Fixed Param Pickling. Work towards AbstractConfig

This commit is contained in:
Benoît Sierro
2022-05-17 09:06:25 +02:00
parent 02abcbe2a2
commit 3ab20c219c
7 changed files with 41 additions and 33 deletions

View File

@@ -3,7 +3,7 @@ from . import math, operators
from .evaluator import Evaluator from .evaluator import Evaluator
from .legacy import convert_sim_folder from .legacy import convert_sim_folder
from .math import abs2, argclosest, normalized, span, tspace, wspace from .math import abs2, argclosest, normalized, span, tspace, wspace
from .parameter import Configuration, Parameters from .parameter import FileConfiguration, Parameters
from .physics import fiber, materials, pulse, simulate, units, plasma from .physics import fiber, materials, pulse, simulate, units, plasma
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
from .physics.units import PlotRange from .physics.units import PlotRange

View File

@@ -1,4 +1,4 @@
__version__ = "0.2.5dev" __version__ = "0.2.6dev"
from typing import Any from typing import Any

View File

@@ -9,7 +9,7 @@ import tomli
import tomli_w import tomli_w
from .const import SPEC1_FN, SPEC1_FN_N, SPECN_FN1 from .const import SPEC1_FN, SPEC1_FN_N, SPECN_FN1
from .parameter import Configuration, Parameters from .parameter import FileConfiguration, Parameters
from .pbar import PBars from .pbar import PBars
from .utils import save_parameters from .utils import save_parameters
from .variationer import VariationDescriptor from .variationer import VariationDescriptor
@@ -43,7 +43,7 @@ def convert_sim_folder(path: os.PathLike):
master_config = dict(name=path.name, Fiber=configs) master_config = dict(name=path.name, Fiber=configs)
with open(new_root / "initial_config.toml", "wb") as f: with open(new_root / "initial_config.toml", "wb") as f:
tomli_w.dump(Parameters.strip_params_dict(master_config), f) tomli_w.dump(Parameters.strip_params_dict(master_config), f)
configuration = Configuration(path, final_output_path=new_root) configuration = FileConfiguration(path, final_output_path=new_root)
pbar = PBars(configuration.total_num_steps, "Converting") pbar = PBars(configuration.total_num_steps, "Converting")
new_paths: dict[VariationDescriptor, Parameters] = dict(configuration) new_paths: dict[VariationDescriptor, Parameters] = dict(configuration)

View File

