reworked parameter compilation

This commit is contained in:
Benoît Sierro
2023-07-26 11:33:36 +02:00
parent 77f5932fe8
commit e0e262c6f2
6 changed files with 43 additions and 121 deletions

View File

@@ -27,9 +27,12 @@ PARAM_FN = "params.toml"
PARAM_SEPARATOR = " " PARAM_SEPARATOR = " "
MANDATORY_PARAMETERS = [ MANDATORY_PARAMETERS = {
"name", "name",
"w", "w",
"t",
"fft",
"ifft",
"w0", "w0",
"spec_0", "spec_0",
"field_0", "field_0",
@@ -39,16 +42,15 @@ MANDATORY_PARAMETERS = [
"length", "length",
"adapt_step_size", "adapt_step_size",
"tolerated_error", "tolerated_error",
"recovery_last_stored",
"output_path",
"repeat", "repeat",
"linear_operator", "linear_operator",
"nonlinear_operator", "nonlinear_operator",
] "soliton_length",
"nonlinear_length",
"dispersion_length",
}
ROOT_PARAMETERS = [ ROOT_PARAMETERS = [
"repeat",
"num",
"dt", "dt",
"t_num", "t_num",
"time_window", "time_window",

View File

@@ -28,10 +28,6 @@ class MissingParameterError(Exception):
super().__init__(message) super().__init__(message)
class DuplicateParameterError(Exception):
pass
class IncompleteDataFolderError(FileNotFoundError): class IncompleteDataFolderError(FileNotFoundError):
pass pass

View File

@@ -38,7 +38,10 @@ class Rule:
return f"Rule(targets={self.targets!r}, func={self.func_name}, args={self.args!r})" return f"Rule(targets={self.targets!r}, func={self.func_name}, args={self.args!r})"
def __str__(self) -> str: def __str__(self) -> str:
return f"[{', '.join(self.args)}] -- {self.func.__module__}.{self.func.__name__} --> [{', '.join(self.targets)}]" return (
f"[{', '.join(self.args)}] -- {self.func.__module__}."
f"{self.func.__name__} --> [{', '.join(self.targets)}]"
)
@property @property
def func_name(self) -> str: def func_name(self) -> str:
@@ -192,8 +195,8 @@ class Evaluator:
if target in self.__curent_lookup: if target in self.__curent_lookup:
raise EvaluatorError( raise EvaluatorError(
"cyclic dependency detected : " "cyclic dependency detected : "
f"{target!r} seems to depend on itself, " f"{target!r} seems to depend on itself, please provide "
f"please provide a value for at least one variable in {self.__curent_lookup!r}. " f"a value for at least one variable in {self.__curent_lookup!r}. "
+ self.attempted_rules_str(target) + self.attempted_rules_str(target)
) )
else: else:
@@ -372,9 +375,9 @@ default_rules: list[Rule] = [
Rule("effective_area", fiber.effective_area_from_diam), Rule("effective_area", fiber.effective_area_from_diam),
Rule("effective_area", fiber.effective_area_hasan, conditions=dict(model="hasan")), Rule("effective_area", fiber.effective_area_hasan, conditions=dict(model="hasan")),
Rule("effective_area", fiber.effective_area_from_gamma, priorities=-1), Rule("effective_area", fiber.effective_area_from_gamma, priorities=-1),
Rule("effective_area", fiber.elfective_area_marcatili, priorities=-2), Rule("effective_area", fiber.effective_area_marcatili, priorities=-2),
Rule("effecive_area_arr", fiber.effective_area_from_V, ["core_radius", "V_eff_arr"]), Rule("effective_area_arr", fiber.effective_area_from_V, ["core_radius", "V_eff_arr"]),
Rule("effecive_area_arr", fiber.load_custom_effective_area), Rule("effective_area_arr", fiber.load_custom_effective_area),
Rule( Rule(
"V_eff", "V_eff",
fiber.V_parameter_koshiba, fiber.V_parameter_koshiba,
@@ -396,7 +399,7 @@ default_rules: list[Rule] = [
Rule("n2", lambda: 2.2e-20, priorities=-1), Rule("n2", lambda: 2.2e-20, priorities=-1),
Rule("gamma", lambda gamma_arr: gamma_arr[0], priorities=-1), Rule("gamma", lambda gamma_arr: gamma_arr[0], priorities=-1),
Rule("gamma", fiber.gamma_parameter), Rule("gamma", fiber.gamma_parameter),
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "effecive_area_arr"]), Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "effective_area_arr"]),
# Raman # Raman
Rule(["hr_w", "raman_fraction"], fiber.delayed_raman_w), Rule(["hr_w", "raman_fraction"], fiber.delayed_raman_w),
Rule("raman_fraction", fiber.raman_fraction), Rule("raman_fraction", fiber.raman_fraction),

View File

@@ -13,7 +13,6 @@ import numpy as np
from scgenerator import utils from scgenerator import utils
from scgenerator.const import MANDATORY_PARAMETERS, __version__ from scgenerator.const import MANDATORY_PARAMETERS, __version__
from scgenerator.errors import EvaluatorError
from scgenerator.evaluator import Evaluator from scgenerator.evaluator import Evaluator
from scgenerator.operators import Qualifier, SpecOperator from scgenerator.operators import Qualifier, SpecOperator
from scgenerator.utils import update_path_name from scgenerator.utils import update_path_name
@@ -256,7 +255,8 @@ class Parameter:
if instance._frozen: if instance._frozen:
raise AttributeError("Parameters instance is frozen and can no longer be modified") raise AttributeError("Parameters instance is frozen and can no longer be modified")
if isinstance(value, Parameter) and self.default is not None: if isinstance(value, Parameter):
if self.default is not None:
instance._param_dico[self.name] = copy(self.default) instance._param_dico[self.name] = copy(self.default)
return return
@@ -299,13 +299,12 @@ class Parameters:
# internal machinery # internal machinery
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False) _param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
_evaluator: Evaluator = field(init=False, repr=False)
_p_names: ClassVar[Set[str]] = set() _p_names: ClassVar[Set[str]] = set()
_frozen: bool = field(init=False, default=False, repr=False) _frozen: bool = field(init=False, default=False, repr=False)
# root # root
name: str = Parameter(string, default="no name") name: str = Parameter(string, default="no name")
output_path: Path = Parameter(type_checker(Path), default=Path("sc_data"), converter=Path) output_path: Path = Parameter(type_checker(Path), converter=Path)
# fiber # fiber
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0) input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
@@ -433,10 +432,6 @@ class Parameters:
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime)) datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
version: str = Parameter(string) version: str = Parameter(string)
def __post_init__(self):
self._evaluator = Evaluator.default(self.full_field)
self._evaluator.set(self._param_dico)
def __repr__(self) -> str: def __repr__(self) -> str:
return "Parameter(" + ", ".join(self.__repr_list__()) + ")" return "Parameter(" + ", ".join(self.__repr_list__()) + ")"
@@ -455,23 +450,31 @@ class Parameters:
setattr(self, k, v) setattr(self, k, v)
self.__post_init__() self.__post_init__()
def dump_dict(self, compute=True, add_metadata=True) -> dict[str, Any]: def get_evaluator(self):
if compute: evaluator = Evaluator.default(self.full_field)
self.compute_in_place() evaluator.set(self._param_dico.copy())
return evaluator
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
param = Parameters.strip_params_dict(self._param_dico) param = Parameters.strip_params_dict(self._param_dico)
if add_metadata: if add_metadata:
param["datetime"] = datetime_module.datetime.now() param["datetime"] = datetime_module.datetime.now()
param["version"] = __version__ param["version"] = __version__
return param return param
def compute_in_place(self, *to_compute: str): def compile(self, exhaustive=False):
if len(to_compute) == 0:
to_compute = MANDATORY_PARAMETERS to_compute = MANDATORY_PARAMETERS
evaluator = self.get_evaluator()
for k in to_compute: for k in to_compute:
getattr(self, k) evaluator.compute(k)
if exhaustive:
def compute(self, key: str) -> Any: for p in self._p_names:
return self._evaluator.compute(key) if p not in evaluator.params:
try:
evaluator.compute(p)
except Exception:
pass
return self.__class__(**{k: v for k, v in evaluator.params.items() if k in self._p_names})
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str: def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
"""return a pretty formatted string describing the parameters""" """return a pretty formatted string describing the parameters"""

