fixed z bug

This commit is contained in:
Benoît Sierro
2021-05-27 15:04:54 +02:00
parent f88995aa6a
commit ce648b12ff
4 changed files with 25 additions and 17 deletions

View File

@@ -43,7 +43,7 @@ def create_parser():
default=None, default=None,
) )
run_parser.add_argument( run_parser.add_argument(
"--output-name", "--o", help="path to the final output folder", default=None "--output-name", "-o", help="path to the final output folder", default=None
) )
run_parser.set_defaults(func=run_sim) run_parser.set_defaults(func=run_sim)
@@ -65,7 +65,7 @@ def create_parser():
"path", help="path to the final simulation folder containing 'initial_config.toml'" "path", help="path to the final simulation folder containing 'initial_config.toml'"
) )
merge_parser.add_argument( merge_parser.add_argument(
"--output-name", "--o", help="path to the final output folder", default=None "--output-name", "-o", help="path to the final output folder", default=None
) )
merge_parser.set_defaults(func=merge) merge_parser.set_defaults(func=merge)

View File

@@ -139,14 +139,14 @@ class RecoveryParamSequence(ParamSequence):
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]: def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]:
for variable_list, full_config in required_simulations(self.config): for variable_list, full_config in required_simulations(self.config):
sub_folder = os.path.join( data_dir = os.path.join(
io.get_data_folder(self.id), utils.format_variable_list(variable_list) io.get_data_folder(self.id), utils.format_variable_list(variable_list)
) )
if not io.propagation_initiated(sub_folder): if not io.propagation_initiated(data_dir):
yield variable_list, compute_init_parameters(full_config) yield variable_list, compute_init_parameters(full_config)
elif io.num_left_to_propagate(sub_folder, self.config["simulation"]["z_num"]) != 0: elif io.num_left_to_propagate(data_dir, self.config["simulation"]["z_num"]) != 0:
yield variable_list, recover_params(full_config, variable_list, self.id) yield variable_list, recover_params(full_config, data_dir)
else: else:
continue continue
@@ -515,15 +515,20 @@ def _ensure_consistency(config):
return config return config
def recover_params(params: dict, variable_only: List[Tuple[str, Any]], task_id: int): def recover_params(config: Dict[str, Any], data_folder: os.PathLike) -> Dict[str, Any]:
params = compute_init_parameters(params) path = Path(data_folder)
vary_str = utils.format_variable_list(variable_only) params = compute_init_parameters(config)
path = os.path.join(io.get_data_folder(task_id), vary_str) try:
num, last_spectrum = io.load_last_spectrum(path) prev_params = io.load_toml(path / "params.toml")
except FileNotFoundError:
prev_params = {}
for k, v in prev_params.items():
params.setdefault(k, v)
num, last_spectrum = io.load_last_spectrum(str(path))
params["spec_0"] = last_spectrum params["spec_0"] = last_spectrum
params["field_0"] = np.fft.ifft(last_spectrum) params["field_0"] = np.fft.ifft(last_spectrum)
params["recovery_last_stored"] = num params["recovery_last_stored"] = num
params["cons_qty"] = np.load(os.path.join(path, "cons_qty.npy")) params["cons_qty"] = np.load(os.path.join(data_folder, "cons_qty.npy"))
return params return params

View File

@@ -361,6 +361,7 @@ def find_last_spectrum_file(path: str):
def load_last_spectrum(path: str) -> Tuple[int, np.ndarray]: def load_last_spectrum(path: str) -> Tuple[int, np.ndarray]:
"""return the last spectrum stored in path as well as its id"""
num = find_last_spectrum_file(path) num = find_last_spectrum_file(path)
return num, np.load(os.path.join(path, f"spectrum_{num}.npy")) return num, np.load(os.path.join(path, f"spectrum_{num}.npy"))
@@ -421,10 +422,10 @@ def update_appended_params(param_path, new_path, z):
params = load_toml(param_path) params = load_toml(param_path)
if "simulation" in params: if "simulation" in params:
params["simulation"]["z_num"] = z_num params["simulation"]["z_num"] = z_num
params["simulation"]["z_targets"] = z_num params["simulation"]["z_targets"] = z
else: else:
params["z_num"] = z_num params["z_num"] = z_num
params["z_targets"] = z_num params["z_targets"] = z
save_toml(new_path, params) save_toml(new_path, params)

View File

@@ -124,13 +124,13 @@ class RK4IP:
def _setup_sim_parameters(self): def _setup_sim_parameters(self):
# making sure to keep only the z that we want # making sure to keep only the z that we want
self.z_stored = list(self.z_targets.copy()[0 : self.starting_num + 1])
self.z_targets = list(self.z_targets.copy()[self.starting_num :]) self.z_targets = list(self.z_targets.copy()[self.starting_num :])
self.z_targets.sort() self.z_targets.sort()
self.store_num = len(self.z_targets) self.store_num = len(self.z_targets)
# Initial setup of simulation parameters # Initial setup of simulation parameters
self.d_w = self.w_c[1] - self.w_c[0] # resolution of the frequency grid self.d_w = self.w_c[1] - self.w_c[0] # resolution of the frequency grid
self.z_stored = list(self.z_targets.copy()[0 : self.starting_num + 1])
self.z = self.z_targets.pop(0) self.z = self.z_targets.pop(0)
# Setup initial values for every physical quantity that we want to track # Setup initial values for every physical quantity that we want to track
@@ -497,7 +497,7 @@ class SequencialSimulations(Simulations, priority=0):
def new_sim(self, variable_list: List[tuple], params: Dict[str, Any]): def new_sim(self, variable_list: List[tuple], params: Dict[str, Any]):
v_list_str = utils.format_variable_list(variable_list) v_list_str = utils.format_variable_list(variable_list)
self.logger.info(f"launching simulation with {v_list_str}") self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}")
SequentialRK4IP( SequentialRK4IP(
params, self.overall_pbar, save_data=True, job_identifier=v_list_str, task_id=self.id params, self.overall_pbar, save_data=True, job_identifier=v_list_str, task_id=self.id
).run() ).run()
@@ -678,7 +678,9 @@ class RaySimulations(Simulations, priority=2):
self.actors[new_job.task_id()] = new_actor self.actors[new_job.task_id()] = new_actor
self.jobs.append(new_job) self.jobs.append(new_job)
self.logger.info(f"launching simulation with {v_list_str}, job : {self.jobs[-1].hex()}") self.logger.info(
f"{self.param_seq.name} : launching simulation with {v_list_str}, job : {self.jobs[-1].hex()}"
)
def finish(self): def finish(self):
while len(self.jobs) > 0: while len(self.jobs) > 0: