From 967efa9d13716b3fb22589cc2ee4841e75b0bc6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 17 Oct 2023 09:08:38 +0200 Subject: [PATCH] allow numpy number types in parameters --- pyproject.toml | 2 +- src/scgenerator/parameter.py | 99 +++++++++++++++++++----------------- 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1c93e8e..7abcdd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "scgenerator" -version = "0.3.19" +version = "0.3.20" description = "Simulate nonlinear pulse propagation in optical fibers" readme = "README.md" authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 380d365..3195141 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -20,6 +20,9 @@ from scgenerator.physics.units import unit_formatter T = TypeVar("T") DISPLAY_FUNCTIONS: dict[str, Callable[[float], str]] = {} +integer = (int, np.integer) +floating = (float, np.floating) +number = (*integer, *floating) def _format_display_info(name: str, value) -> str: @@ -32,7 +35,7 @@ def _format_display_info(name: str, value) -> str: def format_value(name: str, value) -> str: if value is True or value is False: return str(value) - elif isinstance(value, (float, int)): + elif isinstance(value, number): return _format_display_info(name, value) elif isinstance(value, np.ndarray): return np.array2string(value) @@ -86,7 +89,7 @@ def low_string(name, n): def in_range_excl(_min, _max): - @type_checker(float, int) + @type_checker(*number) def _in_range(name, n): if n <= _min or n >= _max: raise ValueError(f"{name!r} must be between {_min} and {_max} (exclusive)") @@ -95,7 +98,7 @@ def in_range_excl(_min, _max): def in_range_incl(_min, _max): - @type_checker(float, int) + @type_checker(*number) def _in_range(name, n): if n < _min or n > _max: raise ValueError(f"{name!r} must be between {_min} and {_max} (inclusive)") @@ -137,7 +140,7 @@ def positive(*types): def int_pair(name, t): invalid = len(t) != 2 for m in t: - if invalid or not isinstance(m, int): + if invalid or not isinstance(m, integer): raise ValueError(f"{name!r} must be a list or a tuple of 2 int. got {t!r} instead") @@ -145,7 +148,7 @@ def int_pair(name, t): def float_pair(name, t): invalid = len(t) != 2 for m in t: - if invalid or not isinstance(m, (int, float)): + if invalid or not isinstance(m, number): raise ValueError(f"{name!r} must be a list or a tuple of 2 numbers. got {t!r} instead") @@ -201,7 +204,7 @@ def validator_and(*validators): @type_checker(list, tuple, np.ndarray) def num_list(name, l): for i, el in enumerate(l): - type_checker(int, float)(name + f"[{i}]", el) + type_checker(*number)(name + f"[{i}]", el) def func_validator(name, n): @@ -307,13 +310,13 @@ class Parameters: # fiber input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0) - gamma: float = Parameter(non_negative(float, int)) - n2: float = Parameter(non_negative(float, int)) - chi3: float = Parameter(non_negative(float, int)) + gamma: float = Parameter(non_negative(*number)) + n2: float = Parameter(non_negative(*number)) + chi3: float = Parameter(non_negative(*number)) loss: str = Parameter(literal("capillary")) loss_file: DataFile = Parameter(DataFile.validate) - effective_mode_diameter: float = Parameter(positive(float, int)) - effective_area: float = Parameter(non_negative(float, int)) + effective_mode_diameter: float = Parameter(positive(*number)) + effective_area: float = Parameter(non_negative(*number)) effective_area_file: DataFile = Parameter(DataFile.validate) numerical_aperture: float = Parameter(in_range_excl(0, 1)) pcf_pitch: float = Parameter(in_range_excl(0, 1e-3), unit="m") @@ -326,45 +329,43 @@ class Parameters: model: str = Parameter( literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"), ) - zero_dispersion_wavelength: float = Parameter( - validator_list(non_negative(float, int)), unit="m" - ) - length: float = Parameter(non_negative(float, int), unit="m") - capillary_num: int = Parameter(positive(int)) + zero_dispersion_wavelength: float = Parameter(validator_list(non_negative(*number)), unit="m") + length: float = Parameter(non_negative(*number), unit="m") + capillary_num: int = Parameter(positive(*integer)) capillary_radius: float = Parameter(in_range_excl(0, 1e-3), unit="m") capillary_thickness: float = Parameter(in_range_excl(0, 1e-3), unit="m") capillary_spacing: float = Parameter(in_range_excl(0, 1e-3), unit="m") capillary_resonance_strengths: Iterable[float] = Parameter( validator_list(type_checker(int, float, np.ndarray)) ) - capillary_resonance_max_order: int = Parameter(non_negative(int), default=0) - capillary_nested: int = Parameter(non_negative(int), default=0) + capillary_resonance_max_order: int = Parameter(non_negative(*integer), default=0) + capillary_nested: int = Parameter(non_negative(*integer), default=0) # gas gas_name: str = Parameter(low_string, default="vacuum") - pressure: float = Parameter(non_negative(float, int), unit="bar") - pressure_in: float = Parameter(non_negative(float, int), unit="bar") - pressure_out: float = Parameter(non_negative(float, int), unit="bar") - temperature: float = Parameter(positive(float, int), unit="K", default=300) - plasma_density: float = Parameter(non_negative(float, int), default=0) + pressure: float = Parameter(non_negative(*number), unit="bar") + pressure_in: float = Parameter(non_negative(*number), unit="bar") + pressure_out: float = Parameter(non_negative(*number), unit="bar") + temperature: float = Parameter(positive(*number), unit="K", default=300) + plasma_density: float = Parameter(non_negative(*number), default=0) # pulse field_file: DataFile = Parameter(DataFile.validate) input_time: np.ndarray = Parameter(type_checker(np.ndarray)) input_field: np.ndarray = Parameter(type_checker(np.ndarray)) - repetition_rate: float = Parameter(non_negative(float, int), unit="Hz", default=40e6) - peak_power: float = Parameter(positive(float, int), unit="W") - mean_power: float = Parameter(positive(float, int), unit="W") - energy: float = Parameter(positive(float, int), unit="J") - soliton_num: float = Parameter(non_negative(float, int)) - additional_noise_factor: float = Parameter(positive(float, int), default=1) + repetition_rate: float = Parameter(non_negative(*number), unit="Hz", default=40e6) + peak_power: float = Parameter(positive(*number), unit="W") + mean_power: float = Parameter(positive(*number), unit="W") + energy: float = Parameter(positive(*number), unit="J") + soliton_num: float = Parameter(non_negative(*number)) + additional_noise_factor: float = Parameter(positive(*number), default=1) shape: str = Parameter(literal("gaussian", "sech"), default="gaussian") wavelength: float = Parameter(in_range_incl(100e-9, 10000e-9), unit="m") intensity_noise: float = Parameter(in_range_incl(0, 1), unit="%", default=0) noise_correlation: float = Parameter(in_range_incl(-10, 10), default=0) width: float = Parameter(in_range_excl(0, 1e-9), unit="s") t0: float = Parameter(in_range_excl(0, 1e-9), unit="s") - delay: float = Parameter(type_checker(float, int), unit="s") + delay: float = Parameter(type_checker(*number), unit="s") # Behaviors to include quantum_noise: bool = Parameter(boolean, default=False) @@ -378,23 +379,25 @@ class Parameters: literal("erk43", "erk54", "cqe", "sd", "constant"), default="erk43" ) raman_type: str = Parameter(literal("measured", "agrawal", "stolen")) - raman_fraction: float = Parameter(non_negative(float, int)) + raman_fraction: float = Parameter(non_negative(*number)) spm: bool = Parameter(boolean, default=True) - repeat: int = Parameter(positive(int), default=1) - t_num: int = Parameter(positive(int), default=4096) - z_num: int = Parameter(positive(int), default=128) - time_window: float = Parameter(positive(float, int), unit="s") + repeat: int = Parameter(positive(*integer), default=1) + t_num: int = Parameter(positive(*integer), default=4096) + z_num: int = Parameter(positive(*integer), default=128) + time_window: float = Parameter(positive(*number), unit="s") dt: float = Parameter(in_range_excl(0, 10e-15), unit="s") tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11) - step_size: float = Parameter(non_negative(float, int), default=0) + step_size: float = Parameter(non_negative(*number), default=0) wavelength_window: tuple[float, float] = Parameter( validator_and(float_pair, validator_list(in_range_incl(100e-9, 10000e-9))), unit="m" ) - interpolation_degree: int = Parameter(validator_and(type_checker(int), in_range_incl(2, 18))) + interpolation_degree: int = Parameter( + validator_and(type_checker(*integer), in_range_incl(2, 18)) + ) prev_sim_dir: str = Parameter(string) - recovery_last_stored: int = Parameter(non_negative(int), default=0) + recovery_last_stored: int = Parameter(non_negative(*integer), default=0) parallel: bool = Parameter(boolean, default=True) - worker_num: int = Parameter(positive(int)) + worker_num: int = Parameter(positive(*integer)) # computed linear_operator: VariableQuantity = Parameter(is_function, can_pickle=False) @@ -404,27 +407,27 @@ class Parameters: ifft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function, can_pickle=False) field_0: np.ndarray = Parameter(type_checker(np.ndarray)) spec_0: np.ndarray = Parameter(type_checker(np.ndarray)) - beta2: float = Parameter(type_checker(int, float)) + beta2: float = Parameter(type_checker(*number)) alpha_arr: np.ndarray = Parameter(type_checker(np.ndarray)) - alpha: float = Parameter(non_negative(float, int)) + alpha: float = Parameter(non_negative(*number)) gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray)) effective_area_arr: np.ndarray = Parameter(type_checker(np.ndarray)) - spectrum_factor: float = Parameter(type_checker(float)) - c_to_a_factor: np.ndarray = Parameter(type_checker(float, int, np.ndarray)) + spectrum_factor: float = Parameter(type_checker(*number)) + c_to_a_factor: np.ndarray = Parameter(type_checker(*number, np.ndarray)) w: np.ndarray = Parameter(type_checker(np.ndarray)) l: np.ndarray = Parameter(type_checker(np.ndarray)) w_c: np.ndarray = Parameter(type_checker(np.ndarray)) - w0: float = Parameter(positive(float)) + w0: float = Parameter(positive(*number)) t: np.ndarray = Parameter(type_checker(np.ndarray)) - dispersion_length: float = Parameter(non_negative(float, int), unit="m") - nonlinear_length: float = Parameter(non_negative(float, int), unit="m") - soliton_length: float = Parameter(non_negative(float, int), unit="m") + dispersion_length: float = Parameter(non_negative(*number), unit="m") + nonlinear_length: float = Parameter(non_negative(*number), unit="m") + soliton_length: float = Parameter(non_negative(*number), unit="m") adapt_step_size: bool = Parameter(boolean) hr_w: np.ndarray = Parameter(type_checker(np.ndarray)) z_targets: np.ndarray = Parameter(type_checker(np.ndarray)) const_qty: np.ndarray = Parameter(type_checker(np.ndarray)) - num: int = Parameter(non_negative(int)) + num: int = Parameter(non_negative(*integer)) datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime)) version: str = Parameter(string)