reworked parameter compilation
This commit is contained in:
@@ -27,9 +27,12 @@ PARAM_FN = "params.toml"
|
||||
PARAM_SEPARATOR = " "
|
||||
|
||||
|
||||
MANDATORY_PARAMETERS = [
|
||||
MANDATORY_PARAMETERS = {
|
||||
"name",
|
||||
"w",
|
||||
"t",
|
||||
"fft",
|
||||
"ifft",
|
||||
"w0",
|
||||
"spec_0",
|
||||
"field_0",
|
||||
@@ -39,16 +42,15 @@ MANDATORY_PARAMETERS = [
|
||||
"length",
|
||||
"adapt_step_size",
|
||||
"tolerated_error",
|
||||
"recovery_last_stored",
|
||||
"output_path",
|
||||
"repeat",
|
||||
"linear_operator",
|
||||
"nonlinear_operator",
|
||||
]
|
||||
"soliton_length",
|
||||
"nonlinear_length",
|
||||
"dispersion_length",
|
||||
}
|
||||
|
||||
ROOT_PARAMETERS = [
|
||||
"repeat",
|
||||
"num",
|
||||
"dt",
|
||||
"t_num",
|
||||
"time_window",
|
||||
|
||||
@@ -28,10 +28,6 @@ class MissingParameterError(Exception):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class DuplicateParameterError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class IncompleteDataFolderError(FileNotFoundError):
|
||||
pass
|
||||
|
||||
|
||||
@@ -38,7 +38,10 @@ class Rule:
|
||||
return f"Rule(targets={self.targets!r}, func={self.func_name}, args={self.args!r})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"[{', '.join(self.args)}] -- {self.func.__module__}.{self.func.__name__} --> [{', '.join(self.targets)}]"
|
||||
return (
|
||||
f"[{', '.join(self.args)}] -- {self.func.__module__}."
|
||||
f"{self.func.__name__} --> [{', '.join(self.targets)}]"
|
||||
)
|
||||
|
||||
@property
|
||||
def func_name(self) -> str:
|
||||
@@ -192,8 +195,8 @@ class Evaluator:
|
||||
if target in self.__curent_lookup:
|
||||
raise EvaluatorError(
|
||||
"cyclic dependency detected : "
|
||||
f"{target!r} seems to depend on itself, "
|
||||
f"please provide a value for at least one variable in {self.__curent_lookup!r}. "
|
||||
f"{target!r} seems to depend on itself, please provide "
|
||||
f"a value for at least one variable in {self.__curent_lookup!r}. "
|
||||
+ self.attempted_rules_str(target)
|
||||
)
|
||||
else:
|
||||
@@ -372,9 +375,9 @@ default_rules: list[Rule] = [
|
||||
Rule("effective_area", fiber.effective_area_from_diam),
|
||||
Rule("effective_area", fiber.effective_area_hasan, conditions=dict(model="hasan")),
|
||||
Rule("effective_area", fiber.effective_area_from_gamma, priorities=-1),
|
||||
Rule("effective_area", fiber.elfective_area_marcatili, priorities=-2),
|
||||
Rule("effecive_area_arr", fiber.effective_area_from_V, ["core_radius", "V_eff_arr"]),
|
||||
Rule("effecive_area_arr", fiber.load_custom_effective_area),
|
||||
Rule("effective_area", fiber.effective_area_marcatili, priorities=-2),
|
||||
Rule("effective_area_arr", fiber.effective_area_from_V, ["core_radius", "V_eff_arr"]),
|
||||
Rule("effective_area_arr", fiber.load_custom_effective_area),
|
||||
Rule(
|
||||
"V_eff",
|
||||
fiber.V_parameter_koshiba,
|
||||
@@ -396,7 +399,7 @@ default_rules: list[Rule] = [
|
||||
Rule("n2", lambda: 2.2e-20, priorities=-1),
|
||||
Rule("gamma", lambda gamma_arr: gamma_arr[0], priorities=-1),
|
||||
Rule("gamma", fiber.gamma_parameter),
|
||||
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "effecive_area_arr"]),
|
||||
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "effective_area_arr"]),
|
||||
# Raman
|
||||
Rule(["hr_w", "raman_fraction"], fiber.delayed_raman_w),
|
||||
Rule("raman_fraction", fiber.raman_fraction),
|
||||
|
||||
@@ -13,7 +13,6 @@ import numpy as np
|
||||
|
||||
from scgenerator import utils
|
||||
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
||||
from scgenerator.errors import EvaluatorError
|
||||
from scgenerator.evaluator import Evaluator
|
||||
from scgenerator.operators import Qualifier, SpecOperator
|
||||
from scgenerator.utils import update_path_name
|
||||
@@ -256,7 +255,8 @@ class Parameter:
|
||||
if instance._frozen:
|
||||
raise AttributeError("Parameters instance is frozen and can no longer be modified")
|
||||
|
||||
if isinstance(value, Parameter) and self.default is not None:
|
||||
if isinstance(value, Parameter):
|
||||
if self.default is not None:
|
||||
instance._param_dico[self.name] = copy(self.default)
|
||||
return
|
||||
|
||||
@@ -299,13 +299,12 @@ class Parameters:
|
||||
|
||||
# internal machinery
|
||||
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
|
||||
_evaluator: Evaluator = field(init=False, repr=False)
|
||||
_p_names: ClassVar[Set[str]] = set()
|
||||
_frozen: bool = field(init=False, default=False, repr=False)
|
||||
|
||||
# root
|
||||
name: str = Parameter(string, default="no name")
|
||||
output_path: Path = Parameter(type_checker(Path), default=Path("sc_data"), converter=Path)
|
||||
output_path: Path = Parameter(type_checker(Path), converter=Path)
|
||||
|
||||
# fiber
|
||||
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
|
||||
@@ -433,10 +432,6 @@ class Parameters:
|
||||
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
||||
version: str = Parameter(string)
|
||||
|
||||
def __post_init__(self):
|
||||
self._evaluator = Evaluator.default(self.full_field)
|
||||
self._evaluator.set(self._param_dico)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Parameter(" + ", ".join(self.__repr_list__()) + ")"
|
||||
|
||||
@@ -455,23 +450,31 @@ class Parameters:
|
||||
setattr(self, k, v)
|
||||
self.__post_init__()
|
||||
|
||||
def dump_dict(self, compute=True, add_metadata=True) -> dict[str, Any]:
|
||||
if compute:
|
||||
self.compute_in_place()
|
||||
def get_evaluator(self):
|
||||
evaluator = Evaluator.default(self.full_field)
|
||||
evaluator.set(self._param_dico.copy())
|
||||
return evaluator
|
||||
|
||||
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
|
||||
param = Parameters.strip_params_dict(self._param_dico)
|
||||
if add_metadata:
|
||||
param["datetime"] = datetime_module.datetime.now()
|
||||
param["version"] = __version__
|
||||
return param
|
||||
|
||||
def compute_in_place(self, *to_compute: str):
|
||||
if len(to_compute) == 0:
|
||||
def compile(self, exhaustive=False):
|
||||
to_compute = MANDATORY_PARAMETERS
|
||||
evaluator = self.get_evaluator()
|
||||
for k in to_compute:
|
||||
getattr(self, k)
|
||||
|
||||
def compute(self, key: str) -> Any:
|
||||
return self._evaluator.compute(key)
|
||||
evaluator.compute(k)
|
||||
if exhaustive:
|
||||
for p in self._p_names:
|
||||
if p not in evaluator.params:
|
||||
try:
|
||||
evaluator.compute(p)
|
||||
except Exception:
|
||||
pass
|
||||
return self.__class__(**{k: v for k, v in evaluator.params.items() if k in self._p_names})
|
||||
|
||||
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
|
||||
"""return a pretty formatted string describing the parameters"""
|
||||
|
||||
@@ -28,7 +28,8 @@ def lambda_for_envelope_dispersion(
|
||||
subset of l in the interpolation range with two extra values on each side
|
||||
to accomodate for taking gradients
|
||||
np.ndarray
|
||||
indices of the original l where the values are valid (i.e. without the two extra on each side)
|
||||
indices of the original l where the values are valid
|
||||
(i.e. without the two extra on each side)
|
||||
"""
|
||||
su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0]
|
||||
if l[su].min() > 1.01 * wavelength_window[0]:
|
||||
|
||||
@@ -11,20 +11,17 @@ import itertools
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from functools import cache, lru_cache
|
||||
from pathlib import Path
|
||||
from string import printable as str_printable
|
||||
from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy.ma.extras import isin
|
||||
import pkg_resources as pkg
|
||||
import tomli
|
||||
import tomli_w
|
||||
|
||||
from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, ROOT_PARAMETERS, SPEC1_FN, Z_FN
|
||||
from scgenerator.errors import DuplicateParameterError
|
||||
from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN
|
||||
from scgenerator.logger import get_logger
|
||||
|
||||
T_ = TypeVar("T_")
|
||||
@@ -91,47 +88,6 @@ class Paths:
|
||||
return os.path.join(cls.get("plots"), name)
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class SubConfig:
|
||||
fixed: dict[str, Any]
|
||||
variable: list[dict[str, list]]
|
||||
fixed_keys: set[str]
|
||||
variable_keys: set[str]
|
||||
|
||||
def __init__(self, dico: dict[str, Any]):
|
||||
dico = dico.copy()
|
||||
self.variable = conform_variable_entry(dico.pop("variable", []))
|
||||
self.fixed = dico
|
||||
self.__update
|
||||
|
||||
def __update(self):
|
||||
self.variable_keys = set()
|
||||
self.fixed_keys = set()
|
||||
for dico in self.variable:
|
||||
for key in dico:
|
||||
if key in self.variable_keys:
|
||||
raise DuplicateParameterError(f"{key} is specified twice")
|
||||
self.variable_keys.add(key)
|
||||
for key in self.fixed:
|
||||
if key in self.variable_keys:
|
||||
raise DuplicateParameterError(f"{key} is specified twice")
|
||||
self.fixed_keys.add(key)
|
||||
|
||||
def weak_update(self, other: SubConfig = None, **kwargs):
|
||||
"""similar to a dict update method put prioritizes existing values
|
||||
|
||||
Parameters
|
||||
----------
|
||||
other : SubConfig
|
||||
other obj
|
||||
"""
|
||||
if other is None:
|
||||
other = SubConfig(kwargs)
|
||||
self.fixed = other.fixed | self.fixed
|
||||
self.variable = other.variable + self.variable
|
||||
self.__update()
|
||||
|
||||
|
||||
def conform_variable_entry(d) -> list[dict[str, list]]:
|
||||
if isinstance(d, MutableMapping):
|
||||
d = [{k: v} for k, v in d.items()]
|
||||
@@ -240,45 +196,6 @@ def save_toml(path: os.PathLike, dico):
|
||||
return dico
|
||||
|
||||
|
||||
def load_config_sequence(path: os.PathLike) -> tuple[Path, list[SubConfig]]:
|
||||
"""loads a configuration file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : os.PathLike
|
||||
path to the config toml file or a directory containing config files
|
||||
|
||||
Returns
|
||||
-------
|
||||
final_path : Path
|
||||
output name of the simulation
|
||||
list[dict[str, Any]]
|
||||
one config per fiber
|
||||
|
||||
"""
|
||||
path = Path(path)
|
||||
if path.name.lower().endswith(".toml"):
|
||||
master_config_dict = _open_config(path)
|
||||
fiber_list = [SubConfig(d) for d in master_config_dict.pop("Fiber")]
|
||||
master_config = SubConfig(master_config_dict)
|
||||
else:
|
||||
master_config = SubConfig(dict(name=path.name))
|
||||
fiber_list = [SubConfig(_open_config(p)) for p in sorted(path.glob("initial_config*.toml"))]
|
||||
|
||||
if len(fiber_list) == 0:
|
||||
raise ValueError(f"No fiber in config {path}")
|
||||
for fiber in fiber_list:
|
||||
fiber.weak_update(master_config)
|
||||
if "num" not in fiber_list[0].variable_keys:
|
||||
repeat_arg = list(range(fiber_list[0].fixed.get("repeat", 1)))
|
||||
fiber_list[0].weak_update(variable=dict(num=repeat_arg))
|
||||
for p_name in ROOT_PARAMETERS:
|
||||
if any(p_name in conf.variable_keys for conf in fiber_list[1:]):
|
||||
raise ValueError(f"{p_name} should only be specified in the root or first fiber")
|
||||
configs = fiber_list
|
||||
return Path(master_config.fixed["name"]), configs
|
||||
|
||||
|
||||
@cache
|
||||
def load_material_dico(name: str) -> dict[str, Any]:
|
||||
"""loads a material dictionary
|
||||
|
||||
Reference in New Issue
Block a user