From e0e262c6f2c1fe65cef62d965b1bcd8026de2656 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Wed, 26 Jul 2023 11:33:36 +0200 Subject: [PATCH] reworked parameter compilation --- src/scgenerator/const.py | 14 +++--- src/scgenerator/errors.py | 4 -- src/scgenerator/evaluator.py | 17 ++++--- src/scgenerator/parameter.py | 41 ++++++++------- src/scgenerator/physics/fiber.py | 3 +- src/scgenerator/utils.py | 85 +------------------------------- 6 files changed, 43 insertions(+), 121 deletions(-) diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index 1dc33df..3ec5b5c 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -27,9 +27,12 @@ PARAM_FN = "params.toml" PARAM_SEPARATOR = " " -MANDATORY_PARAMETERS = [ +MANDATORY_PARAMETERS = { "name", "w", + "t", + "fft", + "ifft", "w0", "spec_0", "field_0", @@ -39,16 +42,15 @@ MANDATORY_PARAMETERS = [ "length", "adapt_step_size", "tolerated_error", - "recovery_last_stored", - "output_path", "repeat", "linear_operator", "nonlinear_operator", -] + "soliton_length", + "nonlinear_length", + "dispersion_length", +} ROOT_PARAMETERS = [ - "repeat", - "num", "dt", "t_num", "time_window", diff --git a/src/scgenerator/errors.py b/src/scgenerator/errors.py index 93e69a3..416ec08 100644 --- a/src/scgenerator/errors.py +++ b/src/scgenerator/errors.py @@ -28,10 +28,6 @@ class MissingParameterError(Exception): super().__init__(message) -class DuplicateParameterError(Exception): - pass - - class IncompleteDataFolderError(FileNotFoundError): pass diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index 94e99db..fc0c70a 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -38,7 +38,10 @@ class Rule: return f"Rule(targets={self.targets!r}, func={self.func_name}, args={self.args!r})" def __str__(self) -> str: - return f"[{', '.join(self.args)}] -- {self.func.__module__}.{self.func.__name__} --> [{', '.join(self.targets)}]" + return ( + f"[{', '.join(self.args)}] -- {self.func.__module__}." + f"{self.func.__name__} --> [{', '.join(self.targets)}]" + ) @property def func_name(self) -> str: @@ -192,8 +195,8 @@ class Evaluator: if target in self.__curent_lookup: raise EvaluatorError( "cyclic dependency detected : " - f"{target!r} seems to depend on itself, " - f"please provide a value for at least one variable in {self.__curent_lookup!r}. " + f"{target!r} seems to depend on itself, please provide " + f"a value for at least one variable in {self.__curent_lookup!r}. " + self.attempted_rules_str(target) ) else: @@ -372,9 +375,9 @@ default_rules: list[Rule] = [ Rule("effective_area", fiber.effective_area_from_diam), Rule("effective_area", fiber.effective_area_hasan, conditions=dict(model="hasan")), Rule("effective_area", fiber.effective_area_from_gamma, priorities=-1), - Rule("effective_area", fiber.elfective_area_marcatili, priorities=-2), - Rule("effecive_area_arr", fiber.effective_area_from_V, ["core_radius", "V_eff_arr"]), - Rule("effecive_area_arr", fiber.load_custom_effective_area), + Rule("effective_area", fiber.effective_area_marcatili, priorities=-2), + Rule("effective_area_arr", fiber.effective_area_from_V, ["core_radius", "V_eff_arr"]), + Rule("effective_area_arr", fiber.load_custom_effective_area), Rule( "V_eff", fiber.V_parameter_koshiba, @@ -396,7 +399,7 @@ default_rules: list[Rule] = [ Rule("n2", lambda: 2.2e-20, priorities=-1), Rule("gamma", lambda gamma_arr: gamma_arr[0], priorities=-1), Rule("gamma", fiber.gamma_parameter), - Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "effecive_area_arr"]), + Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "effective_area_arr"]), # Raman Rule(["hr_w", "raman_fraction"], fiber.delayed_raman_w), Rule("raman_fraction", fiber.raman_fraction), diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 10f552a..e0122f1 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -13,7 +13,6 @@ import numpy as np from scgenerator import utils from scgenerator.const import MANDATORY_PARAMETERS, __version__ -from scgenerator.errors import EvaluatorError from scgenerator.evaluator import Evaluator from scgenerator.operators import Qualifier, SpecOperator from scgenerator.utils import update_path_name @@ -256,8 +255,9 @@ class Parameter: if instance._frozen: raise AttributeError("Parameters instance is frozen and can no longer be modified") - if isinstance(value, Parameter) and self.default is not None: - instance._param_dico[self.name] = copy(self.default) + if isinstance(value, Parameter): + if self.default is not None: + instance._param_dico[self.name] = copy(self.default) return is_value, value = self.validate(value) @@ -299,13 +299,12 @@ class Parameters: # internal machinery _param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False) - _evaluator: Evaluator = field(init=False, repr=False) _p_names: ClassVar[Set[str]] = set() _frozen: bool = field(init=False, default=False, repr=False) # root name: str = Parameter(string, default="no name") - output_path: Path = Parameter(type_checker(Path), default=Path("sc_data"), converter=Path) + output_path: Path = Parameter(type_checker(Path), converter=Path) # fiber input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0) @@ -433,10 +432,6 @@ class Parameters: datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime)) version: str = Parameter(string) - def __post_init__(self): - self._evaluator = Evaluator.default(self.full_field) - self._evaluator.set(self._param_dico) - def __repr__(self) -> str: return "Parameter(" + ", ".join(self.__repr_list__()) + ")" @@ -455,23 +450,31 @@ class Parameters: setattr(self, k, v) self.__post_init__() - def dump_dict(self, compute=True, add_metadata=True) -> dict[str, Any]: - if compute: - self.compute_in_place() + def get_evaluator(self): + evaluator = Evaluator.default(self.full_field) + evaluator.set(self._param_dico.copy()) + return evaluator + + def dump_dict(self, add_metadata=True) -> dict[str, Any]: param = Parameters.strip_params_dict(self._param_dico) if add_metadata: param["datetime"] = datetime_module.datetime.now() param["version"] = __version__ return param - def compute_in_place(self, *to_compute: str): - if len(to_compute) == 0: - to_compute = MANDATORY_PARAMETERS + def compile(self, exhaustive=False): + to_compute = MANDATORY_PARAMETERS + evaluator = self.get_evaluator() for k in to_compute: - getattr(self, k) - - def compute(self, key: str) -> Any: - return self._evaluator.compute(key) + evaluator.compute(k) + if exhaustive: + for p in self._p_names: + if p not in evaluator.params: + try: + evaluator.compute(p) + except Exception: + pass + return self.__class__(**{k: v for k, v in evaluator.params.items() if k in self._p_names}) def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str: """return a pretty formatted string describing the parameters""" diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index cead06b..dbe5405 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -28,7 +28,8 @@ def lambda_for_envelope_dispersion( subset of l in the interpolation range with two extra values on each side to accomodate for taking gradients np.ndarray - indices of the original l where the values are valid (i.e. without the two extra on each side) + indices of the original l where the values are valid + (i.e. without the two extra on each side) """ su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0] if l[su].min() > 1.01 * wavelength_window[0]: diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index 5b0aac5..4c7b25e 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -11,20 +11,17 @@ import itertools import os import re from collections import defaultdict -from dataclasses import dataclass from functools import cache, lru_cache from pathlib import Path from string import printable as str_printable from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Union import numpy as np -from numpy.ma.extras import isin import pkg_resources as pkg import tomli import tomli_w -from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, ROOT_PARAMETERS, SPEC1_FN, Z_FN -from scgenerator.errors import DuplicateParameterError +from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN from scgenerator.logger import get_logger T_ = TypeVar("T_") @@ -91,47 +88,6 @@ class Paths: return os.path.join(cls.get("plots"), name) -@dataclass(init=False) -class SubConfig: - fixed: dict[str, Any] - variable: list[dict[str, list]] - fixed_keys: set[str] - variable_keys: set[str] - - def __init__(self, dico: dict[str, Any]): - dico = dico.copy() - self.variable = conform_variable_entry(dico.pop("variable", [])) - self.fixed = dico - self.__update - - def __update(self): - self.variable_keys = set() - self.fixed_keys = set() - for dico in self.variable: - for key in dico: - if key in self.variable_keys: - raise DuplicateParameterError(f"{key} is specified twice") - self.variable_keys.add(key) - for key in self.fixed: - if key in self.variable_keys: - raise DuplicateParameterError(f"{key} is specified twice") - self.fixed_keys.add(key) - - def weak_update(self, other: SubConfig = None, **kwargs): - """similar to a dict update method put prioritizes existing values - - Parameters - ---------- - other : SubConfig - other obj - """ - if other is None: - other = SubConfig(kwargs) - self.fixed = other.fixed | self.fixed - self.variable = other.variable + self.variable - self.__update() - - def conform_variable_entry(d) -> list[dict[str, list]]: if isinstance(d, MutableMapping): d = [{k: v} for k, v in d.items()] @@ -240,45 +196,6 @@ def save_toml(path: os.PathLike, dico): return dico -def load_config_sequence(path: os.PathLike) -> tuple[Path, list[SubConfig]]: - """loads a configuration file - - Parameters - ---------- - path : os.PathLike - path to the config toml file or a directory containing config files - - Returns - ------- - final_path : Path - output name of the simulation - list[dict[str, Any]] - one config per fiber - - """ - path = Path(path) - if path.name.lower().endswith(".toml"): - master_config_dict = _open_config(path) - fiber_list = [SubConfig(d) for d in master_config_dict.pop("Fiber")] - master_config = SubConfig(master_config_dict) - else: - master_config = SubConfig(dict(name=path.name)) - fiber_list = [SubConfig(_open_config(p)) for p in sorted(path.glob("initial_config*.toml"))] - - if len(fiber_list) == 0: - raise ValueError(f"No fiber in config {path}") - for fiber in fiber_list: - fiber.weak_update(master_config) - if "num" not in fiber_list[0].variable_keys: - repeat_arg = list(range(fiber_list[0].fixed.get("repeat", 1))) - fiber_list[0].weak_update(variable=dict(num=repeat_arg)) - for p_name in ROOT_PARAMETERS: - if any(p_name in conf.variable_keys for conf in fiber_list[1:]): - raise ValueError(f"{p_name} should only be specified in the root or first fiber") - configs = fiber_list - return Path(master_config.fixed["name"]), configs - - @cache def load_material_dico(name: str) -> dict[str, Any]: """loads a material dictionary