From 03447f268d894471688fd10684182e8ef8fdf274 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Thu, 10 Jun 2021 13:35:42 +0200 Subject: [PATCH] forgot some fields in Params --- src/scgenerator/initialize.py | 4 ++-- src/scgenerator/utils/parameter.py | 13 ++++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index ea4ca66..752e85d 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -69,10 +69,10 @@ class Params(BareParams): self.adapt_step_size = True # FIBER - self.interp_range = [ + self.interp_range = ( max(self.lower_wavelength_interp_limit, units.m.inv(np.max(self.w[self.w > 0]))), min(self.upper_wavelength_interp_limit, units.m.inv(np.min(self.w[self.w > 0]))), - ] + ) temp_gamma = None if self.effective_mode_diameter is not None: diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/utils/parameter.py index b1cb124..0f317b1 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/utils/parameter.py @@ -89,6 +89,14 @@ def int_pair(name, t): raise ValueError(f"{name!r} must be a list or a tuple of 2 int") +@type_checker(tuple, list) +def float_pair(name, t): + invalid = len(t) != 2 + for m in t: + if invalid or not isinstance(m, (int, float)): + raise ValueError(f"{name!r} must be a list or a tuple of 2 numbers") + + def literal(*l): l = set(l) @@ -103,7 +111,7 @@ def literal(*l): def validator_list(validator): """returns a new validator that applies validator to each el of an iterable""" - @type_checker(list, tuple) + @type_checker(list, tuple, np.ndarray) def _list_validator(name, l): for i, el in enumerate(l): validator(name + f"[{i}]", el) @@ -358,6 +366,8 @@ class BareParams: spec_0: np.ndarray = Parameter(type_checker(np.ndarray)) w: np.ndarray = Parameter(type_checker(np.ndarray)) w_c: np.ndarray = Parameter(type_checker(np.ndarray)) + w0: float = Parameter(positive(float)) + w_power_fact: np.ndarray = Parameter(validator_list(type_checker(np.ndarray))) t: np.ndarray = Parameter(type_checker(np.ndarray)) L_D: float = Parameter(non_negative(float, int)) L_NL: float = Parameter(non_negative(float, int)) @@ -370,6 +380,7 @@ class BareParams: const_qty: np.ndarray = Parameter(type_checker(np.ndarray)) beta_func: Callable[[float], List[float]] = Parameter(func_validator) gamma_func: Callable[[float], float] = Parameter(func_validator) + interp_range: Tuple[float, float] = Parameter(float_pair) def prepare_for_dump(self) -> Dict[str, Any]: param = asdict(self)