From 3ab20c219ce06bb00b3047372442a48f4f4ed8b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 17 May 2022 09:06:25 +0200 Subject: [PATCH] Fixed Param Pickling. Work towards AbstractConfig --- src/scgenerator/__init__.py | 2 +- src/scgenerator/const.py | 2 +- src/scgenerator/legacy.py | 4 +-- src/scgenerator/parameter.py | 38 +++++++++++++++---------- src/scgenerator/physics/simulate.py | 20 ++++++------- src/scgenerator/scripts/__init__.py | 4 +-- src/scgenerator/scripts/slurm_submit.py | 4 +-- 7 files changed, 41 insertions(+), 33 deletions(-) diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 00a5dd1..9780858 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -3,7 +3,7 @@ from . import math, operators from .evaluator import Evaluator from .legacy import convert_sim_folder from .math import abs2, argclosest, normalized, span, tspace, wspace -from .parameter import Configuration, Parameters +from .parameter import FileConfiguration, Parameters from .physics import fiber, materials, pulse, simulate, units, plasma from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation from .physics.units import PlotRange diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index 92610e3..07da023 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -1,4 +1,4 @@ -__version__ = "0.2.5dev" +__version__ = "0.2.6dev" from typing import Any diff --git a/src/scgenerator/legacy.py b/src/scgenerator/legacy.py index b2d68e3..72c2355 100644 --- a/src/scgenerator/legacy.py +++ b/src/scgenerator/legacy.py @@ -9,7 +9,7 @@ import tomli import tomli_w from .const import SPEC1_FN, SPEC1_FN_N, SPECN_FN1 -from .parameter import Configuration, Parameters +from .parameter import FileConfiguration, Parameters from .pbar import PBars from .utils import save_parameters from .variationer import VariationDescriptor @@ -43,7 +43,7 @@ def convert_sim_folder(path: os.PathLike): master_config = dict(name=path.name, Fiber=configs) with open(new_root / "initial_config.toml", "wb") as f: tomli_w.dump(Parameters.strip_params_dict(master_config), f) - configuration = Configuration(path, final_output_path=new_root) + configuration = FileConfiguration(path, final_output_path=new_root) pbar = PBars(configuration.total_num_steps, "Converting") new_paths: dict[VariationDescriptor, Parameters] = dict(configuration) diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index d0d1cf4..2cdccbc 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -437,7 +437,9 @@ class Parameters: return self.dump_dict(add_metadata=False) def __setstate__(self, dumped_dict: dict[str, Any]): - self._param_dico = dumped_dict + self._param_dico = DebugDict() + for k, v in dumped_dict.items(): + setattr(self, k, v) self.__post_init__() def dump_dict(self, compute=True, add_metadata=True) -> dict[str, Any]: @@ -539,7 +541,21 @@ class Parameters: return None -class Configuration: +class AbstractConfiguration: + fiber_paths: list[Path] + num_sim: int + total_num_steps: int + worker_num: int + final_path: Path + + def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]: + raise NotImplementedError() + + def save_parameters(self): + raise NotImplementedError() + + +class FileConfiguration(AbstractConfiguration): """ Primary role is to load the final config file of the simulation and deduce every simulatin that has to happen. Iterating through the Configuration obj yields a list of @@ -548,19 +564,12 @@ class Configuration: """ fiber_configs: list[utils.SubConfig] - vary_dicts: list[dict[str, list]] master_config_dict: dict[str, Any] - fiber_paths: list[Path] - num_sim: int num_fibers: int repeat: int z_num: int - total_num_steps: int - worker_num: int - parallel: bool overwrite: bool - final_path: Path - all_configs: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"] + all_configs: dict[tuple[tuple[int, ...], ...], "FileConfiguration.__SimConfig"] @dataclass(frozen=True) class __SimConfig: @@ -643,7 +652,6 @@ class Configuration: config.fixed["z_num"] * self.variationer.var_num(i) for i, config in enumerate(self.fiber_configs) ) - self.parallel = self.master_config_dict.get("parallel", Parameters.parallel.default) def __validate_variable(self, vary_dict_list: list[dict[str, list]]): for vary_dict in vary_dict_list: @@ -675,7 +683,7 @@ class Configuration: """ if index < 0: index = self.num_fibers + index - sim_dict: dict[Path, Configuration.__SimConfig] = {} + sim_dict: dict[Path, FileConfiguration.__SimConfig] = {} for descriptor in self.variationer.iterate(index): cfg = descriptor.update_config(self.fiber_configs[index].fixed) if index > 0: @@ -711,8 +719,8 @@ class Configuration: time.sleep(1) def __decide( - self, sim_config: "Configuration.__SimConfig" - ) -> tuple["Configuration.Action", dict[str, Any]]: + self, sim_config: "FileConfiguration.__SimConfig" + ) -> tuple["FileConfiguration.Action", dict[str, Any]]: """decide what to to with a particular simulation Parameters @@ -746,7 +754,7 @@ class Configuration: def sim_status( self, data_dir: Path, config_dict: dict[str, Any] = None - ) -> tuple["Configuration.State", int]: + ) -> tuple["FileConfiguration.State", int]: """returns the status of a simulation Parameters diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 3966990..a36ad16 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -13,7 +13,7 @@ import numpy as np from .. import solver, utils from ..logger import get_logger from ..operators import CurrentState -from ..parameter import Configuration, Parameters +from ..parameter import FileConfiguration, Parameters from ..pbar import PBars, ProgressBarActor, progress_worker try: @@ -310,7 +310,7 @@ class Simulations: @classmethod def new( - cls, configuration: Configuration, method: Union[str, Type["Simulations"]] = None + cls, configuration: FileConfiguration, method: Union[str, Type["Simulations"]] = None ) -> "Simulations": """Prefered method to create a new simulations object @@ -323,12 +323,12 @@ class Simulations: if isinstance(method, str): method = Simulations.simulation_methods_dict[method] return method(configuration) - elif configuration.num_sim > 1 and configuration.parallel: + elif configuration.num_sim > 1 and configuration.worker_num > 1: return Simulations.get_best_method()(configuration) else: return SequencialSimulations(configuration) - def __init__(self, configuration: Configuration): + def __init__(self, configuration: FileConfiguration): """ Parameters ---------- @@ -397,7 +397,7 @@ class SequencialSimulations(Simulations, priority=0): def is_available(cls): return True - def __init__(self, configuration: Configuration): + def __init__(self, configuration: FileConfiguration): super().__init__(configuration) self.pbars = PBars( self.configuration.total_num_steps, @@ -422,7 +422,7 @@ class MultiProcSimulations(Simulations, priority=1): def is_available(cls): return True - def __init__(self, configuration: Configuration): + def __init__(self, configuration: FileConfiguration): super().__init__(configuration) if configuration.worker_num is not None: self.sim_jobs_per_node = configuration.worker_num @@ -502,7 +502,7 @@ class RaySimulations(Simulations, priority=2): def __init__( self, - configuration: Configuration, + configuration: FileConfiguration, ): super().__init__(configuration) @@ -578,7 +578,7 @@ def run_simulation( config_file: os.PathLike, method: Union[str, Type[Simulations]] = None, ): - config = Configuration(config_file, wait=True) + config = FileConfiguration(config_file, wait=True) sim = new_simulation(config, method) sim.run() @@ -588,7 +588,7 @@ def run_simulation( def new_simulation( - configuration: Configuration, + configuration: FileConfiguration, method: Union[str, Type[Simulations]] = None, ) -> Simulations: logger = get_logger(__name__) @@ -618,7 +618,7 @@ def parallel_RK4IP( tuple[tuple[list[tuple[str, Any]], Parameters, int, int, np.ndarray], ...], None, None ]: logger = get_logger(__name__) - params = list(Configuration(config)) + params = list(FileConfiguration(config)) n = len(params) z_num = params[0][1].z_num diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index 2fb6022..b59d168 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -11,7 +11,7 @@ from tqdm import tqdm from .. import env, math from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN from ..legacy import translate_parameters -from ..parameter import Configuration, Parameters +from ..parameter import FileConfiguration, Parameters from ..physics import fiber, units from ..plotting import plot_setup, transform_2D_propagation, get_extent from ..spectra import SimulationSeries @@ -271,7 +271,7 @@ def finish_plot(fig: plt.Figure, legend_axes: plt.Axes, all_labels: list[str], p def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]: cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"]) - for style, (descriptor, params), _ in zip(cc, Configuration(config_path), range(20)): + for style, (descriptor, params), _ in zip(cc, FileConfiguration(config_path), range(20)): yield style, descriptor.branch.formatted_descriptor(), params diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py index 1c134ca..c1e804e 100644 --- a/src/scgenerator/scripts/slurm_submit.py +++ b/src/scgenerator/scripts/slurm_submit.py @@ -10,7 +10,7 @@ from typing import Tuple import numpy as np from ..utils import Paths -from ..parameter import Configuration +from ..parameter import FileConfiguration def primes(n): @@ -126,7 +126,7 @@ def main(): "time format must be an integer number of minute or must match the pattern hh:mm:ss" ) - config = Configuration(args.config) + config = FileConfiguration(args.config) final_name = config.final_path sim_num = config.num_sim