reworked parameter compilation
This commit is contained in:
@@ -27,9 +27,12 @@ PARAM_FN = "params.toml"
|
|||||||
PARAM_SEPARATOR = " "
|
PARAM_SEPARATOR = " "
|
||||||
|
|
||||||
|
|
||||||
MANDATORY_PARAMETERS = [
|
MANDATORY_PARAMETERS = {
|
||||||
"name",
|
"name",
|
||||||
"w",
|
"w",
|
||||||
|
"t",
|
||||||
|
"fft",
|
||||||
|
"ifft",
|
||||||
"w0",
|
"w0",
|
||||||
"spec_0",
|
"spec_0",
|
||||||
"field_0",
|
"field_0",
|
||||||
@@ -39,16 +42,15 @@ MANDATORY_PARAMETERS = [
|
|||||||
"length",
|
"length",
|
||||||
"adapt_step_size",
|
"adapt_step_size",
|
||||||
"tolerated_error",
|
"tolerated_error",
|
||||||
"recovery_last_stored",
|
|
||||||
"output_path",
|
|
||||||
"repeat",
|
"repeat",
|
||||||
"linear_operator",
|
"linear_operator",
|
||||||
"nonlinear_operator",
|
"nonlinear_operator",
|
||||||
]
|
"soliton_length",
|
||||||
|
"nonlinear_length",
|
||||||
|
"dispersion_length",
|
||||||
|
}
|
||||||
|
|
||||||
ROOT_PARAMETERS = [
|
ROOT_PARAMETERS = [
|
||||||
"repeat",
|
|
||||||
"num",
|
|
||||||
"dt",
|
"dt",
|
||||||
"t_num",
|
"t_num",
|
||||||
"time_window",
|
"time_window",
|
||||||
|
|||||||
@@ -28,10 +28,6 @@ class MissingParameterError(Exception):
|
|||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class DuplicateParameterError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class IncompleteDataFolderError(FileNotFoundError):
|
class IncompleteDataFolderError(FileNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,10 @@ class Rule:
|
|||||||
return f"Rule(targets={self.targets!r}, func={self.func_name}, args={self.args!r})"
|
return f"Rule(targets={self.targets!r}, func={self.func_name}, args={self.args!r})"
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"[{', '.join(self.args)}] -- {self.func.__module__}.{self.func.__name__} --> [{', '.join(self.targets)}]"
|
return (
|
||||||
|
f"[{', '.join(self.args)}] -- {self.func.__module__}."
|
||||||
|
f"{self.func.__name__} --> [{', '.join(self.targets)}]"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def func_name(self) -> str:
|
def func_name(self) -> str:
|
||||||
@@ -192,8 +195,8 @@ class Evaluator:
|
|||||||
if target in self.__curent_lookup:
|
if target in self.__curent_lookup:
|
||||||
raise EvaluatorError(
|
raise EvaluatorError(
|
||||||
"cyclic dependency detected : "
|
"cyclic dependency detected : "
|
||||||
f"{target!r} seems to depend on itself, "
|
f"{target!r} seems to depend on itself, please provide "
|
||||||
f"please provide a value for at least one variable in {self.__curent_lookup!r}. "
|
f"a value for at least one variable in {self.__curent_lookup!r}. "
|
||||||
+ self.attempted_rules_str(target)
|
+ self.attempted_rules_str(target)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -372,9 +375,9 @@ default_rules: list[Rule] = [
|
|||||||
Rule("effective_area", fiber.effective_area_from_diam),
|
Rule("effective_area", fiber.effective_area_from_diam),
|
||||||
Rule("effective_area", fiber.effective_area_hasan, conditions=dict(model="hasan")),
|
Rule("effective_area", fiber.effective_area_hasan, conditions=dict(model="hasan")),
|
||||||
Rule("effective_area", fiber.effective_area_from_gamma, priorities=-1),
|
Rule("effective_area", fiber.effective_area_from_gamma, priorities=-1),
|
||||||
Rule("effective_area", fiber.elfective_area_marcatili, priorities=-2),
|
Rule("effective_area", fiber.effective_area_marcatili, priorities=-2),
|
||||||
Rule("effecive_area_arr", fiber.effective_area_from_V, ["core_radius", "V_eff_arr"]),
|
Rule("effective_area_arr", fiber.effective_area_from_V, ["core_radius", "V_eff_arr"]),
|
||||||
Rule("effecive_area_arr", fiber.load_custom_effective_area),
|
Rule("effective_area_arr", fiber.load_custom_effective_area),
|
||||||
Rule(
|
Rule(
|
||||||
"V_eff",
|
"V_eff",
|
||||||
fiber.V_parameter_koshiba,
|
fiber.V_parameter_koshiba,
|
||||||
@@ -396,7 +399,7 @@ default_rules: list[Rule] = [
|
|||||||
Rule("n2", lambda: 2.2e-20, priorities=-1),
|
Rule("n2", lambda: 2.2e-20, priorities=-1),
|
||||||
Rule("gamma", lambda gamma_arr: gamma_arr[0], priorities=-1),
|
Rule("gamma", lambda gamma_arr: gamma_arr[0], priorities=-1),
|
||||||
Rule("gamma", fiber.gamma_parameter),
|
Rule("gamma", fiber.gamma_parameter),
|
||||||
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "effecive_area_arr"]),
|
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "effective_area_arr"]),
|
||||||
# Raman
|
# Raman
|
||||||
Rule(["hr_w", "raman_fraction"], fiber.delayed_raman_w),
|
Rule(["hr_w", "raman_fraction"], fiber.delayed_raman_w),
|
||||||
Rule("raman_fraction", fiber.raman_fraction),
|
Rule("raman_fraction", fiber.raman_fraction),
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import numpy as np
|
|||||||
|
|
||||||
from scgenerator import utils
|
from scgenerator import utils
|
||||||
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
||||||
from scgenerator.errors import EvaluatorError
|
|
||||||
from scgenerator.evaluator import Evaluator
|
from scgenerator.evaluator import Evaluator
|
||||||
from scgenerator.operators import Qualifier, SpecOperator
|
from scgenerator.operators import Qualifier, SpecOperator
|
||||||
from scgenerator.utils import update_path_name
|
from scgenerator.utils import update_path_name
|
||||||
@@ -256,7 +255,8 @@ class Parameter:
|
|||||||
if instance._frozen:
|
if instance._frozen:
|
||||||
raise AttributeError("Parameters instance is frozen and can no longer be modified")
|
raise AttributeError("Parameters instance is frozen and can no longer be modified")
|
||||||
|
|
||||||
if isinstance(value, Parameter) and self.default is not None:
|
if isinstance(value, Parameter):
|
||||||
|
if self.default is not None:
|
||||||
instance._param_dico[self.name] = copy(self.default)
|
instance._param_dico[self.name] = copy(self.default)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -299,13 +299,12 @@ class Parameters:
|
|||||||
|
|
||||||
# internal machinery
|
# internal machinery
|
||||||
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
|
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
|
||||||
_evaluator: Evaluator = field(init=False, repr=False)
|
|
||||||
_p_names: ClassVar[Set[str]] = set()
|
_p_names: ClassVar[Set[str]] = set()
|
||||||
_frozen: bool = field(init=False, default=False, repr=False)
|
_frozen: bool = field(init=False, default=False, repr=False)
|
||||||
|
|
||||||
# root
|
# root
|
||||||
name: str = Parameter(string, default="no name")
|
name: str = Parameter(string, default="no name")
|
||||||
output_path: Path = Parameter(type_checker(Path), default=Path("sc_data"), converter=Path)
|
output_path: Path = Parameter(type_checker(Path), converter=Path)
|
||||||
|
|
||||||
# fiber
|
# fiber
|
||||||
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
|
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
|
||||||
@@ -433,10 +432,6 @@ class Parameters:
|
|||||||
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
||||||
version: str = Parameter(string)
|
version: str = Parameter(string)
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self._evaluator = Evaluator.default(self.full_field)
|
|
||||||
self._evaluator.set(self._param_dico)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "Parameter(" + ", ".join(self.__repr_list__()) + ")"
|
return "Parameter(" + ", ".join(self.__repr_list__()) + ")"
|
||||||
|
|
||||||
@@ -455,23 +450,31 @@ class Parameters:
|
|||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
self.__post_init__()
|
self.__post_init__()
|
||||||
|
|
||||||
def dump_dict(self, compute=True, add_metadata=True) -> dict[str, Any]:
|
def get_evaluator(self):
|
||||||
if compute:
|
evaluator = Evaluator.default(self.full_field)
|
||||||
self.compute_in_place()
|
evaluator.set(self._param_dico.copy())
|
||||||
|
return evaluator
|
||||||
|
|
||||||
|
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
|
||||||
param = Parameters.strip_params_dict(self._param_dico)
|
param = Parameters.strip_params_dict(self._param_dico)
|
||||||
if add_metadata:
|
if add_metadata:
|
||||||
param["datetime"] = datetime_module.datetime.now()
|
param["datetime"] = datetime_module.datetime.now()
|
||||||
param["version"] = __version__
|
param["version"] = __version__
|
||||||
return param
|
return param
|
||||||
|
|
||||||
def compute_in_place(self, *to_compute: str):
|
def compile(self, exhaustive=False):
|
||||||
if len(to_compute) == 0:
|
|
||||||
to_compute = MANDATORY_PARAMETERS
|
to_compute = MANDATORY_PARAMETERS
|
||||||
|
evaluator = self.get_evaluator()
|
||||||
for k in to_compute:
|
for k in to_compute:
|
||||||
getattr(self, k)
|
evaluator.compute(k)
|
||||||
|
if exhaustive:
|
||||||
def compute(self, key: str) -> Any:
|
for p in self._p_names:
|
||||||
return self._evaluator.compute(key)
|
if p not in evaluator.params:
|
||||||
|
try:
|
||||||
|
evaluator.compute(p)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return self.__class__(**{k: v for k, v in evaluator.params.items() if k in self._p_names})
|
||||||
|
|
||||||
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
|
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
|
||||||
"""return a pretty formatted string describing the parameters"""
|
"""return a pretty formatted string describing the parameters"""
|
||||||
|
|||||||
@@ -28,7 +28,8 @@ def lambda_for_envelope_dispersion(
|
|||||||
subset of l in the interpolation range with two extra values on each side
|
subset of l in the interpolation range with two extra values on each side
|
||||||
to accomodate for taking gradients
|
to accomodate for taking gradients
|
||||||
np.ndarray
|
np.ndarray
|
||||||
indices of the original l where the values are valid (i.e. without the two extra on each side)
|
indices of the original l where the values are valid
|
||||||
|
(i.e. without the two extra on each side)
|
||||||
"""
|
"""
|
||||||
su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0]
|
su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0]
|
||||||
if l[su].min() > 1.01 * wavelength_window[0]:
|
if l[su].min() > 1.01 * wavelength_window[0]:
|
||||||
|
|||||||
@@ -11,20 +11,17 @@ import itertools
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import cache, lru_cache
|
from functools import cache, lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from string import printable as str_printable
|
from string import printable as str_printable
|
||||||
from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Union
|
from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.ma.extras import isin
|
|
||||||
import pkg_resources as pkg
|
import pkg_resources as pkg
|
||||||
import tomli
|
import tomli
|
||||||
import tomli_w
|
import tomli_w
|
||||||
|
|
||||||
from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, ROOT_PARAMETERS, SPEC1_FN, Z_FN
|
from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN
|
||||||
from scgenerator.errors import DuplicateParameterError
|
|
||||||
from scgenerator.logger import get_logger
|
from scgenerator.logger import get_logger
|
||||||
|
|
||||||
T_ = TypeVar("T_")
|
T_ = TypeVar("T_")
|
||||||
@@ -91,47 +88,6 @@ class Paths:
|
|||||||
return os.path.join(cls.get("plots"), name)
|
return os.path.join(cls.get("plots"), name)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(init=False)
|
|
||||||
class SubConfig:
|
|
||||||
fixed: dict[str, Any]
|
|
||||||
variable: list[dict[str, list]]
|
|
||||||
fixed_keys: set[str]
|
|
||||||
variable_keys: set[str]
|
|
||||||
|
|
||||||
def __init__(self, dico: dict[str, Any]):
|
|
||||||
dico = dico.copy()
|
|
||||||
self.variable = conform_variable_entry(dico.pop("variable", []))
|
|
||||||
self.fixed = dico
|
|
||||||
self.__update
|
|
||||||
|
|
||||||
def __update(self):
|
|
||||||
self.variable_keys = set()
|
|
||||||
self.fixed_keys = set()
|
|
||||||
for dico in self.variable:
|
|
||||||
for key in dico:
|
|
||||||
if key in self.variable_keys:
|
|
||||||
raise DuplicateParameterError(f"{key} is specified twice")
|
|
||||||
self.variable_keys.add(key)
|
|
||||||
for key in self.fixed:
|
|
||||||
if key in self.variable_keys:
|
|
||||||
raise DuplicateParameterError(f"{key} is specified twice")
|
|
||||||
self.fixed_keys.add(key)
|
|
||||||
|
|
||||||
def weak_update(self, other: SubConfig = None, **kwargs):
|
|
||||||
"""similar to a dict update method put prioritizes existing values
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
other : SubConfig
|
|
||||||
other obj
|
|
||||||
"""
|
|
||||||
if other is None:
|
|
||||||
other = SubConfig(kwargs)
|
|
||||||
self.fixed = other.fixed | self.fixed
|
|
||||||
self.variable = other.variable + self.variable
|
|
||||||
self.__update()
|
|
||||||
|
|
||||||
|
|
||||||
def conform_variable_entry(d) -> list[dict[str, list]]:
|
def conform_variable_entry(d) -> list[dict[str, list]]:
|
||||||
if isinstance(d, MutableMapping):
|
if isinstance(d, MutableMapping):
|
||||||
d = [{k: v} for k, v in d.items()]
|
d = [{k: v} for k, v in d.items()]
|
||||||
@@ -240,45 +196,6 @@ def save_toml(path: os.PathLike, dico):
|
|||||||
return dico
|
return dico
|
||||||
|
|
||||||
|
|
||||||
def load_config_sequence(path: os.PathLike) -> tuple[Path, list[SubConfig]]:
|
|
||||||
"""loads a configuration file
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : os.PathLike
|
|
||||||
path to the config toml file or a directory containing config files
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
final_path : Path
|
|
||||||
output name of the simulation
|
|
||||||
list[dict[str, Any]]
|
|
||||||
one config per fiber
|
|
||||||
|
|
||||||
"""
|
|
||||||
path = Path(path)
|
|
||||||
if path.name.lower().endswith(".toml"):
|
|
||||||
master_config_dict = _open_config(path)
|
|
||||||
fiber_list = [SubConfig(d) for d in master_config_dict.pop("Fiber")]
|
|
||||||
master_config = SubConfig(master_config_dict)
|
|
||||||
else:
|
|
||||||
master_config = SubConfig(dict(name=path.name))
|
|
||||||
fiber_list = [SubConfig(_open_config(p)) for p in sorted(path.glob("initial_config*.toml"))]
|
|
||||||
|
|
||||||
if len(fiber_list) == 0:
|
|
||||||
raise ValueError(f"No fiber in config {path}")
|
|
||||||
for fiber in fiber_list:
|
|
||||||
fiber.weak_update(master_config)
|
|
||||||
if "num" not in fiber_list[0].variable_keys:
|
|
||||||
repeat_arg = list(range(fiber_list[0].fixed.get("repeat", 1)))
|
|
||||||
fiber_list[0].weak_update(variable=dict(num=repeat_arg))
|
|
||||||
for p_name in ROOT_PARAMETERS:
|
|
||||||
if any(p_name in conf.variable_keys for conf in fiber_list[1:]):
|
|
||||||
raise ValueError(f"{p_name} should only be specified in the root or first fiber")
|
|
||||||
configs = fiber_list
|
|
||||||
return Path(master_config.fixed["name"]), configs
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def load_material_dico(name: str) -> dict[str, Any]:
|
def load_material_dico(name: str) -> dict[str, Any]:
|
||||||
"""loads a material dictionary
|
"""loads a material dictionary
|
||||||
|
|||||||
Reference in New Issue
Block a user