From 2d9d24da1641999dc9f7e33bcc31e9ff82cd6825 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 11 May 2021 08:48:25 +0200 Subject: [PATCH] num_steps added in ParamSequence --- src/scgenerator/__init__.py | 2 +- src/scgenerator/initialize.py | 10 +++++++--- src/scgenerator/io.py | 6 +++--- src/scgenerator/physics/simulate.py | 4 ++-- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 46e51b1..e3f5c27 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -2,5 +2,5 @@ from .initialize import compute_init_parameters from .io import Paths, load_toml from .math import abs2, argclosest, span from .physics import fiber, materials, pulse, simulate, units -from .physics.simulate import RK4IP, new_simulations +from .physics.simulate import RK4IP, new_simulations, resume_simulations from .plotting import plot_avg, plot_results_1D, plot_results_2D, plot_spectrogram diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 6fc1760..5a4aae7 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -20,10 +20,11 @@ class ParamSequence(Mapping): self.name = self.config["name"] self.num_sim, self.num_variable = count_variations(self.config) + self.num_steps = self.num_sim * self.config["simulation", "z_num"] self.single_sim = self.num_sim == 1 def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]: - """iterates through all possible parameters, yielding a config as welle as a flattened + """iterates through all possible parameters, yielding a config as well as a flattened computed parameters set each time""" for variable_list, full_config in required_simulations(self.config): yield variable_list, compute_init_parameters(full_config) @@ -42,9 +43,12 @@ class RecoveryParamSequence(ParamSequence): def __init__(self, config, task_id): super().__init__(config) self.id = task_id + self.num_steps = 0 for sub_folder in io.get_data_subfolders(io.get_data_folder(self.id)): - if io.propagation_completed(sub_folder, config["simulation"]["z_num"]): + num_left = io.num_left_to_propagate(sub_folder, config["simulation"]["z_num"]) + if num_left == 0: self.num_sim -= 1 + self.num_steps += num_left self.single_sim = self.num_sim == 1 def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]: @@ -56,7 +60,7 @@ class RecoveryParamSequence(ParamSequence): if not io.propagation_initiated(sub_folder): yield variable_list, compute_init_parameters(full_config) - elif not io.propagation_completed(sub_folder, self.config["simulation"]["z_num"]): + elif io.num_left_to_propagate(sub_folder, self.config["simulation"]["z_num"]) != 0: yield variable_list, recover_params(full_config, variable_list, self.id) else: continue diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index a6103f4..4bb825a 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -302,7 +302,7 @@ def check_data_integrity(sub_folders: List[str], init_z_num: int): raised if not all spectra are present in any folder """ for sub_folder in sub_folders: - if not propagation_completed(sub_folder, init_z_num): + if num_left_to_propagate(sub_folder, init_z_num) != 0: raise IncompleteDataFolderError( f"not enough spectra of the specified {init_z_num} found in {sub_folder}" ) @@ -314,7 +314,7 @@ def propagation_initiated(sub_folder) -> bool: return False -def propagation_completed(sub_folder: str, init_z_num: int): +def num_left_to_propagate(sub_folder: str, init_z_num: int) -> int: """checks if a propagation has completed Parameters @@ -344,7 +344,7 @@ def propagation_completed(sub_folder: str, init_z_num: int): + f" but the parameter file in {sub_folder} specifies {z_num}" ) - return num_spectra == z_num + return z_num - num_spectra def find_last_spectrum_file(path: str): diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 11da88b..e691a1d 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -358,10 +358,10 @@ class Simulations: def limit_concurrent_jobs(self, max_concurrent_jobs): self.max_concurrent_jobs = max_concurrent_jobs - def update(self, param_seq): + def update(self, param_seq: initialize.ParamSequence): self.param_seq = param_seq self.progress_tracker = utils.ProgressTracker( - len(self.param_seq) * self.param_seq["simulation", "z_num"], + self.param_seq.num_steps, percent_incr=1, logger=self.logger, prefix="Overall : ",