allow numpy number types in parameters
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "scgenerator"
|
name = "scgenerator"
|
||||||
version = "0.3.19"
|
version = "0.3.20"
|
||||||
description = "Simulate nonlinear pulse propagation in optical fibers"
|
description = "Simulate nonlinear pulse propagation in optical fibers"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ from scgenerator.physics.units import unit_formatter
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
DISPLAY_FUNCTIONS: dict[str, Callable[[float], str]] = {}
|
DISPLAY_FUNCTIONS: dict[str, Callable[[float], str]] = {}
|
||||||
|
integer = (int, np.integer)
|
||||||
|
floating = (float, np.floating)
|
||||||
|
number = (*integer, *floating)
|
||||||
|
|
||||||
|
|
||||||
def _format_display_info(name: str, value) -> str:
|
def _format_display_info(name: str, value) -> str:
|
||||||
@@ -32,7 +35,7 @@ def _format_display_info(name: str, value) -> str:
|
|||||||
def format_value(name: str, value) -> str:
|
def format_value(name: str, value) -> str:
|
||||||
if value is True or value is False:
|
if value is True or value is False:
|
||||||
return str(value)
|
return str(value)
|
||||||
elif isinstance(value, (float, int)):
|
elif isinstance(value, number):
|
||||||
return _format_display_info(name, value)
|
return _format_display_info(name, value)
|
||||||
elif isinstance(value, np.ndarray):
|
elif isinstance(value, np.ndarray):
|
||||||
return np.array2string(value)
|
return np.array2string(value)
|
||||||
@@ -86,7 +89,7 @@ def low_string(name, n):
|
|||||||
|
|
||||||
|
|
||||||
def in_range_excl(_min, _max):
|
def in_range_excl(_min, _max):
|
||||||
@type_checker(float, int)
|
@type_checker(*number)
|
||||||
def _in_range(name, n):
|
def _in_range(name, n):
|
||||||
if n <= _min or n >= _max:
|
if n <= _min or n >= _max:
|
||||||
raise ValueError(f"{name!r} must be between {_min} and {_max} (exclusive)")
|
raise ValueError(f"{name!r} must be between {_min} and {_max} (exclusive)")
|
||||||
@@ -95,7 +98,7 @@ def in_range_excl(_min, _max):
|
|||||||
|
|
||||||
|
|
||||||
def in_range_incl(_min, _max):
|
def in_range_incl(_min, _max):
|
||||||
@type_checker(float, int)
|
@type_checker(*number)
|
||||||
def _in_range(name, n):
|
def _in_range(name, n):
|
||||||
if n < _min or n > _max:
|
if n < _min or n > _max:
|
||||||
raise ValueError(f"{name!r} must be between {_min} and {_max} (inclusive)")
|
raise ValueError(f"{name!r} must be between {_min} and {_max} (inclusive)")
|
||||||
@@ -137,7 +140,7 @@ def positive(*types):
|
|||||||
def int_pair(name, t):
|
def int_pair(name, t):
|
||||||
invalid = len(t) != 2
|
invalid = len(t) != 2
|
||||||
for m in t:
|
for m in t:
|
||||||
if invalid or not isinstance(m, int):
|
if invalid or not isinstance(m, integer):
|
||||||
raise ValueError(f"{name!r} must be a list or a tuple of 2 int. got {t!r} instead")
|
raise ValueError(f"{name!r} must be a list or a tuple of 2 int. got {t!r} instead")
|
||||||
|
|
||||||
|
|
||||||
@@ -145,7 +148,7 @@ def int_pair(name, t):
|
|||||||
def float_pair(name, t):
|
def float_pair(name, t):
|
||||||
invalid = len(t) != 2
|
invalid = len(t) != 2
|
||||||
for m in t:
|
for m in t:
|
||||||
if invalid or not isinstance(m, (int, float)):
|
if invalid or not isinstance(m, number):
|
||||||
raise ValueError(f"{name!r} must be a list or a tuple of 2 numbers. got {t!r} instead")
|
raise ValueError(f"{name!r} must be a list or a tuple of 2 numbers. got {t!r} instead")
|
||||||
|
|
||||||
|
|
||||||
@@ -201,7 +204,7 @@ def validator_and(*validators):
|
|||||||
@type_checker(list, tuple, np.ndarray)
|
@type_checker(list, tuple, np.ndarray)
|
||||||
def num_list(name, l):
|
def num_list(name, l):
|
||||||
for i, el in enumerate(l):
|
for i, el in enumerate(l):
|
||||||
type_checker(int, float)(name + f"[{i}]", el)
|
type_checker(*number)(name + f"[{i}]", el)
|
||||||
|
|
||||||
|
|
||||||
def func_validator(name, n):
|
def func_validator(name, n):
|
||||||
@@ -307,13 +310,13 @@ class Parameters:
|
|||||||
|
|
||||||
# fiber
|
# fiber
|
||||||
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
|
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
|
||||||
gamma: float = Parameter(non_negative(float, int))
|
gamma: float = Parameter(non_negative(*number))
|
||||||
n2: float = Parameter(non_negative(float, int))
|
n2: float = Parameter(non_negative(*number))
|
||||||
chi3: float = Parameter(non_negative(float, int))
|
chi3: float = Parameter(non_negative(*number))
|
||||||
loss: str = Parameter(literal("capillary"))
|
loss: str = Parameter(literal("capillary"))
|
||||||
loss_file: DataFile = Parameter(DataFile.validate)
|
loss_file: DataFile = Parameter(DataFile.validate)
|
||||||
effective_mode_diameter: float = Parameter(positive(float, int))
|
effective_mode_diameter: float = Parameter(positive(*number))
|
||||||
effective_area: float = Parameter(non_negative(float, int))
|
effective_area: float = Parameter(non_negative(*number))
|
||||||
effective_area_file: DataFile = Parameter(DataFile.validate)
|
effective_area_file: DataFile = Parameter(DataFile.validate)
|
||||||
numerical_aperture: float = Parameter(in_range_excl(0, 1))
|
numerical_aperture: float = Parameter(in_range_excl(0, 1))
|
||||||
pcf_pitch: float = Parameter(in_range_excl(0, 1e-3), unit="m")
|
pcf_pitch: float = Parameter(in_range_excl(0, 1e-3), unit="m")
|
||||||
@@ -326,45 +329,43 @@ class Parameters:
|
|||||||
model: str = Parameter(
|
model: str = Parameter(
|
||||||
literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"),
|
literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"),
|
||||||
)
|
)
|
||||||
zero_dispersion_wavelength: float = Parameter(
|
zero_dispersion_wavelength: float = Parameter(validator_list(non_negative(*number)), unit="m")
|
||||||
validator_list(non_negative(float, int)), unit="m"
|
length: float = Parameter(non_negative(*number), unit="m")
|
||||||
)
|
capillary_num: int = Parameter(positive(*integer))
|
||||||
length: float = Parameter(non_negative(float, int), unit="m")
|
|
||||||
capillary_num: int = Parameter(positive(int))
|
|
||||||
capillary_radius: float = Parameter(in_range_excl(0, 1e-3), unit="m")
|
capillary_radius: float = Parameter(in_range_excl(0, 1e-3), unit="m")
|
||||||
capillary_thickness: float = Parameter(in_range_excl(0, 1e-3), unit="m")
|
capillary_thickness: float = Parameter(in_range_excl(0, 1e-3), unit="m")
|
||||||
capillary_spacing: float = Parameter(in_range_excl(0, 1e-3), unit="m")
|
capillary_spacing: float = Parameter(in_range_excl(0, 1e-3), unit="m")
|
||||||
capillary_resonance_strengths: Iterable[float] = Parameter(
|
capillary_resonance_strengths: Iterable[float] = Parameter(
|
||||||
validator_list(type_checker(int, float, np.ndarray))
|
validator_list(type_checker(int, float, np.ndarray))
|
||||||
)
|
)
|
||||||
capillary_resonance_max_order: int = Parameter(non_negative(int), default=0)
|
capillary_resonance_max_order: int = Parameter(non_negative(*integer), default=0)
|
||||||
capillary_nested: int = Parameter(non_negative(int), default=0)
|
capillary_nested: int = Parameter(non_negative(*integer), default=0)
|
||||||
|
|
||||||
# gas
|
# gas
|
||||||
gas_name: str = Parameter(low_string, default="vacuum")
|
gas_name: str = Parameter(low_string, default="vacuum")
|
||||||
pressure: float = Parameter(non_negative(float, int), unit="bar")
|
pressure: float = Parameter(non_negative(*number), unit="bar")
|
||||||
pressure_in: float = Parameter(non_negative(float, int), unit="bar")
|
pressure_in: float = Parameter(non_negative(*number), unit="bar")
|
||||||
pressure_out: float = Parameter(non_negative(float, int), unit="bar")
|
pressure_out: float = Parameter(non_negative(*number), unit="bar")
|
||||||
temperature: float = Parameter(positive(float, int), unit="K", default=300)
|
temperature: float = Parameter(positive(*number), unit="K", default=300)
|
||||||
plasma_density: float = Parameter(non_negative(float, int), default=0)
|
plasma_density: float = Parameter(non_negative(*number), default=0)
|
||||||
|
|
||||||
# pulse
|
# pulse
|
||||||
field_file: DataFile = Parameter(DataFile.validate)
|
field_file: DataFile = Parameter(DataFile.validate)
|
||||||
input_time: np.ndarray = Parameter(type_checker(np.ndarray))
|
input_time: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
input_field: np.ndarray = Parameter(type_checker(np.ndarray))
|
input_field: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
repetition_rate: float = Parameter(non_negative(float, int), unit="Hz", default=40e6)
|
repetition_rate: float = Parameter(non_negative(*number), unit="Hz", default=40e6)
|
||||||
peak_power: float = Parameter(positive(float, int), unit="W")
|
peak_power: float = Parameter(positive(*number), unit="W")
|
||||||
mean_power: float = Parameter(positive(float, int), unit="W")
|
mean_power: float = Parameter(positive(*number), unit="W")
|
||||||
energy: float = Parameter(positive(float, int), unit="J")
|
energy: float = Parameter(positive(*number), unit="J")
|
||||||
soliton_num: float = Parameter(non_negative(float, int))
|
soliton_num: float = Parameter(non_negative(*number))
|
||||||
additional_noise_factor: float = Parameter(positive(float, int), default=1)
|
additional_noise_factor: float = Parameter(positive(*number), default=1)
|
||||||
shape: str = Parameter(literal("gaussian", "sech"), default="gaussian")
|
shape: str = Parameter(literal("gaussian", "sech"), default="gaussian")
|
||||||
wavelength: float = Parameter(in_range_incl(100e-9, 10000e-9), unit="m")
|
wavelength: float = Parameter(in_range_incl(100e-9, 10000e-9), unit="m")
|
||||||
intensity_noise: float = Parameter(in_range_incl(0, 1), unit="%", default=0)
|
intensity_noise: float = Parameter(in_range_incl(0, 1), unit="%", default=0)
|
||||||
noise_correlation: float = Parameter(in_range_incl(-10, 10), default=0)
|
noise_correlation: float = Parameter(in_range_incl(-10, 10), default=0)
|
||||||
width: float = Parameter(in_range_excl(0, 1e-9), unit="s")
|
width: float = Parameter(in_range_excl(0, 1e-9), unit="s")
|
||||||
t0: float = Parameter(in_range_excl(0, 1e-9), unit="s")
|
t0: float = Parameter(in_range_excl(0, 1e-9), unit="s")
|
||||||
delay: float = Parameter(type_checker(float, int), unit="s")
|
delay: float = Parameter(type_checker(*number), unit="s")
|
||||||
|
|
||||||
# Behaviors to include
|
# Behaviors to include
|
||||||
quantum_noise: bool = Parameter(boolean, default=False)
|
quantum_noise: bool = Parameter(boolean, default=False)
|
||||||
@@ -378,23 +379,25 @@ class Parameters:
|
|||||||
literal("erk43", "erk54", "cqe", "sd", "constant"), default="erk43"
|
literal("erk43", "erk54", "cqe", "sd", "constant"), default="erk43"
|
||||||
)
|
)
|
||||||
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"))
|
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"))
|
||||||
raman_fraction: float = Parameter(non_negative(float, int))
|
raman_fraction: float = Parameter(non_negative(*number))
|
||||||
spm: bool = Parameter(boolean, default=True)
|
spm: bool = Parameter(boolean, default=True)
|
||||||
repeat: int = Parameter(positive(int), default=1)
|
repeat: int = Parameter(positive(*integer), default=1)
|
||||||
t_num: int = Parameter(positive(int), default=4096)
|
t_num: int = Parameter(positive(*integer), default=4096)
|
||||||
z_num: int = Parameter(positive(int), default=128)
|
z_num: int = Parameter(positive(*integer), default=128)
|
||||||
time_window: float = Parameter(positive(float, int), unit="s")
|
time_window: float = Parameter(positive(*number), unit="s")
|
||||||
dt: float = Parameter(in_range_excl(0, 10e-15), unit="s")
|
dt: float = Parameter(in_range_excl(0, 10e-15), unit="s")
|
||||||
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
|
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
|
||||||
step_size: float = Parameter(non_negative(float, int), default=0)
|
step_size: float = Parameter(non_negative(*number), default=0)
|
||||||
wavelength_window: tuple[float, float] = Parameter(
|
wavelength_window: tuple[float, float] = Parameter(
|
||||||
validator_and(float_pair, validator_list(in_range_incl(100e-9, 10000e-9))), unit="m"
|
validator_and(float_pair, validator_list(in_range_incl(100e-9, 10000e-9))), unit="m"
|
||||||
)
|
)
|
||||||
interpolation_degree: int = Parameter(validator_and(type_checker(int), in_range_incl(2, 18)))
|
interpolation_degree: int = Parameter(
|
||||||
|
validator_and(type_checker(*integer), in_range_incl(2, 18))
|
||||||
|
)
|
||||||
prev_sim_dir: str = Parameter(string)
|
prev_sim_dir: str = Parameter(string)
|
||||||
recovery_last_stored: int = Parameter(non_negative(int), default=0)
|
recovery_last_stored: int = Parameter(non_negative(*integer), default=0)
|
||||||
parallel: bool = Parameter(boolean, default=True)
|
parallel: bool = Parameter(boolean, default=True)
|
||||||
worker_num: int = Parameter(positive(int))
|
worker_num: int = Parameter(positive(*integer))
|
||||||
|
|
||||||
# computed
|
# computed
|
||||||
linear_operator: VariableQuantity = Parameter(is_function, can_pickle=False)
|
linear_operator: VariableQuantity = Parameter(is_function, can_pickle=False)
|
||||||
@@ -404,27 +407,27 @@ class Parameters:
|
|||||||
ifft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function, can_pickle=False)
|
ifft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function, can_pickle=False)
|
||||||
field_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
field_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
beta2: float = Parameter(type_checker(int, float))
|
beta2: float = Parameter(type_checker(*number))
|
||||||
alpha_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
alpha_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
alpha: float = Parameter(non_negative(float, int))
|
alpha: float = Parameter(non_negative(*number))
|
||||||
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
effective_area_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
effective_area_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
spectrum_factor: float = Parameter(type_checker(float))
|
spectrum_factor: float = Parameter(type_checker(*number))
|
||||||
c_to_a_factor: np.ndarray = Parameter(type_checker(float, int, np.ndarray))
|
c_to_a_factor: np.ndarray = Parameter(type_checker(*number, np.ndarray))
|
||||||
w: np.ndarray = Parameter(type_checker(np.ndarray))
|
w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
l: np.ndarray = Parameter(type_checker(np.ndarray))
|
l: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
w_c: np.ndarray = Parameter(type_checker(np.ndarray))
|
w_c: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
w0: float = Parameter(positive(float))
|
w0: float = Parameter(positive(*number))
|
||||||
t: np.ndarray = Parameter(type_checker(np.ndarray))
|
t: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
dispersion_length: float = Parameter(non_negative(float, int), unit="m")
|
dispersion_length: float = Parameter(non_negative(*number), unit="m")
|
||||||
nonlinear_length: float = Parameter(non_negative(float, int), unit="m")
|
nonlinear_length: float = Parameter(non_negative(*number), unit="m")
|
||||||
soliton_length: float = Parameter(non_negative(float, int), unit="m")
|
soliton_length: float = Parameter(non_negative(*number), unit="m")
|
||||||
adapt_step_size: bool = Parameter(boolean)
|
adapt_step_size: bool = Parameter(boolean)
|
||||||
hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
|
hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
|
z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
|
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
|
|
||||||
num: int = Parameter(non_negative(int))
|
num: int = Parameter(non_negative(*integer))
|
||||||
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
||||||
version: str = Parameter(string)
|
version: str = Parameter(string)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user