Files
scgenerator/src/scgenerator/utils/parameter.py
Benoît Sierro 7f21f18786 misc
2021-08-20 14:46:59 +02:00

497 lines
16 KiB
Python

import datetime as datetime_module
from copy import copy
from dataclasses import asdict, dataclass
from functools import lru_cache
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import numpy as np
from numpy.lib.function_base import disp
from ..const import __version__
@lru_cache
def type_checker(*types):
def _type_checker_wrapper(validator, n=None):
if isinstance(validator, str) and n is not None:
_name = validator
validator = lambda *args: None
def _type_checker_wrapped(name, n):
if not isinstance(n, types):
raise TypeError(
f"{name!r} value must be of type {' or '.join(format(t) for t in types)} "
f"instead of {type(n)}"
)
validator(name, n)
if n is None:
return _type_checker_wrapped
else:
_type_checker_wrapped(_name, n)
return _type_checker_wrapper
@type_checker(str)
def string(name, n):
if len(n) == 0:
raise ValueError(f"{name!r} must not be empty")
def in_range_excl(_min, _max):
@type_checker(float, int)
def _in_range(name, n):
if n <= _min or n >= _max:
raise ValueError(f"{name!r} must be between {_min} and {_max} (exclusive)")
return _in_range
def in_range_incl(_min, _max):
@type_checker(float, int)
def _in_range(name, n):
if n < _min or n > _max:
raise ValueError(f"{name!r} must be between {_min} and {_max} (inclusive)")
return _in_range
def boolean(name, n):
if not n is True and not n is False:
raise ValueError(f"{name!r} must be True or False")
@lru_cache
def non_negative(*types):
@type_checker(*types)
def _non_negative(name, n):
if n < 0:
raise ValueError(f"{name!r} must be non negative")
return _non_negative
@lru_cache
def positive(*types):
@type_checker(*types)
def _positive(name, n):
if n <= 0:
raise ValueError(f"{name!r} must be positive")
return _positive
@type_checker(tuple, list)
def int_pair(name, t):
invalid = len(t) != 2
for m in t:
if invalid or not isinstance(m, int):
raise ValueError(f"{name!r} must be a list or a tuple of 2 int")
@type_checker(tuple, list)
def float_pair(name, t):
invalid = len(t) != 2
for m in t:
if invalid or not isinstance(m, (int, float)):
raise ValueError(f"{name!r} must be a list or a tuple of 2 numbers")
def literal(*l):
l = set(l)
@type_checker(str)
def _string(name, s):
if not s in l:
raise ValueError(f"{name!r} must be a str in {l}")
return _string
def validator_list(validator):
"""returns a new validator that applies validator to each el of an iterable"""
@type_checker(list, tuple, np.ndarray)
def _list_validator(name, l):
for i, el in enumerate(l):
validator(name + f"[{i}]", el)
return _list_validator
def validator_or(*validators):
"""combines many validators and raises an exception only if all of them raise an exception"""
n = len(validators)
def _or_validator(name, value):
errors = []
for validator in validators:
try:
validator(name, value)
except (ValueError, TypeError) as e:
errors.append(e)
errors.sort(key=lambda el: isinstance(el, ValueError))
if len(errors) == n:
raise errors[-1]
return _or_validator
def validator_and(*validators):
def _and_validator(name, n):
for v in validators:
v(name, n)
return _and_validator
@type_checker(list, tuple, np.ndarray)
def num_list(name, l):
for i, el in enumerate(l):
type_checker(int, float)(name + f"[{i}]", el)
def func_validator(name, n):
if not callable(n):
raise TypeError(f"{name!r} must be callable")
class Parameter:
def __init__(self, validator, converter=None, default=None, display_info=None):
"""Single parameter
Parameters
----------
tpe : type
type of the paramter
validators : Callable[[str, Any], None]
signature : validator(name, value)
must raise a ValueError when value doesn't fit the criteria checked by
validator. name is passed to validator to be included in the error message
converter : Callable, optional
converts a valid value (for example, str.lower), by default None
default : callable, optional
factory function for a default value (for example, list), by default None
"""
self.validator = validator
self.converter = converter
self.default = default
self.display_info = display_info
def __set_name__(self, owner, name):
self.name = name
def __get__(self, instance, owner):
if not instance:
return self
return instance.__dict__[self.name]
def __delete__(self, instance):
del instance.__dict__[self.name]
def __set__(self, instance, value):
if isinstance(value, Parameter):
defaut = None if self.default is None else copy(self.default)
instance.__dict__[self.name] = defaut
else:
if value is not None:
self.validator(self.name, value)
if self.converter is not None:
value = self.converter(value)
instance.__dict__[self.name] = value
def display(self, num: float):
if self.display_info is None:
return str(num)
else:
fac, unit = self.display_info
num_str = format(num * fac, ".2f")
if num_str.endswith(".00"):
num_str = num_str[:-3]
return f"{num_str} {unit}"
class VariableParameter:
def __init__(self, parameterBase):
self.pbase = parameterBase
self.list_checker = type_checker(list, tuple, np.ndarray)
def __set_name__(self, owner, name):
self.name = name
def __get__(self, instance, owner):
if not instance:
return self
return instance.__dict__[self.name]
def __delete__(self, instance):
del instance.__dict__[self.name]
def __set__(self, instance, value: dict):
if isinstance(value, VariableParameter):
value = {}
else:
for k, v in value.items():
self.list_checker("variable " + k, v)
if k not in valid_variable:
raise TypeError(f"{k!r} is not a valid variable parameter")
if len(v) == 0:
raise ValueError(f"variable parameter {k!r} must not be empty")
p = getattr(self.pbase, k)
for el in v:
p.validator(k, el)
instance.__dict__[self.name] = value
valid_variable = {
"dispersion_file",
"field_file",
"loss_file",
"A_eff_file",
"beta",
"gamma",
"pitch",
"pitch_ratio",
"effective_mode_diameter",
"core_radius",
"capillary_num",
"capillary_outer_d",
"capillary_thickness",
"capillary_spacing",
"capillary_resonance_strengths",
"capillary_nested",
"he_mode",
"fit_parameters",
"input_transmission",
"n2",
"pressure",
"temperature",
"gas_name",
"plasma_density",
"peak_power",
"mean_power",
"peak_power",
"energy",
"quantum_noise",
"shape",
"wavelength",
"intensity_noise",
"width",
"t0",
"soliton_num",
"behaviors",
"raman_type",
"tolerated_error",
"step_size",
"interpolation_degree",
"ideal_gas",
}
hc_model_specific_parameters = dict(
marcatili=["core_radius", "he_mode"],
marcatili_adjusted=["core_radius", "he_mode", "fit_parameters"],
hasan=[
"core_radius",
"capillary_num",
"capillary_thickness",
"capillary_resonance_strengths",
"capillary_nested",
"capillary_spacing",
"capillary_outer_d",
],
)
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
@dataclass
class BareParams:
"""
This class defines each valid parameter's name, type and valid value but doesn't provide
any method to act on those. For that, use initialize.Params
"""
# root
name: str = Parameter(string)
prev_data_dir: str = Parameter(string)
previous_config_file: str = Parameter(string)
# # fiber
input_transmission: float = Parameter(in_range_incl(0, 1))
gamma: float = Parameter(non_negative(float, int))
n2: float = Parameter(non_negative(float, int))
loss: str = Parameter(literal("capillary"))
loss_file: str = Parameter(string)
effective_mode_diameter: float = Parameter(positive(float, int))
A_eff: float = Parameter(non_negative(float, int))
A_eff_file: str = Parameter(string)
pitch: float = Parameter(in_range_excl(0, 1e-3))
pitch_ratio: float = Parameter(in_range_excl(0, 1))
core_radius: float = Parameter(in_range_excl(0, 1e-3))
he_mode: Tuple[int, int] = Parameter(int_pair)
fit_parameters: Tuple[int, int] = Parameter(int_pair)
beta: Iterable[float] = Parameter(num_list)
dispersion_file: str = Parameter(string)
model: str = Parameter(literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"))
length: float = Parameter(non_negative(float, int))
capillary_num: int = Parameter(positive(int))
capillary_outer_d: float = Parameter(in_range_excl(0, 1e-3))
capillary_thickness: float = Parameter(in_range_excl(0, 1e-3))
capillary_spacing: float = Parameter(in_range_excl(0, 1e-3))
capillary_resonance_strengths: Iterable[float] = Parameter(num_list)
capillary_nested: int = Parameter(non_negative(int))
# gas
gas_name: str = Parameter(string, converter=str.lower)
pressure: Union[float, Iterable[float]] = Parameter(
validator_or(non_negative(float, int), num_list), display_info=(1e-5, "bar")
)
temperature: float = Parameter(positive(float, int), display_info=(1, "K"))
plasma_density: float = Parameter(non_negative(float, int))
# pulse
field_file: str = Parameter(string)
repetition_rate: float = Parameter(non_negative(float, int), display_info=(1e-6, "MHz"))
peak_power: float = Parameter(positive(float, int), display_info=(1e-3, "kW"))
mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW"))
energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ"))
soliton_num: float = Parameter(non_negative(float, int))
quantum_noise: bool = Parameter(boolean)
shape: str = Parameter(literal("gaussian", "sech"))
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm"))
intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%"))
width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
# simulation
behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss")))
parallel: bool = Parameter(boolean)
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
ideal_gas: bool = Parameter(boolean)
repeat: int = Parameter(positive(int))
t_num: int = Parameter(positive(int))
z_num: int = Parameter(positive(int))
time_window: float = Parameter(positive(float, int))
dt: float = Parameter(in_range_excl(0, 5e-15))
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3))
step_size: float = Parameter(positive(float, int))
lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9))
upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9))
interpolation_degree: int = Parameter(positive(int))
prev_sim_dir: str = Parameter(string)
recovery_last_stored: int = Parameter(non_negative(int))
worker_num: int = Parameter(positive(int))
# computed
field_0: np.ndarray = Parameter(type_checker(np.ndarray))
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
alpha: np.ndarray = Parameter(type_checker(np.ndarray))
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray))
w: np.ndarray = Parameter(type_checker(np.ndarray))
l: np.ndarray = Parameter(type_checker(np.ndarray))
w_c: np.ndarray = Parameter(type_checker(np.ndarray))
w0: float = Parameter(positive(float))
w_power_fact: np.ndarray = Parameter(validator_list(type_checker(np.ndarray)))
t: np.ndarray = Parameter(type_checker(np.ndarray))
L_D: float = Parameter(non_negative(float, int))
L_NL: float = Parameter(non_negative(float, int))
L_sol: float = Parameter(non_negative(float, int))
dynamic_dispersion: bool = Parameter(boolean)
adapt_step_size: bool = Parameter(boolean)
error_ok: float = Parameter(positive(float))
hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
beta_func: Callable[[float], List[float]] = Parameter(func_validator)
gamma_func: Callable[[float], float] = Parameter(func_validator)
interp_range: Tuple[float, float] = Parameter(float_pair)
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
version: str = Parameter(string)
def prepare_for_dump(self) -> Dict[str, Any]:
param = asdict(self)
param = BareParams.strip_params_dict(param)
param["datetime"] = datetime_module.datetime.now()
param["version"] = __version__
return param
@staticmethod
def strip_params_dict(dico: Dict[str, Any]) -> Dict[str, Any]:
"""prepares a dictionary for serialization. Some keys may not be preserved
(dropped because they take a lot of space and can be exactly reconstructed)
Parameters
----------
dico : dict
dictionary
"""
forbiden_keys = [
"w_c",
"w_power_fact",
"field_0",
"spec_0",
"w",
"t",
"z_targets",
"l",
"alpha",
"gamma_arr",
"A_eff_arr",
]
types = (np.ndarray, float, int, str, list, tuple, dict)
out = {}
for key, value in dico.items():
if key in forbiden_keys:
continue
if not isinstance(value, types):
continue
if isinstance(value, dict):
out[key] = BareParams.strip_params_dict(value)
elif isinstance(value, np.ndarray) and value.dtype == complex:
continue
else:
out[key] = value
if "variable" in out and len(out["variable"]) == 0:
del out["variable"]
return out
@dataclass
class BareConfig(BareParams):
variable: dict = VariableParameter(BareParams)
if __name__ == "__main__":
numero = type_checker(int)
@numero
def natural_number(name, n):
if n < 0:
raise ValueError(f"{name!r} must be positive")
try:
numero("a", np.arange(45))
except Exception as e:
print(e)
try:
natural_number("b", -1)
except Exception as e:
print(e)
try:
natural_number("c", 1.0)
except Exception as e:
print(e)
try:
natural_number("d", 1)
print("success !")
except Exception as e:
print(e)