@@ -437,7 +437,9 @@ class Parameters:
return self.dump_dict(add_metadata=False) return self.dump_dict(add_metadata=False)
def __setstate__(self, dumped_dict: dict[str, Any]): def __setstate__(self, dumped_dict: dict[str, Any]):
self._param_dico = dumped_dict self._param_dico = DebugDict()
for k, v in dumped_dict.items():
setattr(self, k, v)
self.__post_init__() self.__post_init__()
def dump_dict(self, compute=True, add_metadata=True) -> dict[str, Any]: def dump_dict(self, compute=True, add_metadata=True) -> dict[str, Any]:
@@ -539,7 +541,21 @@ class Parameters:
return None return None
class Configuration: class AbstractConfiguration:
fiber_paths: list[Path]
num_sim: int
total_num_steps: int
worker_num: int
final_path: Path
def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]:
raise NotImplementedError()
def save_parameters(self):
raise NotImplementedError()
class FileConfiguration(AbstractConfiguration):
""" """
Primary role is to load the final config file of the simulation and deduce every Primary role is to load the final config file of the simulation and deduce every
simulatin that has to happen. Iterating through the Configuration obj yields a list of simulatin that has to happen. Iterating through the Configuration obj yields a list of
@@ -548,19 +564,12 @@ class Configuration:
""" """
fiber_configs: list[utils.SubConfig] fiber_configs: list[utils.SubConfig]
vary_dicts: list[dict[str, list]]
master_config_dict: dict[str, Any] master_config_dict: dict[str, Any]
fiber_paths: list[Path]
num_sim: int
num_fibers: int num_fibers: int
repeat: int repeat: int
z_num: int z_num: int
total_num_steps: int
worker_num: int
parallel: bool
overwrite: bool overwrite: bool
final_path: Path all_configs: dict[tuple[tuple[int, ...], ...], "FileConfiguration.__SimConfig"]
all_configs: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"]
@dataclass(frozen=True) @dataclass(frozen=True)
class __SimConfig: class __SimConfig:
@@ -643,7 +652,6 @@ class Configuration:
config.fixed["z_num"] * self.variationer.var_num(i) config.fixed["z_num"] * self.variationer.var_num(i)
for i, config in enumerate(self.fiber_configs) for i, config in enumerate(self.fiber_configs)
) )
self.parallel = self.master_config_dict.get("parallel", Parameters.parallel.default)
def __validate_variable(self, vary_dict_list: list[dict[str, list]]): def __validate_variable(self, vary_dict_list: list[dict[str, list]]):
for vary_dict in vary_dict_list: for vary_dict in vary_dict_list:
@@ -675,7 +683,7 @@ class Configuration:
""" """
if index < 0: if index < 0:
index = self.num_fibers + index index = self.num_fibers + index
sim_dict: dict[Path, Configuration.__SimConfig] = {} sim_dict: dict[Path, FileConfiguration.__SimConfig] = {}
for descriptor in self.variationer.iterate(index): for descriptor in self.variationer.iterate(index):
cfg = descriptor.update_config(self.fiber_configs[index].fixed) cfg = descriptor.update_config(self.fiber_configs[index].fixed)
if index > 0: if index > 0:
@@ -711,8 +719,8 @@ class Configuration:
time.sleep(1) time.sleep(1)
def __decide( def __decide(
self, sim_config: "Configuration.__SimConfig" self, sim_config: "FileConfiguration.__SimConfig"
) -> tuple["Configuration.Action", dict[str, Any]]: ) -> tuple["FileConfiguration.Action", dict[str, Any]]:
"""decide what to to with a particular simulation """decide what to to with a particular simulation
Parameters Parameters
@@ -746,7 +754,7 @@ class Configuration:
def sim_status( def sim_status(
self, data_dir: Path, config_dict: dict[str, Any] = None self, data_dir: Path, config_dict: dict[str, Any] = None
) -> tuple["Configuration.State", int]: ) -> tuple["FileConfiguration.State", int]:
"""returns the status of a simulation """returns the status of a simulation
Parameters Parameters

View File

