reworked parameter compilation

This commit is contained in:
Benoît Sierro
2023-07-26 11:33:36 +02:00
parent 77f5932fe8
commit e0e262c6f2
6 changed files with 43 additions and 121 deletions

View File

@@ -27,9 +27,12 @@ PARAM_FN = "params.toml"
PARAM_SEPARATOR = " "
MANDATORY_PARAMETERS = [
MANDATORY_PARAMETERS = {
"name",
"w",
"t",
"fft",
"ifft",
"w0",
"spec_0",
"field_0",
@@ -39,16 +42,15 @@ MANDATORY_PARAMETERS = [
"length",
"adapt_step_size",
"tolerated_error",
"recovery_last_stored",
"output_path",
"repeat",
"linear_operator",
"nonlinear_operator",
]
"soliton_length",
"nonlinear_length",
"dispersion_length",
}
ROOT_PARAMETERS = [
"repeat",
"num",
"dt",
"t_num",
"time_window",

View File

@@ -28,10 +28,6 @@ class MissingParameterError(Exception):
super().__init__(message)
class DuplicateParameterError(Exception):
pass
class IncompleteDataFolderError(FileNotFoundError):
pass

View File

@@ -38,7 +38,10 @@ class Rule:
return f"Rule(targets={self.targets!r}, func={self.func_name}, args={self.args!r})"
def __str__(self) -> str:
return f"[{', '.join(self.args)}] -- {self.func.__module__}.{self.func.__name__} --> [{', '.join(self.targets)}]"
return (
f"[{', '.join(self.args)}] -- {self.func.__module__}."
f"{self.func.__name__} --> [{', '.join(self.targets)}]"
)
@property
def func_name(self) -> str:
@@ -192,8 +195,8 @@ class Evaluator:
if target in self.__curent_lookup:
raise EvaluatorError(
"cyclic dependency detected : "
f"{target!r} seems to depend on itself, "
f"please provide a value for at least one variable in {self.__curent_lookup!r}. "
f"{target!r} seems to depend on itself, please provide "
f"a value for at least one variable in {self.__curent_lookup!r}. "
+ self.attempted_rules_str(target)
)
else:
@@ -372,9 +375,9 @@ default_rules: list[Rule] = [
Rule("effective_area", fiber.effective_area_from_diam),
Rule("effective_area", fiber.effective_area_hasan, conditions=dict(model="hasan")),
Rule("effective_area", fiber.effective_area_from_gamma, priorities=-1),
Rule("effective_area", fiber.elfective_area_marcatili, priorities=-2),
Rule("effecive_area_arr", fiber.effective_area_from_V, ["core_radius", "V_eff_arr"]),
Rule("effecive_area_arr", fiber.load_custom_effective_area),
Rule("effective_area", fiber.effective_area_marcatili, priorities=-2),
Rule("effective_area_arr", fiber.effective_area_from_V, ["core_radius", "V_eff_arr"]),
Rule("effective_area_arr", fiber.load_custom_effective_area),
Rule(
"V_eff",
fiber.V_parameter_koshiba,
@@ -396,7 +399,7 @@ default_rules: list[Rule] = [
Rule("n2", lambda: 2.2e-20, priorities=-1),
Rule("gamma", lambda gamma_arr: gamma_arr[0], priorities=-1),
Rule("gamma", fiber.gamma_parameter),
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "effecive_area_arr"]),
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "effective_area_arr"]),
# Raman
Rule(["hr_w", "raman_fraction"], fiber.delayed_raman_w),
Rule("raman_fraction", fiber.raman_fraction),

View File

@@ -13,7 +13,6 @@ import numpy as np
from scgenerator import utils
from scgenerator.const import MANDATORY_PARAMETERS, __version__
from scgenerator.errors import EvaluatorError
from scgenerator.evaluator import Evaluator
from scgenerator.operators import Qualifier, SpecOperator
from scgenerator.utils import update_path_name
@@ -256,8 +255,9 @@ class Parameter:
if instance._frozen:
raise AttributeError("Parameters instance is frozen and can no longer be modified")
if isinstance(value, Parameter) and self.default is not None:
instance._param_dico[self.name] = copy(self.default)
if isinstance(value, Parameter):
if self.default is not None:
instance._param_dico[self.name] = copy(self.default)
return
is_value, value = self.validate(value)
@@ -299,13 +299,12 @@ class Parameters:
# internal machinery
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
_evaluator: Evaluator = field(init=False, repr=False)
_p_names: ClassVar[Set[str]] = set()
_frozen: bool = field(init=False, default=False, repr=False)
# root
name: str = Parameter(string, default="no name")
output_path: Path = Parameter(type_checker(Path), default=Path("sc_data"), converter=Path)
output_path: Path = Parameter(type_checker(Path), converter=Path)
# fiber
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
@@ -433,10 +432,6 @@ class Parameters:
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
version: str = Parameter(string)
def __post_init__(self):
self._evaluator = Evaluator.default(self.full_field)
self._evaluator.set(self._param_dico)
def __repr__(self) -> str:
return "Parameter(" + ", ".join(self.__repr_list__()) + ")"
@@ -455,23 +450,31 @@ class Parameters:
setattr(self, k, v)
self.__post_init__()
def dump_dict(self, compute=True, add_metadata=True) -> dict[str, Any]:
if compute:
self.compute_in_place()
def get_evaluator(self):
evaluator = Evaluator.default(self.full_field)
evaluator.set(self._param_dico.copy())
return evaluator
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
param = Parameters.strip_params_dict(self._param_dico)
if add_metadata:
param["datetime"] = datetime_module.datetime.now()
param["version"] = __version__
return param
def compute_in_place(self, *to_compute: str):
if len(to_compute) == 0:
to_compute = MANDATORY_PARAMETERS
def compile(self, exhaustive=False):
to_compute = MANDATORY_PARAMETERS
evaluator = self.get_evaluator()
for k in to_compute:
getattr(self, k)
def compute(self, key: str) -> Any:
return self._evaluator.compute(key)
evaluator.compute(k)
if exhaustive:
for p in self._p_names:
if p not in evaluator.params:
try:
evaluator.compute(p)
except Exception:
pass
return self.__class__(**{k: v for k, v in evaluator.params.items() if k in self._p_names})
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
"""return a pretty formatted string describing the parameters"""

View File

@@ -28,7 +28,8 @@ def lambda_for_envelope_dispersion(
subset of l in the interpolation range with two extra values on each side
to accomodate for taking gradients
np.ndarray
indices of the original l where the values are valid (i.e. without the two extra on each side)
indices of the original l where the values are valid
(i.e. without the two extra on each side)
"""
su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0]
if l[su].min() > 1.01 * wavelength_window[0]:

View File

@@ -11,20 +11,17 @@ import itertools
import os
import re
from collections import defaultdict
from dataclasses import dataclass
from functools import cache, lru_cache
from pathlib import Path
from string import printable as str_printable
from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Union
import numpy as np
from numpy.ma.extras import isin
import pkg_resources as pkg
import tomli
import tomli_w
from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, ROOT_PARAMETERS, SPEC1_FN, Z_FN
from scgenerator.errors import DuplicateParameterError
from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN
from scgenerator.logger import get_logger
T_ = TypeVar("T_")
@@ -91,47 +88,6 @@ class Paths:
return os.path.join(cls.get("plots"), name)
@dataclass(init=False)
class SubConfig:
fixed: dict[str, Any]
variable: list[dict[str, list]]
fixed_keys: set[str]
variable_keys: set[str]
def __init__(self, dico: dict[str, Any]):
dico = dico.copy()
self.variable = conform_variable_entry(dico.pop("variable", []))
self.fixed = dico
self.__update
def __update(self):
self.variable_keys = set()
self.fixed_keys = set()
for dico in self.variable:
for key in dico:
if key in self.variable_keys:
raise DuplicateParameterError(f"{key} is specified twice")
self.variable_keys.add(key)
for key in self.fixed:
if key in self.variable_keys:
raise DuplicateParameterError(f"{key} is specified twice")
self.fixed_keys.add(key)
def weak_update(self, other: SubConfig = None, **kwargs):
"""similar to a dict update method put prioritizes existing values
Parameters
----------
other : SubConfig
other obj
"""
if other is None:
other = SubConfig(kwargs)
self.fixed = other.fixed | self.fixed
self.variable = other.variable + self.variable
self.__update()
def conform_variable_entry(d) -> list[dict[str, list]]:
if isinstance(d, MutableMapping):
d = [{k: v} for k, v in d.items()]
@@ -240,45 +196,6 @@ def save_toml(path: os.PathLike, dico):
return dico
def load_config_sequence(path: os.PathLike) -> tuple[Path, list[SubConfig]]:
"""loads a configuration file
Parameters
----------
path : os.PathLike
path to the config toml file or a directory containing config files
Returns
-------
final_path : Path
output name of the simulation
list[dict[str, Any]]
one config per fiber
"""
path = Path(path)
if path.name.lower().endswith(".toml"):
master_config_dict = _open_config(path)
fiber_list = [SubConfig(d) for d in master_config_dict.pop("Fiber")]
master_config = SubConfig(master_config_dict)
else:
master_config = SubConfig(dict(name=path.name))
fiber_list = [SubConfig(_open_config(p)) for p in sorted(path.glob("initial_config*.toml"))]
if len(fiber_list) == 0:
raise ValueError(f"No fiber in config {path}")
for fiber in fiber_list:
fiber.weak_update(master_config)
if "num" not in fiber_list[0].variable_keys:
repeat_arg = list(range(fiber_list[0].fixed.get("repeat", 1)))
fiber_list[0].weak_update(variable=dict(num=repeat_arg))
for p_name in ROOT_PARAMETERS:
if any(p_name in conf.variable_keys for conf in fiber_list[1:]):
raise ValueError(f"{p_name} should only be specified in the root or first fiber")
configs = fiber_list
return Path(master_config.fixed["name"]), configs
@cache
def load_material_dico(name: str) -> dict[str, Any]:
"""loads a material dictionary