View File

@@ -28,7 +28,8 @@ def lambda_for_envelope_dispersion(
subset of l in the interpolation range with two extra values on each side subset of l in the interpolation range with two extra values on each side
to accomodate for taking gradients to accomodate for taking gradients
np.ndarray np.ndarray
indices of the original l where the values are valid (i.e. without the two extra on each side) indices of the original l where the values are valid
(i.e. without the two extra on each side)
""" """
su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0] su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0]
if l[su].min() > 1.01 * wavelength_window[0]: if l[su].min() > 1.01 * wavelength_window[0]:

View File

@@ -11,20 +11,17 @@ import itertools
import os import os
import re import re
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from functools import cache, lru_cache from functools import cache, lru_cache
from pathlib import Path from pathlib import Path
from string import printable as str_printable from string import printable as str_printable
from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Union from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Union
import numpy as np import numpy as np
from numpy.ma.extras import isin
import pkg_resources as pkg import pkg_resources as pkg
import tomli import tomli
import tomli_w import tomli_w
from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, ROOT_PARAMETERS, SPEC1_FN, Z_FN from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN
from scgenerator.errors import DuplicateParameterError
from scgenerator.logger import get_logger from scgenerator.logger import get_logger
T_ = TypeVar("T_") T_ = TypeVar("T_")
@@ -91,47 +88,6 @@ class Paths:
return os.path.join(cls.get("plots"), name) return os.path.join(cls.get("plots"), name)
@dataclass(init=False)
class SubConfig:
fixed: dict[str, Any]
variable: list[dict[str, list]]
fixed_keys: set[str]
variable_keys: set[str]
def __init__(self, dico: dict[str, Any]):
dico = dico.copy()
self.variable = conform_variable_entry(dico.pop("variable", []))
self.fixed = dico
self.__update
def __update(self):
self.variable_keys = set()
self.fixed_keys = set()
for dico in self.variable:
for key in dico:
if key in self.variable_keys:
raise DuplicateParameterError(f"{key} is specified twice")
self.variable_keys.add(key)
for key in self.fixed:
if key in self.variable_keys:
raise DuplicateParameterError(f"{key} is specified twice")
self.fixed_keys.add(key)
def weak_update(self, other: SubConfig = None, **kwargs):
"""similar to a dict update method put prioritizes existing values
Parameters
----------
other : SubConfig
other obj
"""
if other is None:
other = SubConfig(kwargs)
self.fixed = other.fixed | self.fixed
self.variable = other.variable + self.variable
self.__update()
def conform_variable_entry(d) -> list[dict[str, list]]: def conform_variable_entry(d) -> list[dict[str, list]]:
if isinstance(d, MutableMapping): if isinstance(d, MutableMapping):
d = [{k: v} for k, v in d.items()] d = [{k: v} for k, v in d.items()]
@@ -240,45 +196,6 @@ def save_toml(path: os.PathLike, dico):
return dico return dico
def load_config_sequence(path: os.PathLike) -> tuple[Path, list[SubConfig]]:
"""loads a configuration file
Parameters
----------
path : os.PathLike
path to the config toml file or a directory containing config files
Returns
-------
final_path : Path
output name of the simulation
list[dict[str, Any]]
one config per fiber
"""
path = Path(path)
if path.name.lower().endswith(".toml"):
master_config_dict = _open_config(path)
fiber_list = [SubConfig(d) for d in master_config_dict.pop("Fiber")]
master_config = SubConfig(master_config_dict)
else:
master_config = SubConfig(dict(name=path.name))
fiber_list = [SubConfig(_open_config(p)) for p in sorted(path.glob("initial_config*.toml"))]
if len(fiber_list) == 0:
raise ValueError(f"No fiber in config {path}")
for fiber in fiber_list:
fiber.weak_update(master_config)
if "num" not in fiber_list[0].variable_keys:
repeat_arg = list(range(fiber_list[0].fixed.get("repeat", 1)))
fiber_list[0].weak_update(variable=dict(num=repeat_arg))
for p_name in ROOT_PARAMETERS:
if any(p_name in conf.variable_keys for conf in fiber_list[1:]):
raise ValueError(f"{p_name} should only be specified in the root or first fiber")
configs = fiber_list
return Path(master_config.fixed["name"]), configs
@cache @cache
def load_material_dico(name: str) -> dict[str, Any]: def load_material_dico(name: str) -> dict[str, Any]:
"""loads a material dictionary """loads a material dictionary