From f724d1dcf632b27692bbebe286a054d918d02113 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Thu, 9 Sep 2021 12:37:58 +0200 Subject: [PATCH] changed to tuple index in Configuration --- 1 | 0 play.py | 9 +- src/scgenerator/__init__.py | 2 +- src/scgenerator/physics/fiber.py | 13 -- src/scgenerator/physics/pulse.py | 5 +- src/scgenerator/physics/simulate.py | 30 ++-- src/scgenerator/plotting.py | 3 + src/scgenerator/scripts/__init__.py | 8 +- src/scgenerator/scripts/slurm_submit.py | 2 +- src/scgenerator/spectra.py | 4 +- src/scgenerator/utils/__init__.py | 70 +++++--- src/scgenerator/utils/parameter.py | 202 ++++++++++++++---------- 12 files changed, 209 insertions(+), 139 deletions(-) create mode 100644 1 diff --git a/1 b/1 new file mode 100644 index 0000000..e69de29 diff --git a/play.py b/play.py index 9543fe2..9f687f3 100644 --- a/play.py +++ b/play.py @@ -3,13 +3,20 @@ import numpy as np import scgenerator as sc import matplotlib.pyplot as plt from pathlib import Path +from pprint import pprint + + +def _main(): + print(os.getcwd()) + for v_list, params in sc.Configuration("PM1550+PM2000D+PM1550/Pos30000.toml"): + print(params.fiber_map) def main(): drr = os.getcwd() os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations") try: - sc.run_simulation("PM1550+PM2000D/Pos30000.toml") + _main() finally: os.chdir(drr) diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 6942d10..d5e5c7b 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -4,5 +4,5 @@ from .physics import fiber, materials, pulse, simulate, units from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot from .spectra import Pulse, Spectrum -from .utils import Paths, load_toml +from .utils import Paths, open_config from .utils.parameter import Configuration, Parameters, PlotRange diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 43b4155..ccddbb1 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -1124,16 +1124,3 @@ def capillary_loss( nu_n = 0.5 * (chi_silica + 2) / np.sqrt(chi_silica) alpha[mask] = nu_n * (u_nm(*he_mode) * wl_for_disp[mask] / pipi) ** 2 * core_radius ** -3 return alpha - - -if __name__ == "__main__": - w = np.linspace(0, 1, 4096) - c = np.arange(8) - import time - - t = time.time() - - for _ in range(10000): - dispersion_from_coefficients(w, c) - - print((time.time() - t) / 10, "ms") diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index 001f292..6a07c70 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -1055,7 +1055,10 @@ def rin_curve(spectra: np.ndarray) -> np.ndarray: def measure_field(t: np.ndarray, field: np.ndarray) -> Tuple[float, float, float]: """returns fwhm, peak_power, energy""" - intensity = abs2(field) + if np.iscomplexobj(field): + intensity = abs2(field) + else: + intensity = field _, fwhm_lim, _, _ = find_lobe_limits(t, intensity) fwhm = length(fwhm_lim) peak_power = intensity.max() diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index bee847a..6878261 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -459,17 +459,19 @@ class Simulations: self.configuration = configuration - self.name = self.configuration.name + self.name = self.configuration.final_path self.sim_dir = self.configuration.final_sim_dir self.configuration.save_parameters() self.sim_jobs_per_node = 1 def finished_and_complete(self): - for sim in self.configuration.data_dirs: - for data_dir in sim: - if self.configuration.sim_status(data_dir)[0] != self.configuration.State.COMPLETE: - return False + for sim in self.configuration.all_configs_dict.values(): + if ( + self.configuration.sim_status(sim.output_path)[0] + != self.configuration.State.COMPLETE + ): + return False return True def run(self): @@ -517,12 +519,14 @@ class SequencialSimulations(Simulations, priority=0): def __init__(self, configuration: Configuration, task_id): super().__init__(configuration, task_id=task_id) self.pbars = utils.PBars( - self.configuration.total_num_steps, "Simulating " + self.configuration.name, 1 + self.configuration.total_num_steps, "Simulating " + self.configuration.final_path, 1 ) self.configuration.skip_callback = lambda num: self.pbars.update(0, num) def new_sim(self, v_list_str: str, params: Parameters): - self.logger.info(f"{self.configuration.name} : launching simulation with {v_list_str}") + self.logger.info( + f"{self.configuration.final_path} : launching simulation with {v_list_str}" + ) SequentialRK4IP( params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id ).run() @@ -558,7 +562,7 @@ class MultiProcSimulations(Simulations, priority=1): self.p_worker = multiprocessing.Process( target=utils.progress_worker, args=( - self.configuration.name, + self.configuration.final_path, self.sim_jobs_per_node, self.configuration.total_num_steps, self.progress_queue, @@ -646,7 +650,7 @@ class RaySimulations(Simulations, priority=2): self.num_submitted = 0 self.rolling_id = 0 self.p_actor = ray.remote(utils.ProgressBarActor).remote( - self.configuration.name, self.sim_jobs_total, self.configuration.total_num_steps + self.configuration.final_path, self.sim_jobs_total, self.configuration.total_num_steps ) self.configuration.skip_callback = lambda num: ray.get(self.p_actor.update.remote(0, num)) @@ -668,7 +672,9 @@ class RaySimulations(Simulations, priority=2): ) self.num_submitted += 1 - self.logger.info(f"{self.configuration.name} : launching simulation with {v_list_str}") + self.logger.info( + f"{self.configuration.final_path} : launching simulation with {v_list_str}" + ) def collect_1_job(self): ray.get(self.p_actor.update_pbars.remote()) @@ -707,7 +713,7 @@ def run_simulation( final_name = env.get(env.OUTPUT_PATH) if final_name is None: - final_name = config.name + final_name = config.final_path utils.merge(final_name, path_trees) try: @@ -722,7 +728,7 @@ def new_simulation( ) -> Simulations: logger = get_logger(__name__) task_id = random.randint(1e9, 1e12) - logger.info(f"running {configuration.name}") + logger.info(f"running {configuration.final_path}") return Simulations.new(configuration, task_id, method) diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 09338ee..2aa6e55 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -259,6 +259,7 @@ def propagation_plot( params: Parameters, ax: plt.Axes, log: Union[int, float, bool, str] = "1D", + renormalize: bool = False, vmin: float = None, vmax: float = None, transpose: bool = False, @@ -295,6 +296,8 @@ def propagation_plot( """ x_axis, y_axis, values = transform_2D_propagation(values, plt_range, params, log, skip) + if renormalize and log is False: + values = values / values.max() if log is not False: vmax = defaults["vmax"] if vmax is None else vmax vmin = defaults["vmin"] if vmin is None else vmin diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index f249345..88bd0d1 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -1,6 +1,4 @@ -import itertools import os -from itertools import cycle from pathlib import Path from typing import Any, Iterable, Optional @@ -14,7 +12,7 @@ from ..const import PARAM_FN, PARAM_SEPARATOR from ..physics import fiber, units from ..plotting import plot_setup from ..spectra import Pulse -from ..utils import auto_crop, load_toml, save_toml, translate_parameters +from ..utils import auto_crop, open_config, save_toml, translate_parameters from ..utils.parameter import ( Configuration, Parameters, @@ -259,7 +257,7 @@ def finish_plot(fig, legend_axes, all_labels, params): def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]: cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"]) - pseq = Configuration(load_toml(config_path)) + pseq = Configuration(open_config(config_path)) for style, (variables, params) in zip(cc, pseq): lbl = [pretty_format_value(name, value) for name, value in variables[1:-1]] yield style, lbl, params @@ -268,7 +266,7 @@ def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters def convert_params(params_file: os.PathLike): p = Path(params_file) if p.name == PARAM_FN: - d = load_toml(params_file) + d = open_config(params_file) d = translate_parameters(d) save_toml(params_file, d) print(f"converted {p}") diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py index 078a844..a977f35 100644 --- a/src/scgenerator/scripts/slurm_submit.py +++ b/src/scgenerator/scripts/slurm_submit.py @@ -127,7 +127,7 @@ def main(): ) config = Configuration(args.config) - final_name = config.name + final_name = config.final_path sim_num = config.num_sim if args.command == "merge": diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 6e8b089..2401b16 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -150,7 +150,9 @@ class Pulse(Sequence): raise FileNotFoundError(f"Folder {self.path} does not exist") self.params = Parameters.load(self.path / "params.toml") - self.params.compute(["t", "l", "w_c", "w0", "z_targets"]) + self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"]) + if self.params.fiber_map is None: + self.params.fiber_map = {0.0: self.params.name} try: self.z = np.load(os.path.join(path, "z.npy")) diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index 8447303..42f7e9e 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -20,6 +20,8 @@ from string import printable as str_printable from typing import Any, Callable, Generator, Iterable, MutableMapping, Sequence, TypeVar, Union import numpy as np +from numpy.lib.arraysetops import isin +from numpy.lib.function_base import insert import pkg_resources as pkg import toml from tqdm import tqdm @@ -88,18 +90,21 @@ def load_previous_spectrum(prev_data_dir: str) -> np.ndarray: return np.load(prev_data_dir / SPEC1_FN.format(num)) -def conform_toml_path(path: os.PathLike) -> Path: - path = Path(path) - if not path.name.lower().endswith(".toml"): - path = path.parent / (path.name + ".toml") +def conform_toml_path(path: os.PathLike) -> str: + path: str = str(path) + if not path.lower().endswith(".toml"): + path = path + ".toml" return path -def load_toml(path: os.PathLike): - """returns a dictionary parsed from the specified toml file""" +def open_config(path: os.PathLike): + """returns a dictionary parsed from the specified toml file + This also handle having a 'INCLUDE' argument that will fill + otherwise unspecified keys with what's in the INCLUDE file(s)""" + path = conform_toml_path(path) - with open(path, mode="r") as file: - dico = toml.load(file) + dico = resolve_loadfile_arg(load_toml(path)) + dico.setdefault("variable", {}) for key in {"simulation", "fiber", "gas", "pulse"} & dico.keys(): section = dico.pop(key, {}) @@ -110,6 +115,35 @@ def load_toml(path: os.PathLike): return dico +def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]: + if (f_list := dico.pop("INCLUDE", None)) is not None: + if isinstance(f_list, str): + f_list = [f_list] + for to_load in f_list: + loaded = load_toml(to_load) + for k, v in loaded.items(): + if k not in dico and k not in dico.get("variable", {}): + dico[k] = v + for k, v in dico.items(): + if isinstance(v, MutableMapping): + dico[k] = resolve_loadfile_arg(v) + elif isinstance(v, Sequence): + for i, vv in enumerate(v): + if isinstance(vv, MutableMapping): + dico[k][i] = resolve_loadfile_arg(vv) + return dico + + +def load_toml(descr: str) -> dict[str, Any]: + if ":" in descr: + path, entry = descr.split(":", 1) + with open(path) as file: + return toml.load(file)[entry] + else: + with open(descr) as file: + return toml.load(file) + + def save_toml(path: os.PathLike, dico): """saves a dictionary into a toml file""" path = conform_toml_path(path) @@ -119,7 +153,7 @@ def save_toml(path: os.PathLike, dico): def load_config_sequence(final_config_path: os.PathLike) -> tuple[list[dict[str, Any]], str]: - loaded_config = load_toml(final_config_path) + loaded_config = open_config(final_config_path) final_name = loaded_config.get("name") fiber_list = loaded_config.pop("Fiber") configs = [] @@ -133,7 +167,7 @@ def load_config_sequence(final_config_path: os.PathLike) -> tuple[list[dict[str, else: configs.append(loaded_config) while "previous_config_file" in configs[0]: - configs.insert(0, load_toml(configs[0]["previous_config_file"])) + configs.insert(0, open_config(configs[0]["previous_config_file"])) configs[0].setdefault("variable", {}) for pre, nex in zip(configs[:-1], configs[1:]): variable = nex.pop("variable", {}) @@ -189,13 +223,9 @@ def load_material_dico(name: str) -> dict[str, Any]: def update_appended_params(source: Path, destination: Path, z: Sequence): z_num = len(z) - params = load_toml(source) - if "simulation" in params: - params["simulation"]["z_num"] = z_num - params["fiber"]["length"] = float(z[-1] - z[0]) - else: - params["z_num"] = z_num - params["length"] = float(z[-1] - z[0]) + params = open_config(source) + params["z_num"] = z_num + params["length"] = float(z[-1] - z[0]) for p_name in ["recovery_data_dir", "prev_data_dir", "output_path"]: if p_name in params: del params[p_name] @@ -230,7 +260,9 @@ def build_path_branch(data_dir: Path) -> tuple[Path, ...]: if not data_dir.is_dir(): return None path_branch = [data_dir] - while (prev_sim_path := load_toml(path_branch[-1] / PARAM_FN).get("prev_data_dir")) is not None: + while ( + prev_sim_path := open_config(path_branch[-1] / PARAM_FN).get("prev_data_dir") + ) is not None: p = Path(prev_sim_path).resolve() if not p.exists(): p = Path(*p.parts[-2:]).resolve() @@ -345,7 +377,7 @@ def merge(destination: os.PathLike, path_trees: list[PathTree] = None): conf, destination / f"initial_config_{i}.toml", ) - prev_z_num = load_toml(conf).get("z_num", prev_z_num) + prev_z_num = open_config(conf).get("z_num", prev_z_num) z_num += prev_z_num pbars = PBars( diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/utils/parameter.py index 22fd485..b017f9c 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/utils/parameter.py @@ -14,7 +14,7 @@ from typing import Any, Callable, Generator, Iterable, Literal, Optional, TypeVa import numpy as np from .. import math, utils -from ..const import PARAM_SEPARATOR, __version__ +from ..const import PARAM_FN, PARAM_SEPARATOR, __version__ from ..errors import EvaluatorError, NoDefaultError from ..logger import get_logger from ..physics import fiber, materials, pulse, units @@ -308,6 +308,10 @@ class Parameter: return f"{num_str} {unit}" +def fiber_map_converter(d: dict[str, str]) -> dict[float, str]: + return {float(k): v for k, v in d.items()} + + @dataclass class Parameters: """ @@ -421,11 +425,13 @@ class Parameters: const_qty: np.ndarray = Parameter(type_checker(np.ndarray)) beta_func: Callable[[float], list[float]] = Parameter(func_validator) gamma_func: Callable[[float], float] = Parameter(func_validator) + fiber_map: dict[float, str] = Parameter(type_checker(dict), converter=fiber_map_converter) datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime)) version: str = Parameter(string) def prepare_for_dump(self) -> dict[str, Any]: param = asdict(self) + param["fiber_map"] = {str(z): n for z, n in param.get("fiber_map", {}).items()} param = Parameters.strip_params_dict(param) param["datetime"] = datetime_module.datetime.now() param["version"] = __version__ @@ -448,7 +454,7 @@ class Parameters: @classmethod def load(cls, path: os.PathLike) -> "Parameters": - return cls(**utils.load_toml(path)) + return cls(**utils.open_config(path)) @classmethod def load_and_compute(cls, path: os.PathLike) -> "Parameters": @@ -756,8 +762,7 @@ class Configuration: obj with the output path of the simulation saved in its output_path attribute. """ - configs: list[dict[str, Any]] - data_dirs: list[list[Path]] + master_configs: list[dict[str, Any]] sim_dirs: list[Path] num_sim: int repeat: int @@ -766,14 +771,20 @@ class Configuration: worker_num: int parallel: bool overwrite: bool - name: str - all_required: list[list[tuple[list[tuple[str, Any]], dict[str, Any]]]] - # | | | | | - # | | | | param name and value - # | | | all variable parameters - # | | list of all variable parameters associated with the full config dict - # | list of all configs for 1 fiber - # list of all fibers + final_path: str + all_configs_dict: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"] + all_configs_list: list[list["Configuration.__SimConfig"]] + + @dataclass(frozen=True) + class __SimConfig: + vary_list: list[tuple[str, Any]] + config: dict[str, Any] + output_path: Path + index: tuple[tuple[int, ...], ...] + + @property + def sim_num(self) -> int: + return len(self.index) class State(enum.Enum): COMPLETE = enum.auto() @@ -793,46 +804,48 @@ class Configuration: ): self.logger = get_logger(__name__) - self.configs, self.name = utils.load_config_sequence(final_config_path) - if self.name is None: - self.name = Parameters.name.default + self.master_configs, self.final_path = utils.load_config_sequence(final_config_path) + if self.final_path is None: + self.final_path = Parameters.name.default + self.name = Path(self.final_path).name self.z_num = 0 self.total_num_steps = 0 self.sim_dirs = [] self.overwrite = overwrite self.skip_callback = skip_callback - self.worker_num = self.configs[0].get("worker_num", max(1, os.cpu_count() // 2)) - self.repeat = self.configs[0].get("repeat", 1) + self.worker_num = self.master_configs[0].get("worker_num", max(1, os.cpu_count() // 2)) + self.repeat = self.master_configs[0].get("repeat", 1) names = set() - for i, config in enumerate(self.configs): + for i, config in enumerate(self.master_configs): self.z_num += config["z_num"] config.setdefault("name", f"{Parameters.name.default} {i}") given_name = config["name"] - i = 0 + fn_i = 0 while config["name"] in names: - config["name"] = given_name + f"_{i}" - i += 1 + config["name"] = given_name + f"_{fn_i}" + fn_i += 1 names.add(config["name"]) self.sim_dirs.append( utils.ensure_folder( - Path("__" + config["name"] + "__"), + Path("_".join(["_", self.name, Path(config["name"]).name, "_"])), mkdir=False, prevent_overwrite=not self.overwrite, ) ) self.__validate_variable(config) self.__compute_sim_dirs() - [Evaluator.evaluate_default(req[0][1], check_only=True) for req in self.all_required] - self.num_sim = len(self.data_dirs[-1]) + [Evaluator.evaluate_default(c[0].config, True) for c in self.all_configs_list] + self.num_sim = len(self.all_configs_list[-1]) self.total_num_steps = sum( - config["z_num"] * len(self.data_dirs[i]) for i, config in enumerate(self.configs) + config["z_num"] * len(self.all_configs_list[i]) + for i, config in enumerate(self.master_configs) ) self.final_sim_dir = utils.ensure_folder( - Path(self.configs[-1]["name"]), mkdir=False, prevent_overwrite=not self.overwrite + Path(self.master_configs[-1]["name"]), mkdir=False, prevent_overwrite=not self.overwrite ) - self.parallel = self.configs[0].get("parallel", Parameters.parallel.default) + self.parallel = self.master_configs[0].get("parallel", Parameters.parallel.default) def __validate_variable(self, config: dict[str, Any]): for k, v in config.get("variable", {}).items(): @@ -844,14 +857,15 @@ class Configuration: raise ValueError(f"variable parameter {k!r} must not be empty") def __compute_sim_dirs(self): - self.all_required = [] - self.data_dirs = [] - self.configs[0]["variable"]["num"] = list(range(self.configs[0].get("repeat", 1))) - dp = DataPather([c["variable"] for c in self.configs]) - for i, conf in enumerate(self.configs): - self.all_required.append([]) - self.data_dirs.append([]) - for prev_path, this_path, this_vary in dp.all_vary_list(i): + self.all_configs_dict = {} + self.all_configs_list = [] + self.master_configs[0]["variable"]["num"] = list( + range(self.master_configs[0].get("repeat", 1)) + ) + dp = DataPather([c["variable"] for c in self.master_configs]) + for i, conf in enumerate(self.master_configs): + self.all_configs_list.append([]) + for sim_index, prev_path, this_path, this_vary in dp.all_vary_list(i): this_conf = conf.copy() if i > 0: prev_path = utils.ensure_folder( @@ -862,50 +876,56 @@ class Configuration: this_path = utils.ensure_folder( self.sim_dirs[i] / this_path, not self.overwrite, False ) - self.data_dirs[i].append(this_path) this_conf.pop("variable") conf_to_use = {k: v for k, v in this_vary if k != "num"} | this_conf - self.all_required[i].append((this_vary, conf_to_use)) + self.all_configs_dict[sim_index] = self.__SimConfig( + this_vary, conf_to_use, this_path, sim_index + ) + self.all_configs_list[i].append(self.all_configs_dict[sim_index]) def __iter__(self) -> Generator[tuple[list[tuple[str, Any]], Parameters], None, None]: - for sim_paths, fiber in zip(self.data_dirs, self.all_required): - for variable_list, data_dir, params in self.__iter_1_sim(sim_paths, fiber): - params.output_path = str(data_dir) - yield variable_list, params + for i, sim_config_list in enumerate(self.all_configs_list): + for sim_config, params in self.__iter_1_sim(sim_config_list): + fiber_map = [] + for j in range(i + 1): + this_conf = self.all_configs_dict[sim_config.index[: j + 1]].config + if j > 0: + prev_conf = self.all_configs_dict[sim_config.index[:j]].config + length = prev_conf["length"] + fiber_map[j - 1][0] + else: + length = 0.0 + fiber_map.append((length, this_conf["name"])) + params.output_path = str(sim_config.output_path) + params.fiber_map = dict(fiber_map) + yield sim_config.vary_list, params def __iter_1_sim( - self, sim_paths: list[Path], fiber: list[tuple[list[tuple[str, Any]], dict[str, Any]]] - ) -> Generator[tuple[list[tuple[str, Any]], Path, Parameters], None, None]: - """iterates through the parameters of only one fiber. It takes care of recovery partially completed - simulations, skipping complete ones and waiting for the previous fiber to finish + self, configs: list["Configuration.__SimConfig"] + ) -> Generator[tuple["Configuration.__SimConfig", Parameters], None, None]: + """iterates through the parameters of only one fiber. It takes care of recovering partially + completed simulations, skipping complete ones and waiting for the previous fiber to finish Parameters ---------- - sim_paths : list[Path] - output_paths of the desired simulations - fiber : list[tuple[list[tuple[str, Any]], dict[str, Any]]] - list of variable list and config dict as yielded by variable_iterator + configs : list[__SimConfig] + list of configuration obj Yields ------- - list[tuple[str, Any]] - list of variable paramters - Path - desired output path + __SimConfig + configuration obj Parameters computed Parameters obj """ - sim_dict: dict[Path, tuple[list[tuple[str, Any]], dict[str, Any]]] = dict( - zip(sim_paths, fiber) - ) + sim_dict: dict[Path, Configuration.__SimConfig] = {s.output_path: s for s in configs} while len(sim_dict) > 0: - for data_dir, (variable_list, config_dict) in sim_dict.items(): - task, config_dict = self.__decide(data_dir, config_dict) + for data_dir, sim_config in sim_dict.items(): + task, config_dict = self.__decide(sim_config) if task == self.Action.RUN: sim_dict.pop(data_dir) p = Parameters(**config_dict) p.compute() - yield variable_list, data_dir, p + yield sim_config, p if "recovery_last_stored" in config_dict and self.skip_callback is not None: self.skip_callback(config_dict["recovery_last_stored"]) break @@ -920,16 +940,13 @@ class Configuration: time.sleep(1) def __decide( - self, data_dir: Path, config_dict: dict[str, Any] + self, sim_config: "Configuration.__SimConfig" ) -> tuple["Configuration.Action", dict[str, Any]]: """decide what to to with a particular simulation Parameters ---------- - data_dir : Path - path to the output of the simulation - config_dict : dict[str, Any] - configuration of the simulation + sim_config : __SimConfig Returns ------- @@ -939,20 +956,20 @@ class Configuration: config dictionary. The only key possibly modified is 'prev_data_dir', which gets set if the simulation is partially completed """ - out_status, num = self.sim_status(data_dir, config_dict) + out_status, num = self.sim_status(sim_config.output_path, sim_config.config) if out_status == self.State.COMPLETE: - return self.Action.SKIP, config_dict + return self.Action.SKIP, sim_config.config elif out_status == self.State.PARTIAL: - config_dict["recovery_data_dir"] = str(data_dir) - config_dict["recovery_last_stored"] = num - return self.Action.RUN, config_dict + sim_config.config["recovery_data_dir"] = str(sim_config.output_path) + sim_config.config["recovery_last_stored"] = num + return self.Action.RUN, sim_config.config - if "prev_data_dir" in config_dict: - prev_data_path = Path(config_dict["prev_data_dir"]) + if "prev_data_dir" in sim_config.config: + prev_data_path = Path(sim_config.config["prev_data_dir"]) prev_status, _ = self.sim_status(prev_data_path) if prev_status in {self.State.PARTIAL, self.State.ABSENT}: - return self.Action.WAIT, config_dict - return self.Action.RUN, config_dict + return self.Action.WAIT, sim_config.config + return self.Action.RUN, sim_config.config def sim_status( self, data_dir: Path, config_dict: dict[str, Any] = None @@ -975,9 +992,9 @@ class Configuration: num = utils.find_last_spectrum_num(data_dir) if config_dict is None: try: - config_dict = utils.load_toml(data_dir / "params.toml") + config_dict = utils.open_config(data_dir / PARAM_FN) except FileNotFoundError: - self.logger.warning(f"did not find 'params.toml' in {data_dir}") + self.logger.warning(f"did not find {PARAM_FN!r} in {data_dir}") return self.State.ABSENT, 0 if num == config_dict["z_num"] - 1: return self.State.COMPLETE, num @@ -989,18 +1006,23 @@ class Configuration: raise ValueError(f"Too many spectra in {data_dir}") def save_parameters(self): - for config, sim_dir in zip(self.configs, self.sim_dirs): + for config, sim_dir in zip(self.master_configs, self.sim_dirs): os.makedirs(sim_dir, exist_ok=True) utils.save_toml(sim_dir / f"initial_config.toml", config) + @property + def first(self) -> Parameters: + for _, param in self: + return param + class DataPather: def __init__(self, dl: list[dict[str, Any]]): self.dict_list = dl - self.n = len(self.dict_list) - self.final_list = list(self.dico_iterator(self.n)) - def dico_iterator(self, index: int) -> Generator[list[list[tuple[str, Any]]], None, None]: + def dico_iterator( + self, index: int + ) -> Generator[tuple[tuple[tuple[int, ...]], list[list[tuple[str, Any]]]], None, None]: """iterates through every possible combination of a list of dict of lists Parameters @@ -1034,12 +1056,14 @@ class DataPather: for r in itertools.product(*ranges): flat = [(d_tem_list[i][0], d_tem_list[i][1][j]) for i, j in enumerate(r)] + pos = tuple(r) out = [flat[left:right] for left, right in zip(dict_pos[:-1], dict_pos[1:])] - yield out + pos = tuple(pos[left:right] for left, right in zip(dict_pos[:-1], dict_pos[1:])) + yield pos, out def all_vary_list(self, index): - for l in self.dico_iterator(index): - unique_vary = [] + for sim_index, l in self.dico_iterator(index): + unique_vary: list[tuple[str, Any]] = [] for ll in l[: index + 1]: for pname, pval in ll: for i, (pn, _) in enumerate(unique_vary): @@ -1047,9 +1071,12 @@ class DataPather: del unique_vary[i] break unique_vary.append((pname, pval)) - yield format_variable_list(reduce_all_variable(l[:index])), format_variable_list( - reduce_all_variable(l) - ), unique_vary + yield sim_index, format_variable_list( + reduce_all_variable(l[:index]) + ), format_variable_list(reduce_all_variable(l)), unique_vary + + def __repr__(self): + return f"DataPather([{', '.join(repr(d) for d in self.dict_list)}])" @dataclass @@ -1065,6 +1092,11 @@ class PlotRange: def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: return sort_axis(axis, self) + def __iter__(self): + yield self.left + yield self.right + yield self.unit.__name__ + def sort_axis( axis: np.ndarray, plt_range: PlotRange