Fixed Param Pickling. Work towards AbstractConfig

This commit is contained in:
Benoît Sierro
2022-05-17 09:06:25 +02:00
parent 02abcbe2a2
commit 3ab20c219c
7 changed files with 41 additions and 33 deletions

View File

@@ -3,7 +3,7 @@ from . import math, operators
from .evaluator import Evaluator
from .legacy import convert_sim_folder
from .math import abs2, argclosest, normalized, span, tspace, wspace
from .parameter import Configuration, Parameters
from .parameter import FileConfiguration, Parameters
from .physics import fiber, materials, pulse, simulate, units, plasma
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
from .physics.units import PlotRange

View File

@@ -1,4 +1,4 @@
__version__ = "0.2.5dev"
__version__ = "0.2.6dev"
from typing import Any

View File

@@ -9,7 +9,7 @@ import tomli
import tomli_w
from .const import SPEC1_FN, SPEC1_FN_N, SPECN_FN1
from .parameter import Configuration, Parameters
from .parameter import FileConfiguration, Parameters
from .pbar import PBars
from .utils import save_parameters
from .variationer import VariationDescriptor
@@ -43,7 +43,7 @@ def convert_sim_folder(path: os.PathLike):
master_config = dict(name=path.name, Fiber=configs)
with open(new_root / "initial_config.toml", "wb") as f:
tomli_w.dump(Parameters.strip_params_dict(master_config), f)
configuration = Configuration(path, final_output_path=new_root)
configuration = FileConfiguration(path, final_output_path=new_root)
pbar = PBars(configuration.total_num_steps, "Converting")
new_paths: dict[VariationDescriptor, Parameters] = dict(configuration)

View File

