allow numpy number types in parameters

This commit is contained in:
Benoît Sierro
2023-10-17 09:08:38 +02:00
parent 159433f654
commit 967efa9d13
2 changed files with 52 additions and 49 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "scgenerator" name = "scgenerator"
version = "0.3.19" version = "0.3.20"
description = "Simulate nonlinear pulse propagation in optical fibers" description = "Simulate nonlinear pulse propagation in optical fibers"
readme = "README.md" readme = "README.md"
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]

View File

@@ -20,6 +20,9 @@ from scgenerator.physics.units import unit_formatter
T = TypeVar("T") T = TypeVar("T")
DISPLAY_FUNCTIONS: dict[str, Callable[[float], str]] = {} DISPLAY_FUNCTIONS: dict[str, Callable[[float], str]] = {}
integer = (int, np.integer)
floating = (float, np.floating)
number = (*integer, *floating)
def _format_display_info(name: str, value) -> str: def _format_display_info(name: str, value) -> str:
@@ -32,7 +35,7 @@ def _format_display_info(name: str, value) -> str:
def format_value(name: str, value) -> str: def format_value(name: str, value) -> str:
if value is True or value is False: if value is True or value is False:
return str(value) return str(value)
elif isinstance(value, (float, int)): elif isinstance(value, number):
return _format_display_info(name, value) return _format_display_info(name, value)
elif isinstance(value, np.ndarray): elif isinstance(value, np.ndarray):
return np.array2string(value) return np.array2string(value)
@@ -86,7 +89,7 @@ def low_string(name, n):
def in_range_excl(_min, _max): def in_range_excl(_min, _max):
@type_checker(float, int) @type_checker(*number)
def _in_range(name, n): def _in_range(name, n):
if n <= _min or n >= _max: if n <= _min or n >= _max:
raise ValueError(f"{name!r} must be between {_min} and {_max} (exclusive)") raise ValueError(f"{name!r} must be between {_min} and {_max} (exclusive)")
@@ -95,7 +98,7 @@ def in_range_excl(_min, _max):
def in_range_incl(_min, _max): def in_range_incl(_min, _max):
@type_checker(float, int) @type_checker(*number)
def _in_range(name, n): def _in_range(name, n):
if n < _min or n > _max: if n < _min or n > _max:
raise ValueError(f"{name!r} must be between {_min} and {_max} (inclusive)") raise ValueError(f"{name!r} must be between {_min} and {_max} (inclusive)")
@@ -137,7 +140,7 @@ def positive(*types):
def int_pair(name, t): def int_pair(name, t):
invalid = len(t) != 2 invalid = len(t) != 2
for m in t: for m in t:
if invalid or not isinstance(m, int): if invalid or not isinstance(m, integer):
raise ValueError(f"{name!r} must be a list or a tuple of 2 int. got {t!r} instead") raise ValueError(f"{name!r} must be a list or a tuple of 2 int. got {t!r} instead")
@@ -145,7 +148,7 @@ def int_pair(name, t):
def float_pair(name, t): def float_pair(name, t):
invalid = len(t) != 2 invalid = len(t) != 2
for m in t: for m in t:
if invalid or not isinstance(m, (int, float)): if invalid or not isinstance(m, number):
raise ValueError(f"{name!r} must be a list or a tuple of 2 numbers. got {t!r} instead") raise ValueError(f"{name!r} must be a list or a tuple of 2 numbers. got {t!r} instead")
@@ -201,7 +204,7 @@ def validator_and(*validators):
@type_checker(list, tuple, np.ndarray) @type_checker(list, tuple, np.ndarray)
def num_list(name, l): def num_list(name, l):
for i, el in enumerate(l): for i, el in enumerate(l):
type_checker(int, float)(name + f"[{i}]", el) type_checker(*number)(name + f"[{i}]", el)
def func_validator(name, n): def func_validator(name, n):
@@ -307,13 +310,13 @@ class Parameters:
# fiber # fiber
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0) input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
gamma: float = Parameter(non_negative(float, int)) gamma: float = Parameter(non_negative(*number))
n2: float = Parameter(non_negative(float, int)) n2: float = Parameter(non_negative(*number))
chi3: float = Parameter(non_negative(float, int)) chi3: float = Parameter(non_negative(*number))
loss: str = Parameter(literal("capillary")) loss: str = Parameter(literal("capillary"))
loss_file: DataFile = Parameter(DataFile.validate) loss_file: DataFile = Parameter(DataFile.validate)
effective_mode_diameter: float = Parameter(positive(float, int)) effective_mode_diameter: float = Parameter(positive(*number))
effective_area: float = Parameter(non_negative(float, int)) effective_area: float = Parameter(non_negative(*number))
effective_area_file: DataFile = Parameter(DataFile.validate) effective_area_file: DataFile = Parameter(DataFile.validate)
numerical_aperture: float = Parameter(in_range_excl(0, 1)) numerical_aperture: float = Parameter(in_range_excl(0, 1))
pcf_pitch: float = Parameter(in_range_excl(0, 1e-3), unit="m") pcf_pitch: float = Parameter(in_range_excl(0, 1e-3), unit="m")
@@ -326,45 +329,43 @@ class Parameters:
model: str = Parameter( model: str = Parameter(
literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"), literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"),
) )
zero_dispersion_wavelength: float = Parameter( zero_dispersion_wavelength: float = Parameter(validator_list(non_negative(*number)), unit="m")
validator_list(non_negative(float, int)), unit="m" length: float = Parameter(non_negative(*number), unit="m")
) capillary_num: int = Parameter(positive(*integer))
length: float = Parameter(non_negative(float, int), unit="m")
capillary_num: int = Parameter(positive(int))
capillary_radius: float = Parameter(in_range_excl(0, 1e-3), unit="m") capillary_radius: float = Parameter(in_range_excl(0, 1e-3), unit="m")
capillary_thickness: float = Parameter(in_range_excl(0, 1e-3), unit="m") capillary_thickness: float = Parameter(in_range_excl(0, 1e-3), unit="m")
capillary_spacing: float = Parameter(in_range_excl(0, 1e-3), unit="m") capillary_spacing: float = Parameter(in_range_excl(0, 1e-3), unit="m")
capillary_resonance_strengths: Iterable[float] = Parameter( capillary_resonance_strengths: Iterable[float] = Parameter(
validator_list(type_checker(int, float, np.ndarray)) validator_list(type_checker(int, float, np.ndarray))
) )
capillary_resonance_max_order: int = Parameter(non_negative(int), default=0) capillary_resonance_max_order: int = Parameter(non_negative(*integer), default=0)
capillary_nested: int = Parameter(non_negative(int), default=0) capillary_nested: int = Parameter(non_negative(*integer), default=0)
# gas # gas
gas_name: str = Parameter(low_string, default="vacuum") gas_name: str = Parameter(low_string, default="vacuum")
pressure: float = Parameter(non_negative(float, int), unit="bar") pressure: float = Parameter(non_negative(*number), unit="bar")
pressure_in: float = Parameter(non_negative(float, int), unit="bar") pressure_in: float = Parameter(non_negative(*number), unit="bar")
pressure_out: float = Parameter(non_negative(float, int), unit="bar") pressure_out: float = Parameter(non_negative(*number), unit="bar")
temperature: float = Parameter(positive(float, int), unit="K", default=300) temperature: float = Parameter(positive(*number), unit="K", default=300)
plasma_density: float = Parameter(non_negative(float, int), default=0) plasma_density: float = Parameter(non_negative(*number), default=0)
# pulse # pulse
field_file: DataFile = Parameter(DataFile.validate) field_file: DataFile = Parameter(DataFile.validate)
input_time: np.ndarray = Parameter(type_checker(np.ndarray)) input_time: np.ndarray = Parameter(type_checker(np.ndarray))
input_field: np.ndarray = Parameter(type_checker(np.ndarray)) input_field: np.ndarray = Parameter(type_checker(np.ndarray))
repetition_rate: float = Parameter(non_negative(float, int), unit="Hz", default=40e6) repetition_rate: float = Parameter(non_negative(*number), unit="Hz", default=40e6)
peak_power: float = Parameter(positive(float, int), unit="W") peak_power: float = Parameter(positive(*number), unit="W")
mean_power: float = Parameter(positive(float, int), unit="W") mean_power: float = Parameter(positive(*number), unit="W")
energy: float = Parameter(positive(float, int), unit="J") energy: float = Parameter(positive(*number), unit="J")
soliton_num: float = Parameter(non_negative(float, int)) soliton_num: float = Parameter(non_negative(*number))
additional_noise_factor: float = Parameter(positive(float, int), default=1) additional_noise_factor: float = Parameter(positive(*number), default=1)
shape: str = Parameter(literal("gaussian", "sech"), default="gaussian") shape: str = Parameter(literal("gaussian", "sech"), default="gaussian")
wavelength: float = Parameter(in_range_incl(100e-9, 10000e-9), unit="m") wavelength: float = Parameter(in_range_incl(100e-9, 10000e-9), unit="m")
intensity_noise: float = Parameter(in_range_incl(0, 1), unit="%", default=0) intensity_noise: float = Parameter(in_range_incl(0, 1), unit="%", default=0)
noise_correlation: float = Parameter(in_range_incl(-10, 10), default=0) noise_correlation: float = Parameter(in_range_incl(-10, 10), default=0)
width: float = Parameter(in_range_excl(0, 1e-9), unit="s") width: float = Parameter(in_range_excl(0, 1e-9), unit="s")
t0: float = Parameter(in_range_excl(0, 1e-9), unit="s") t0: float = Parameter(in_range_excl(0, 1e-9), unit="s")
delay: float = Parameter(type_checker(float, int), unit="s") delay: float = Parameter(type_checker(*number), unit="s")
# Behaviors to include # Behaviors to include
quantum_noise: bool = Parameter(boolean, default=False) quantum_noise: bool = Parameter(boolean, default=False)
@@ -378,23 +379,25 @@ class Parameters:
literal("erk43", "erk54", "cqe", "sd", "constant"), default="erk43" literal("erk43", "erk54", "cqe", "sd", "constant"), default="erk43"
) )
raman_type: str = Parameter(literal("measured", "agrawal", "stolen")) raman_type: str = Parameter(literal("measured", "agrawal", "stolen"))
raman_fraction: float = Parameter(non_negative(float, int)) raman_fraction: float = Parameter(non_negative(*number))
spm: bool = Parameter(boolean, default=True) spm: bool = Parameter(boolean, default=True)
repeat: int = Parameter(positive(int), default=1) repeat: int = Parameter(positive(*integer), default=1)
t_num: int = Parameter(positive(int), default=4096) t_num: int = Parameter(positive(*integer), default=4096)
z_num: int = Parameter(positive(int), default=128) z_num: int = Parameter(positive(*integer), default=128)
time_window: float = Parameter(positive(float, int), unit="s") time_window: float = Parameter(positive(*number), unit="s")
dt: float = Parameter(in_range_excl(0, 10e-15), unit="s") dt: float = Parameter(in_range_excl(0, 10e-15), unit="s")
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11) tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
step_size: float = Parameter(non_negative(float, int), default=0) step_size: float = Parameter(non_negative(*number), default=0)
wavelength_window: tuple[float, float] = Parameter( wavelength_window: tuple[float, float] = Parameter(
validator_and(float_pair, validator_list(in_range_incl(100e-9, 10000e-9))), unit="m" validator_and(float_pair, validator_list(in_range_incl(100e-9, 10000e-9))), unit="m"
) )
interpolation_degree: int = Parameter(validator_and(type_checker(int), in_range_incl(2, 18))) interpolation_degree: int = Parameter(
validator_and(type_checker(*integer), in_range_incl(2, 18))
)
prev_sim_dir: str = Parameter(string) prev_sim_dir: str = Parameter(string)
recovery_last_stored: int = Parameter(non_negative(int), default=0) recovery_last_stored: int = Parameter(non_negative(*integer), default=0)
parallel: bool = Parameter(boolean, default=True) parallel: bool = Parameter(boolean, default=True)
worker_num: int = Parameter(positive(int)) worker_num: int = Parameter(positive(*integer))
# computed # computed
linear_operator: VariableQuantity = Parameter(is_function, can_pickle=False) linear_operator: VariableQuantity = Parameter(is_function, can_pickle=False)
@@ -404,27 +407,27 @@ class Parameters:
ifft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function, can_pickle=False) ifft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function, can_pickle=False)
field_0: np.ndarray = Parameter(type_checker(np.ndarray)) field_0: np.ndarray = Parameter(type_checker(np.ndarray))
spec_0: np.ndarray = Parameter(type_checker(np.ndarray)) spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
beta2: float = Parameter(type_checker(int, float)) beta2: float = Parameter(type_checker(*number))
alpha_arr: np.ndarray = Parameter(type_checker(np.ndarray)) alpha_arr: np.ndarray = Parameter(type_checker(np.ndarray))
alpha: float = Parameter(non_negative(float, int)) alpha: float = Parameter(non_negative(*number))
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray)) gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
effective_area_arr: np.ndarray = Parameter(type_checker(np.ndarray)) effective_area_arr: np.ndarray = Parameter(type_checker(np.ndarray))
spectrum_factor: float = Parameter(type_checker(float)) spectrum_factor: float = Parameter(type_checker(*number))
c_to_a_factor: np.ndarray = Parameter(type_checker(float, int, np.ndarray)) c_to_a_factor: np.ndarray = Parameter(type_checker(*number, np.ndarray))
w: np.ndarray = Parameter(type_checker(np.ndarray)) w: np.ndarray = Parameter(type_checker(np.ndarray))
l: np.ndarray = Parameter(type_checker(np.ndarray)) l: np.ndarray = Parameter(type_checker(np.ndarray))
w_c: np.ndarray = Parameter(type_checker(np.ndarray)) w_c: np.ndarray = Parameter(type_checker(np.ndarray))
w0: float = Parameter(positive(float)) w0: float = Parameter(positive(*number))
t: np.ndarray = Parameter(type_checker(np.ndarray)) t: np.ndarray = Parameter(type_checker(np.ndarray))
dispersion_length: float = Parameter(non_negative(float, int), unit="m") dispersion_length: float = Parameter(non_negative(*number), unit="m")
nonlinear_length: float = Parameter(non_negative(float, int), unit="m") nonlinear_length: float = Parameter(non_negative(*number), unit="m")
soliton_length: float = Parameter(non_negative(float, int), unit="m") soliton_length: float = Parameter(non_negative(*number), unit="m")
adapt_step_size: bool = Parameter(boolean) adapt_step_size: bool = Parameter(boolean)
hr_w: np.ndarray = Parameter(type_checker(np.ndarray)) hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
z_targets: np.ndarray = Parameter(type_checker(np.ndarray)) z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
const_qty: np.ndarray = Parameter(type_checker(np.ndarray)) const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
num: int = Parameter(non_negative(int)) num: int = Parameter(non_negative(*integer))
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime)) datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
version: str = Parameter(string) version: str = Parameter(string)