forgot some fields in Params

This commit is contained in:
Benoît Sierro
2021-06-10 13:35:42 +02:00
parent 74cb057dbe
commit 03447f268d
2 changed files with 14 additions and 3 deletions

View File

@@ -69,10 +69,10 @@ class Params(BareParams):
self.adapt_step_size = True self.adapt_step_size = True
# FIBER # FIBER
self.interp_range = [ self.interp_range = (
max(self.lower_wavelength_interp_limit, units.m.inv(np.max(self.w[self.w > 0]))), max(self.lower_wavelength_interp_limit, units.m.inv(np.max(self.w[self.w > 0]))),
min(self.upper_wavelength_interp_limit, units.m.inv(np.min(self.w[self.w > 0]))), min(self.upper_wavelength_interp_limit, units.m.inv(np.min(self.w[self.w > 0]))),
] )
temp_gamma = None temp_gamma = None
if self.effective_mode_diameter is not None: if self.effective_mode_diameter is not None:

View File

@@ -89,6 +89,14 @@ def int_pair(name, t):
raise ValueError(f"{name!r} must be a list or a tuple of 2 int") raise ValueError(f"{name!r} must be a list or a tuple of 2 int")
@type_checker(tuple, list)
def float_pair(name, t):
invalid = len(t) != 2
for m in t:
if invalid or not isinstance(m, (int, float)):
raise ValueError(f"{name!r} must be a list or a tuple of 2 numbers")
def literal(*l): def literal(*l):
l = set(l) l = set(l)
@@ -103,7 +111,7 @@ def literal(*l):
def validator_list(validator): def validator_list(validator):
"""returns a new validator that applies validator to each el of an iterable""" """returns a new validator that applies validator to each el of an iterable"""
@type_checker(list, tuple) @type_checker(list, tuple, np.ndarray)
def _list_validator(name, l): def _list_validator(name, l):
for i, el in enumerate(l): for i, el in enumerate(l):
validator(name + f"[{i}]", el) validator(name + f"[{i}]", el)
@@ -358,6 +366,8 @@ class BareParams:
spec_0: np.ndarray = Parameter(type_checker(np.ndarray)) spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
w: np.ndarray = Parameter(type_checker(np.ndarray)) w: np.ndarray = Parameter(type_checker(np.ndarray))
w_c: np.ndarray = Parameter(type_checker(np.ndarray)) w_c: np.ndarray = Parameter(type_checker(np.ndarray))
w0: float = Parameter(positive(float))
w_power_fact: np.ndarray = Parameter(validator_list(type_checker(np.ndarray)))
t: np.ndarray = Parameter(type_checker(np.ndarray)) t: np.ndarray = Parameter(type_checker(np.ndarray))
L_D: float = Parameter(non_negative(float, int)) L_D: float = Parameter(non_negative(float, int))
L_NL: float = Parameter(non_negative(float, int)) L_NL: float = Parameter(non_negative(float, int))
@@ -370,6 +380,7 @@ class BareParams:
const_qty: np.ndarray = Parameter(type_checker(np.ndarray)) const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
beta_func: Callable[[float], List[float]] = Parameter(func_validator) beta_func: Callable[[float], List[float]] = Parameter(func_validator)
gamma_func: Callable[[float], float] = Parameter(func_validator) gamma_func: Callable[[float], float] = Parameter(func_validator)
interp_range: Tuple[float, float] = Parameter(float_pair)
def prepare_for_dump(self) -> Dict[str, Any]: def prepare_for_dump(self) -> Dict[str, Any]:
param = asdict(self) param = asdict(self)