@@ -437,7 +437,9 @@ class Parameters:
return self.dump_dict(add_metadata=False)
def __setstate__(self, dumped_dict: dict[str, Any]):
self._param_dico = dumped_dict
self._param_dico = DebugDict()
for k, v in dumped_dict.items():
setattr(self, k, v)
self.__post_init__()
def dump_dict(self, compute=True, add_metadata=True) -> dict[str, Any]:
@@ -539,7 +541,21 @@ class Parameters:
return None
class Configuration:
class AbstractConfiguration:
fiber_paths: list[Path]
num_sim: int
total_num_steps: int
worker_num: int
final_path: Path
def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]:
raise NotImplementedError()
def save_parameters(self):
raise NotImplementedError()
class FileConfiguration(AbstractConfiguration):
"""
Primary role is to load the final config file of the simulation and deduce every
simulatin that has to happen. Iterating through the Configuration obj yields a list of
@@ -548,19 +564,12 @@ class Configuration:
"""
fiber_configs: list[utils.SubConfig]
vary_dicts: list[dict[str, list]]
master_config_dict: dict[str, Any]
fiber_paths: list[Path]
num_sim: int
num_fibers: int
repeat: int
z_num: int
total_num_steps: int
worker_num: int
parallel: bool
overwrite: bool
final_path: Path
all_configs: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"]
all_configs: dict[tuple[tuple[int, ...], ...], "FileConfiguration.__SimConfig"]
@dataclass(frozen=True)
class __SimConfig:
@@ -643,7 +652,6 @@ class Configuration:
config.fixed["z_num"] * self.variationer.var_num(i)
for i, config in enumerate(self.fiber_configs)
)
self.parallel = self.master_config_dict.get("parallel", Parameters.parallel.default)
def __validate_variable(self, vary_dict_list: list[dict[str, list]]):
for vary_dict in vary_dict_list:
@@ -675,7 +683,7 @@ class Configuration:
"""
if index < 0:
index = self.num_fibers + index
sim_dict: dict[Path, Configuration.__SimConfig] = {}
sim_dict: dict[Path, FileConfiguration.__SimConfig] = {}
for descriptor in self.variationer.iterate(index):
cfg = descriptor.update_config(self.fiber_configs[index].fixed)
if index > 0:
@@ -711,8 +719,8 @@ class Configuration:
time.sleep(1)
def __decide(
self, sim_config: "Configuration.__SimConfig"
) -> tuple["Configuration.Action", dict[str, Any]]:
self, sim_config: "FileConfiguration.__SimConfig"
) -> tuple["FileConfiguration.Action", dict[str, Any]]:
"""decide what to to with a particular simulation
Parameters
@@ -746,7 +754,7 @@ class Configuration:
def sim_status(
self, data_dir: Path, config_dict: dict[str, Any] = None
) -> tuple["Configuration.State", int]:
) -> tuple["FileConfiguration.State", int]:
"""returns the status of a simulation
Parameters

View File

@@ -13,7 +13,7 @@ import numpy as np
from .. import solver, utils
from ..logger import get_logger
from ..operators import CurrentState
from ..parameter import Configuration, Parameters
from ..parameter import FileConfiguration, Parameters
from ..pbar import PBars, ProgressBarActor, progress_worker
try:
@@ -310,7 +310,7 @@ class Simulations:
@classmethod
def new(
cls, configuration: Configuration, method: Union[str, Type["Simulations"]] = None
cls, configuration: FileConfiguration, method: Union[str, Type["Simulations"]] = None
) -> "Simulations":
"""Prefered method to create a new simulations object
@@ -323,12 +323,12 @@ class Simulations:
if isinstance(method, str):
method = Simulations.simulation_methods_dict[method]
return method(configuration)
elif configuration.num_sim > 1 and configuration.parallel:
elif configuration.num_sim > 1 and configuration.worker_num > 1:
return Simulations.get_best_method()(configuration)
else:
return SequencialSimulations(configuration)
def __init__(self, configuration: Configuration):
def __init__(self, configuration: FileConfiguration):
"""
Parameters
----------
@@ -397,7 +397,7 @@ class SequencialSimulations(Simulations, priority=0):
def is_available(cls):
return True
def __init__(self, configuration: Configuration):
def __init__(self, configuration: FileConfiguration):
super().__init__(configuration)
self.pbars = PBars(
self.configuration.total_num_steps,
@@ -422,7 +422,7 @@ class MultiProcSimulations(Simulations, priority=1):
def is_available(cls):
return True
def __init__(self, configuration: Configuration):
def __init__(self, configuration: FileConfiguration):
super().__init__(configuration)
if configuration.worker_num is not None:
self.sim_jobs_per_node = configuration.worker_num
@@ -502,7 +502,7 @@ class RaySimulations(Simulations, priority=2):
def __init__(
self,
configuration: Configuration,
configuration: FileConfiguration,
):
super().__init__(configuration)
@@ -578,7 +578,7 @@ def run_simulation(
config_file: os.PathLike,
method: Union[str, Type[Simulations]] = None,
):
config = Configuration(config_file, wait=True)
config = FileConfiguration(config_file, wait=True)
sim = new_simulation(config, method)
sim.run()
@@ -588,7 +588,7 @@ def run_simulation(
def new_simulation(
configuration: Configuration,
configuration: FileConfiguration,
method: Union[str, Type[Simulations]] = None,
) -> Simulations:
logger = get_logger(__name__)
@@ -618,7 +618,7 @@ def parallel_RK4IP(
tuple[tuple[list[tuple[str, Any]], Parameters, int, int, np.ndarray], ...], None, None
]:
logger = get_logger(__name__)
params = list(Configuration(config))
params = list(FileConfiguration(config))
n = len(params)
z_num = params[0][1].z_num

View File

@@ -11,7 +11,7 @@ from tqdm import tqdm
from .. import env, math
from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN
from ..legacy import translate_parameters
from ..parameter import Configuration, Parameters
from ..parameter import FileConfiguration, Parameters
from ..physics import fiber, units
from ..plotting import plot_setup, transform_2D_propagation, get_extent
from ..spectra import SimulationSeries
@@ -271,7 +271,7 @@ def finish_plot(fig: plt.Figure, legend_axes: plt.Axes, all_labels: list[str], p
def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]:
cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"])
for style, (descriptor, params), _ in zip(cc, Configuration(config_path), range(20)):
for style, (descriptor, params), _ in zip(cc, FileConfiguration(config_path), range(20)):
yield style, descriptor.branch.formatted_descriptor(), params

View File

@@ -10,7 +10,7 @@ from typing import Tuple
import numpy as np
from ..utils import Paths
from ..parameter import Configuration
from ..parameter import FileConfiguration
def primes(n):
@@ -126,7 +126,7 @@ def main():
"time format must be an integer number of minute or must match the pattern hh:mm:ss"
)
config = Configuration(args.config)
config = FileConfiguration(args.config)
final_name = config.final_path
sim_num = config.num_sim