@@ -13,7 +13,7 @@ import numpy as np
from .. import solver, utils from .. import solver, utils
from ..logger import get_logger from ..logger import get_logger
from ..operators import CurrentState from ..operators import CurrentState
from ..parameter import Configuration, Parameters from ..parameter import FileConfiguration, Parameters
from ..pbar import PBars, ProgressBarActor, progress_worker from ..pbar import PBars, ProgressBarActor, progress_worker
try: try:
@@ -310,7 +310,7 @@ class Simulations:
@classmethod @classmethod
def new( def new(
cls, configuration: Configuration, method: Union[str, Type["Simulations"]] = None cls, configuration: FileConfiguration, method: Union[str, Type["Simulations"]] = None
) -> "Simulations": ) -> "Simulations":
"""Prefered method to create a new simulations object """Prefered method to create a new simulations object
@@ -323,12 +323,12 @@ class Simulations:
if isinstance(method, str): if isinstance(method, str):
method = Simulations.simulation_methods_dict[method] method = Simulations.simulation_methods_dict[method]
return method(configuration) return method(configuration)
elif configuration.num_sim > 1 and configuration.parallel: elif configuration.num_sim > 1 and configuration.worker_num > 1:
return Simulations.get_best_method()(configuration) return Simulations.get_best_method()(configuration)
else: else:
return SequencialSimulations(configuration) return SequencialSimulations(configuration)
def __init__(self, configuration: Configuration): def __init__(self, configuration: FileConfiguration):
""" """
Parameters Parameters
---------- ----------
@@ -397,7 +397,7 @@ class SequencialSimulations(Simulations, priority=0):
def is_available(cls): def is_available(cls):
return True return True
def __init__(self, configuration: Configuration): def __init__(self, configuration: FileConfiguration):
super().__init__(configuration) super().__init__(configuration)
self.pbars = PBars( self.pbars = PBars(
self.configuration.total_num_steps, self.configuration.total_num_steps,
@@ -422,7 +422,7 @@ class MultiProcSimulations(Simulations, priority=1):
def is_available(cls): def is_available(cls):
return True return True
def __init__(self, configuration: Configuration): def __init__(self, configuration: FileConfiguration):
super().__init__(configuration) super().__init__(configuration)
if configuration.worker_num is not None: if configuration.worker_num is not None:
self.sim_jobs_per_node = configuration.worker_num self.sim_jobs_per_node = configuration.worker_num
@@ -502,7 +502,7 @@ class RaySimulations(Simulations, priority=2):
def __init__( def __init__(
self, self,
configuration: Configuration, configuration: FileConfiguration,
): ):
super().__init__(configuration) super().__init__(configuration)
@@ -578,7 +578,7 @@ def run_simulation(
config_file: os.PathLike, config_file: os.PathLike,
method: Union[str, Type[Simulations]] = None, method: Union[str, Type[Simulations]] = None,
): ):
config = Configuration(config_file, wait=True) config = FileConfiguration(config_file, wait=True)
sim = new_simulation(config, method) sim = new_simulation(config, method)
sim.run() sim.run()
@@ -588,7 +588,7 @@ def run_simulation(
def new_simulation( def new_simulation(
configuration: Configuration, configuration: FileConfiguration,
method: Union[str, Type[Simulations]] = None, method: Union[str, Type[Simulations]] = None,
) -> Simulations: ) -> Simulations:
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -618,7 +618,7 @@ def parallel_RK4IP(
tuple[tuple[list[tuple[str, Any]], Parameters, int, int, np.ndarray], ...], None, None tuple[tuple[list[tuple[str, Any]], Parameters, int, int, np.ndarray], ...], None, None
]: ]:
logger = get_logger(__name__) logger = get_logger(__name__)
params = list(Configuration(config)) params = list(FileConfiguration(config))
n = len(params) n = len(params)
z_num = params[0][1].z_num z_num = params[0][1].z_num

View File

@@ -11,7 +11,7 @@ from tqdm import tqdm
from .. import env, math from .. import env, math
from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN
from ..legacy import translate_parameters from ..legacy import translate_parameters
from ..parameter import Configuration, Parameters from ..parameter import FileConfiguration, Parameters
from ..physics import fiber, units from ..physics import fiber, units
from ..plotting import plot_setup, transform_2D_propagation, get_extent from ..plotting import plot_setup, transform_2D_propagation, get_extent
from ..spectra import SimulationSeries from ..spectra import SimulationSeries
@@ -271,7 +271,7 @@ def finish_plot(fig: plt.Figure, legend_axes: plt.Axes, all_labels: list[str], p
def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]: def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]:
cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"]) cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"])
for style, (descriptor, params), _ in zip(cc, Configuration(config_path), range(20)): for style, (descriptor, params), _ in zip(cc, FileConfiguration(config_path), range(20)):
yield style, descriptor.branch.formatted_descriptor(), params yield style, descriptor.branch.formatted_descriptor(), params

View File

@@ -10,7 +10,7 @@ from typing import Tuple
import numpy as np import numpy as np
from ..utils import Paths from ..utils import Paths
from ..parameter import Configuration from ..parameter import FileConfiguration
def primes(n): def primes(n):
@@ -126,7 +126,7 @@ def main():
"time format must be an integer number of minute or must match the pattern hh:mm:ss" "time format must be an integer number of minute or must match the pattern hh:mm:ss"
) )
config = Configuration(args.config) config = FileConfiguration(args.config)
final_name = config.final_path final_name = config.final_path
sim_num = config.num_sim sim_num = config.num_sim