solid ground work

This commit is contained in:
Benoît Sierro
2021-08-25 12:18:55 +02:00
parent e0979b39f3
commit 1f0937d840
6 changed files with 746 additions and 189 deletions

47
func_rewrite.py Normal file
View File

@@ -0,0 +1,47 @@
from typing import Callable
import inspect
import re
def get_arg_names(func: Callable) -> list[str]:
spec = inspect.getfullargspec(func)
args = spec.args
if spec.defaults is not None and len(spec.defaults) > 0:
args = args[: -len(spec.defaults)]
return args
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
return scope[tmp_name]
def lol(a, b=None, c=None):
print(f"{a=}, {b=}, {c=}")
def main():
lol1 = func_rewrite(lol, ["c"])
print(inspect.getfullargspec(lol1))
lol2 = func_rewrite(lol, ["b"])
print(inspect.getfullargspec(lol2))
if __name__ == "__main__":
main()

View File

@@ -14,9 +14,11 @@ from .errors import *
from .logger import get_logger
from .math import power_fact
from .physics import fiber, pulse, units
from .utils import override_config, required_simulations
from .utils import override_config, required_simulations, evaluator
from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
global_evaluator = evaluator.Evaluator()
@dataclass
class Params(BareParams):
@@ -541,65 +543,65 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]:
return previous.name, count_variations(*configs)
def wspace(t, t_num=0):
"""frequency array such that x(t) <-> np.fft(x)(w)
Parameters
----------
t : float or array
float : total width of the time window
array : time array
t_num : int-
if t is a float, specifies the number of points
Returns
----------
w : array
linspace of frencies corresponding to t
"""
if isinstance(t, (np.ndarray, list, tuple)):
dt = t[1] - t[0]
t_num = len(t)
t = t[-1] - t[0] + dt
else:
dt = t / t_num
w = 2 * pi * np.arange(t_num) / t
w = np.where(w >= pi / dt, w - 2 * pi / dt, w)
return w
# def wspace(t, t_num=0):
# """frequency array such that x(t) <-> np.fft(x)(w)
# Parameters
# ----------
# t : float or array
# float : total width of the time window
# array : time array
# t_num : int-
# if t is a float, specifies the number of points
# Returns
# ----------
# w : array
# linspace of frencies corresponding to t
# """
# if isinstance(t, (np.ndarray, list, tuple)):
# dt = t[1] - t[0]
# t_num = len(t)
# t = t[-1] - t[0] + dt
# else:
# dt = t / t_num
# w = 2 * pi * np.arange(t_num) / t
# w = np.where(w >= pi / dt, w - 2 * pi / dt, w)
# return w
def tspace(time_window=None, t_num=None, dt=None):
"""returns a time array centered on 0
Parameters
----------
time_window : float
total time spanned
t_num : int
number of points
dt : float
time resolution
# def tspace(time_window=None, t_num=None, dt=None):
# """returns a time array centered on 0
# Parameters
# ----------
# time_window : float
# total time spanned
# t_num : int
# number of points
# dt : float
# time resolution
at least 2 arguments must be given. They are prioritize as such
t_num > time_window > dt
# at least 2 arguments must be given. They are prioritize as such
# t_num > time_window > dt
Returns
-------
t : array
a linearily spaced time array
Raises
------
TypeError
missing at least 1 argument
"""
if t_num is not None:
if isinstance(time_window, (float, int)):
return np.linspace(-time_window / 2, time_window / 2, int(t_num))
elif isinstance(dt, (float, int)):
time_window = (t_num - 1) * dt
return np.linspace(-time_window / 2, time_window / 2, t_num)
elif isinstance(time_window, (float, int)) and isinstance(dt, (float, int)):
t_num = int(time_window / dt) + 1
return np.linspace(-time_window / 2, time_window / 2, t_num)
else:
raise TypeError("not enough parameter to determine time vector")
# Returns
# -------
# t : array
# a linearily spaced time array
# Raises
# ------
# TypeError
# missing at least 1 argument
# """
# if t_num is not None:
# if isinstance(time_window, (float, int)):
# return np.linspace(-time_window / 2, time_window / 2, int(t_num))
# elif isinstance(dt, (float, int)):
# time_window = (t_num - 1) * dt
# return np.linspace(-time_window / 2, time_window / 2, t_num)
# elif isinstance(time_window, (float, int)) and isinstance(dt, (float, int)):
# t_num = int(time_window / dt) + 1
# return np.linspace(-time_window / 2, time_window / 2, t_num)
# else:
# raise TypeError("not enough parameter to determine time vector")
def recover_params(params: BareParams, data_folder: Path) -> Params:
@@ -620,115 +622,115 @@ def recover_params(params: BareParams, data_folder: Path) -> Params:
return params
def build_sim_grid(
length: float,
z_num: int,
wavelength: float,
deg: int,
time_window: float = None,
t_num: int = None,
dt: float = None,
) -> tuple[
np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray
]:
"""computes a bunch of values that relate to the simulation grid
# def build_sim_grid(
# length: float,
# z_num: int,
# wavelength: float,
# deg: int,
# time_window: float = None,
# t_num: int = None,
# dt: float = None,
# ) -> tuple[
# np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray
# ]:
# """computes a bunch of values that relate to the simulation grid
Parameters
----------
length : float
length of the fiber in m
z_num : int
number of spatial points
wavelength : float
pump wavelength in m
deg : int
dispersion interpolation degree
time_window : float, optional
total width of the temporal grid in s, by default None
t_num : int, optional
number of temporal grid points, by default None
dt : float, optional
spacing of the temporal grid in s, by default None
# Parameters
# ----------
# length : float
# length of the fiber in m
# z_num : int
# number of spatial points
# wavelength : float
# pump wavelength in m
# deg : int
# dispersion interpolation degree
# time_window : float, optional
# total width of the temporal grid in s, by default None
# t_num : int, optional
# number of temporal grid points, by default None
# dt : float, optional
# spacing of the temporal grid in s, by default None
Returns
-------
z_targets : np.ndarray, shape (z_num, )
spatial points in m
t : np.ndarray, shape (t_num, )
temporal points in s
time_window : float
total width of the temporal grid in s, by default None
t_num : int
number of temporal grid points, by default None
dt : float
spacing of the temporal grid in s, by default None
w_c : np.ndarray, shape (t_num, )
centered angular frequencies in rad/s where 0 is the pump frequency
w0 : float
pump angular frequency
w : np.ndarray, shape (t_num, )
actual angualr frequency grid in rad/s
w_power_fact : np.ndarray, shape (deg, t_num)
set of all the necessaray powers of w_c
l : np.ndarray, shape (t_num)
wavelengths in m
"""
t = tspace(time_window, t_num, dt)
# Returns
# -------
# z_targets : np.ndarray, shape (z_num, )
# spatial points in m
# t : np.ndarray, shape (t_num, )
# temporal points in s
# time_window : float
# total width of the temporal grid in s, by default None
# t_num : int
# number of temporal grid points, by default None
# dt : float
# spacing of the temporal grid in s, by default None
# w_c : np.ndarray, shape (t_num, )
# centered angular frequencies in rad/s where 0 is the pump frequency
# w0 : float
# pump angular frequency
# w : np.ndarray, shape (t_num, )
# actual angualr frequency grid in rad/s
# w_power_fact : np.ndarray, shape (deg, t_num)
# set of all the necessaray powers of w_c
# l : np.ndarray, shape (t_num)
# wavelengths in m
# """
# t = tspace(time_window, t_num, dt)
time_window = t.max() - t.min()
dt = t[1] - t[0]
t_num = len(t)
z_targets = np.linspace(0, length, z_num)
w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, deg)
l = units.To.m(w)
return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l
# time_window = t.max() - t.min()
# dt = t[1] - t[0]
# t_num = len(t)
# z_targets = np.linspace(0, length, z_num)
# w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, deg)
# l = units.To.m(w)
# return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l
def build_sim_grid_in_place(params: BareParams):
"""similar to calling build_sim_grid, but sets the attributes in place"""
(
params.z_targets,
params.t,
params.time_window,
params.t_num,
params.dt,
params.w_c,
params.w0,
params.w,
params.w_power_fact,
params.l,
) = build_sim_grid(
params.length,
params.z_num,
params.wavelength,
params.interpolation_degree,
params.time_window,
params.t_num,
params.dt,
)
# def build_sim_grid_in_place(params: BareParams):
# """similar to calling build_sim_grid, but sets the attributes in place"""
# (
# params.z_targets,
# params.t,
# params.time_window,
# params.t_num,
# params.dt,
# params.w_c,
# params.w0,
# params.w,
# params.w_power_fact,
# params.l,
# ) = build_sim_grid(
# params.length,
# params.z_num,
# params.wavelength,
# params.interpolation_degree,
# params.time_window,
# params.t_num,
# params.dt,
# )
def update_frequency_domain(
t: np.ndarray, wavelength: float, deg: int
) -> Tuple[np.ndarray, float, np.ndarray, np.ndarray]:
"""updates the frequency grid
# def update_frequency_domain(
# t: np.ndarray, wavelength: float, deg: int
# ) -> Tuple[np.ndarray, float, np.ndarray, np.ndarray]:
# """updates the frequency grid
Parameters
----------
t : np.ndarray
time array
wavelength : float
wavelength
deg : int
interpolation degree of the dispersion
# Parameters
# ----------
# t : np.ndarray
# time array
# wavelength : float
# wavelength
# deg : int
# interpolation degree of the dispersion
Returns
-------
Tuple[np.ndarray, float, np.ndarray, np.ndarray]
w_c, w0, w, w_power_fact
"""
w_c = wspace(t)
w0 = units.m(wavelength)
w = w_c + w0
w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)])
return w_c, w0, w, w_power_fact
# Returns
# -------
# Tuple[np.ndarray, float, np.ndarray, np.ndarray]
# w_c, w0, w, w_power_fact
# """
# w_c = wspace(t)
# w0 = units.m(wavelength)
# w = w_c + w0
# w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)])
# return w_c, w0, w, w_power_fact

View File

@@ -5,6 +5,9 @@ from scipy.interpolate import griddata, interp1d
from scipy.special import jn_zeros
from .utils.cache import np_cache
pi = np.pi
c = 299792458.0
def span(*vec):
"""returns the min and max of whatever array-like is given. can accept many args"""
@@ -218,3 +221,154 @@ def all_zeros(x: np.ndarray, y: np.ndarray) -> np.ndarray:
pos = np.argwhere(y[1:] * y[:-1] < 0)[:, 0]
m = (y[pos] - y[pos - 1]) / (x[pos] - x[pos - 1])
return -y[pos] / m + x[pos]
def wspace(t, t_num=0):
"""frequency array such that x(t) <-> np.fft(x)(w)
Parameters
----------
t : float or array
float : total width of the time window
array : time array
t_num : int-
if t is a float, specifies the number of points
Returns
----------
w : array
linspace of frencies corresponding to t
"""
if isinstance(t, (np.ndarray, list, tuple)):
dt = t[1] - t[0]
t_num = len(t)
t = t[-1] - t[0] + dt
else:
dt = t / t_num
w = 2 * pi * np.arange(t_num) / t
w = np.where(w >= pi / dt, w - 2 * pi / dt, w)
return w
def tspace(time_window=None, t_num=None, dt=None):
"""returns a time array centered on 0
Parameters
----------
time_window : float
total time spanned
t_num : int
number of points
dt : float
time resolution
at least 2 arguments must be given. They are prioritize as such
t_num > time_window > dt
Returns
-------
t : array
a linearily spaced time array
Raises
------
TypeError
missing at least 1 argument
"""
if t_num is not None:
if isinstance(time_window, (float, int)):
return np.linspace(-time_window / 2, time_window / 2, int(t_num))
elif isinstance(dt, (float, int)):
time_window = (t_num - 1) * dt
return np.linspace(-time_window / 2, time_window / 2, t_num)
elif isinstance(time_window, (float, int)) and isinstance(dt, (float, int)):
t_num = int(time_window / dt) + 1
return np.linspace(-time_window / 2, time_window / 2, t_num)
else:
raise TypeError("not enough parameter to determine time vector")
def build_sim_grid(
length: float,
z_num: int,
wavelength: float,
interpolation_degree: int,
time_window: float = None,
t_num: int = None,
dt: float = None,
) -> tuple[
np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray
]:
"""computes a bunch of values that relate to the simulation grid
Parameters
----------
length : float
length of the fiber in m
z_num : int
number of spatial points
wavelength : float
pump wavelength in m
deg : int
dispersion interpolation degree
time_window : float, optional
total width of the temporal grid in s, by default None
t_num : int, optional
number of temporal grid points, by default None
dt : float, optional
spacing of the temporal grid in s, by default None
Returns
-------
z_targets : np.ndarray, shape (z_num, )
spatial points in m
t : np.ndarray, shape (t_num, )
temporal points in s
time_window : float
total width of the temporal grid in s, by default None
t_num : int
number of temporal grid points, by default None
dt : float
spacing of the temporal grid in s, by default None
w_c : np.ndarray, shape (t_num, )
centered angular frequencies in rad/s where 0 is the pump frequency
w0 : float
pump angular frequency
w : np.ndarray, shape (t_num, )
actual angualr frequency grid in rad/s
w_power_fact : np.ndarray, shape (deg, t_num)
set of all the necessaray powers of w_c
l : np.ndarray, shape (t_num)
wavelengths in m
"""
t = tspace(time_window, t_num, dt)
time_window = t.max() - t.min()
dt = t[1] - t[0]
t_num = len(t)
z_targets = np.linspace(0, length, z_num)
w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, interpolation_degree)
l = 2 * pi * c / w
return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l
def update_frequency_domain(
t: np.ndarray, wavelength: float, deg: int
) -> tuple[np.ndarray, float, np.ndarray, np.ndarray]:
"""updates the frequency grid
Parameters
----------
t : np.ndarray
time array
wavelength : float
wavelength
deg : int
interpolation degree of the dispersion
Returns
-------
Tuple[np.ndarray, float, np.ndarray, np.ndarray]
w_c, w0, w, w_power_fact
"""
w_c = wspace(t)
w0 = 2 * pi * c / wavelength
w = w_c + w0
w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)])
return w_c, w0, w, w_power_fact

View File

@@ -3,8 +3,6 @@ import numpy as np
import scipy.special
from scipy.integrate import cumulative_trapezoid
from scgenerator import math
from ..logger import get_logger
from . import units
from .units import NA, c, kB, me, e, hbar

View File

@@ -0,0 +1,341 @@
from collections import defaultdict
from typing import Any, Callable, Union
from typing import TypeVar, Optional
from dataclasses import dataclass
import numpy as np
import itertools
from functools import wraps
import re
from ..physics import fiber, pulse, materials
from .. import math
T = TypeVar("T")
import inspect
class Rule:
def __init__(
self,
target: Union[str, list[Optional[str]]],
func: Callable,
args: list[str] = None,
priorities: Union[int, list[int]] = None,
):
targets = list(target) if isinstance(target, (list, tuple)) else [target]
self.func = func
if priorities is None:
priorities = [1] * len(targets)
elif isinstance(priorities, (int, float, np.integer, np.floating)):
priorities = [priorities]
self.targets = dict(zip(targets, priorities))
if args is None:
args = get_arg_names(func)
self.args = args
def __repr__(self) -> str:
return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})"
@classmethod
def deduce(
cls,
target: Union[str, list[Optional[str]]],
func: Callable,
kwarg_names: list[str],
n_var: int,
args_const: list[str] = None,
) -> list["Rule"]:
"""given a function that doesn't need all its keyword arguemtn specified, will
return a list of Rule obj, one for each combination of n_var specified kwargs
Parameters
----------
target : str | list[str | None]
name of the variable(s) that func returns
func : Callable
function to work with
kwarg_names : list[str]
list of all kwargs of the function to be used
n_var : int
how many shoulf be used per rule
arg_const : list[str], optional
override the name of the positional arguments
Returns
-------
list[Rule]
list of all possible rules
Example
-------
>> def lol(a, b=None, c=None):
pass
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
[
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
]
"""
rules: list[cls] = []
for var_possibility in itertools.combinations(kwarg_names, n_var):
new_func = func_rewrite(func, list(var_possibility), args_const)
rules.append(cls(target, new_func))
return rules
@dataclass
class EvalStat:
priority: float = np.inf
class Evaluator:
def __init__(self):
self.rules: dict[str, list[Rule]] = defaultdict(list)
self.params = {}
self.__curent_lookup = set()
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
def append(self, *rule: Rule):
for r in rule:
for t in r.targets:
if t is not None:
self.rules[t].append(r)
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
def update(self, **params: Any):
self.params.update(params)
for k in params:
self.eval_stats[k].priority = np.inf
def reset(self):
self.params = {}
self.eval_stats = defaultdict(EvalStat)
def compute(self, target: str) -> Any:
"""computes a target
Parameters
----------
target : str
name of the target
Returns
-------
Any
return type of the target function
Raises
------
RecursionError
a cyclic dependence exists
KeyError
there is no saved rule for the target
"""
value = self.params.get(target)
if value is None:
if target in self.__curent_lookup:
raise RecursionError(
"cyclic dependency detected : "
f"{target!r} seems to depend on itself, "
f"please provide a value for at least one variable in {self.__curent_lookup}"
)
else:
self.__curent_lookup.add(target)
error = None
for rule in reversed(self.rules[target]):
try:
args = [self.compute(k) for k in rule.args]
returned_values = rule.func(*args)
if len(rule.targets) == 1:
self.params[target] = returned_values
self.eval_stats[target].priority = rule.targets[target]
value = returned_values
else:
for ((k, p), v) in zip(rule.targets.items(), returned_values):
if (
k == target
or k not in self.params
or self.eval_stats[k].priority < p
):
self.params[k] = v
self.eval_stats[k] = p
if k == target:
value = v
break
except (RecursionError, KeyError) as e:
error = e
continue
if value is None and error is not None:
raise error
self.__curent_lookup.remove(target)
return value
def __call__(self, target: str, args: list[str] = None):
"""creates a wrapper that adds decorated functions to the set of rules
Parameters
----------
target : str
name of the target
args : list[str], optional
list of name of arguments. Automatically deduced from function signature if
not provided, by default None
"""
def wrapper(func):
self.append(Rule(target, func, args))
return func
return wrapper
def get_arg_names(func: Callable) -> list[str]:
spec = inspect.getfullargspec(func)
args = spec.args
if spec.defaults is not None and len(spec.defaults) > 0:
args = args[: -len(spec.defaults)]
return args
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
return scope[tmp_name]
default_rules: list[Rule] = [
*Rule.deduce(
["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "w_power_fact", "l"],
math.build_sim_grid,
["time_window", "t_num", "dt"],
2,
)
]
"""
Rule("gamma", fiber.gamma_parameter),
Rule("gamma", lambda gamma_arr: gamma_arr[0]),
Rule(["beta", "gamma", "interp_range"], fiber.PCF_dispersion),
Rule("n2"),
Rule("loss"),
Rule("loss_file"),
Rule("effective_mode_diameter"),
Rule("A_eff"),
Rule("A_eff_file"),
Rule("pitch"),
Rule("pitch_ratio"),
Rule("core_radius"),
Rule("he_mode"),
Rule("fit_parameters"),
Rule("beta"),
Rule("dispersion_file"),
Rule("model"),
Rule("length"),
Rule("capillary_num"),
Rule("capillary_outer_d"),
Rule("capillary_thickness"),
Rule("capillary_spacing"),
Rule("capillary_resonance_strengths"),
Rule("capillary_nested"),
Rule("gas_name"),
Rule("pressure"),
Rule("temperature"),
Rule("plasma_density"),
Rule("field_file"),
Rule("repetition_rate"),
Rule("peak_power"),
Rule("mean_power"),
Rule("energy"),
Rule("soliton_num"),
Rule("quantum_noise"),
Rule("shape"),
Rule("wavelength"),
Rule("intensity_noise"),
Rule("width"),
Rule("t0"),
Rule("behaviors"),
Rule("parallel"),
Rule("raman_type"),
Rule("ideal_gas"),
Rule("repeat"),
Rule("t_num"),
Rule("z_num"),
Rule("time_window"),
Rule("dt"),
Rule("tolerated_error"),
Rule("step_size"),
Rule("lower_wavelength_interp_limit"),
Rule("upper_wavelength_interp_limit"),
Rule("interpolation_degree"),
Rule("prev_sim_dir"),
Rule("recovery_last_stored"),
Rule("worker_num"),
Rule("field_0"),
Rule("spec_0"),
Rule("alpha"),
Rule("gamma_arr"),
Rule("A_eff_arr"),
Rule("w"),
Rule("l"),
Rule("w_c"),
Rule("w0"),
Rule("w_power_fact"),
Rule("t"),
Rule("L_D"),
Rule("L_NL"),
Rule("L_sol"),
Rule("dynamic_dispersion"),
Rule("adapt_step_size"),
Rule("error_ok"),
Rule("hr_w"),
Rule("z_targets"),
Rule("const_qty"),
Rule("beta_func"),
Rule("gamma_func"),
Rule("interp_range"),
Rule("datetime"),
Rule("version"),
]
"""
def main():
evalor = Evaluator()
evalor.append(*default_rules)
evalor.update(
**{
"length": 1,
"z_num": 128,
"wavelength": 1500e-9,
"interpolation_degree": 8,
"t_num": 16384,
"dt": 1e-15,
}
)
evalor.compute("z_targets")
print(evalor.params.keys())
print(evalor.params["l"][evalor.params["l"] > 0].min())
if __name__ == "__main__":
main()

View File

@@ -8,6 +8,9 @@ import numpy as np
from ..const import __version__
# from .evaluator import Rule, Evaluator
# from ..physics import pulse, fiber, materials
T = TypeVar("T")
# Validator
@@ -187,7 +190,7 @@ def translate(p_name: str, p_value: T) -> tuple[str, T]:
class Parameter:
def __init__(self, validator, converter=None, default=None, display_info=None):
def __init__(self, validator, converter=None, default=None, display_info=None, rules=None):
"""Single parameter
Parameters
@@ -208,6 +211,10 @@ class Parameter:
self.converter = converter
self.default = default
self.display_info = display_info
if rules is None:
self.rules = []
else:
self.rules = rules
def __set_name__(self, owner, name):
self.name = name
@@ -344,14 +351,14 @@ class BareParams:
"""
# root
name: str = Parameter(string)
name: str = Parameter(string, default="no name")
prev_data_dir: str = Parameter(string)
previous_config_file: str = Parameter(string)
# # fiber
input_transmission: float = Parameter(in_range_incl(0, 1))
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
gamma: float = Parameter(non_negative(float, int))
n2: float = Parameter(non_negative(float, int))
n2: float = Parameter(non_negative(float, int), default=2.2e-20)
loss: str = Parameter(literal("capillary"))
loss_file: str = Parameter(string)
effective_mode_diameter: float = Parameter(positive(float, int))
@@ -360,58 +367,66 @@ class BareParams:
pitch: float = Parameter(in_range_excl(0, 1e-3))
pitch_ratio: float = Parameter(in_range_excl(0, 1))
core_radius: float = Parameter(in_range_excl(0, 1e-3))
he_mode: Tuple[int, int] = Parameter(int_pair)
fit_parameters: Tuple[int, int] = Parameter(int_pair)
he_mode: Tuple[int, int] = Parameter(int_pair, default=(1, 1))
fit_parameters: Tuple[int, int] = Parameter(int_pair, default=(0.08, 200e-9))
beta: Iterable[float] = Parameter(num_list)
dispersion_file: str = Parameter(string)
model: str = Parameter(literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"))
length: float = Parameter(non_negative(float, int))
model: str = Parameter(
literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"), default="custom"
)
length: float = Parameter(non_negative(float, int), default=1.0)
capillary_num: int = Parameter(positive(int))
capillary_outer_d: float = Parameter(in_range_excl(0, 1e-3))
capillary_thickness: float = Parameter(in_range_excl(0, 1e-3))
capillary_spacing: float = Parameter(in_range_excl(0, 1e-3))
capillary_resonance_strengths: Iterable[float] = Parameter(num_list)
capillary_nested: int = Parameter(non_negative(int))
capillary_resonance_strengths: Iterable[float] = Parameter(num_list, default=[])
capillary_nested: int = Parameter(non_negative(int), default=0)
# gas
gas_name: str = Parameter(string, converter=str.lower)
gas_name: str = Parameter(string, converter=str.lower, default="vacuum")
pressure: Union[float, Iterable[float]] = Parameter(
validator_or(non_negative(float, int), num_list), display_info=(1e-5, "bar")
validator_or(non_negative(float, int), num_list), display_info=(1e-5, "bar"), default=1e5
)
temperature: float = Parameter(positive(float, int), display_info=(1, "K"))
plasma_density: float = Parameter(non_negative(float, int))
temperature: float = Parameter(positive(float, int), display_info=(1, "K"), default=300)
plasma_density: float = Parameter(non_negative(float, int), default=0)
# pulse
field_file: str = Parameter(string)
repetition_rate: float = Parameter(non_negative(float, int), display_info=(1e-6, "MHz"))
repetition_rate: float = Parameter(
non_negative(float, int), display_info=(1e-6, "MHz"), default=40e6
)
peak_power: float = Parameter(positive(float, int), display_info=(1e-3, "kW"))
mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW"))
energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ"))
soliton_num: float = Parameter(non_negative(float, int))
quantum_noise: bool = Parameter(boolean)
shape: str = Parameter(literal("gaussian", "sech"))
quantum_noise: bool = Parameter(boolean, default=False)
shape: str = Parameter(literal("gaussian", "sech"), default="gaussian")
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm"))
intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%"))
intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%"), default=0)
width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
# simulation
behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss")))
parallel: bool = Parameter(boolean)
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
ideal_gas: bool = Parameter(boolean)
repeat: int = Parameter(positive(int))
behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss")), default=["spm", "ss"])
parallel: bool = Parameter(boolean, default=True)
raman_type: str = Parameter(
literal("measured", "agrawal", "stolen"), converter=str.lower, default="agrawal"
)
ideal_gas: bool = Parameter(boolean, default=False)
repeat: int = Parameter(positive(int), default=1)
t_num: int = Parameter(positive(int))
z_num: int = Parameter(positive(int))
time_window: float = Parameter(positive(float, int))
dt: float = Parameter(in_range_excl(0, 5e-15))
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3))
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
step_size: float = Parameter(positive(float, int))
lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9))
upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9))
interpolation_degree: int = Parameter(positive(int))
lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9), default=100e-9)
upper_wavelength_interp_limit: float = Parameter(
in_range_incl(200e-9, 5000e-9), default=2000e-9
)
interpolation_degree: int = Parameter(positive(int), default=8)
prev_sim_dir: str = Parameter(string)
recovery_last_stored: int = Parameter(non_negative(int))
recovery_last_stored: int = Parameter(non_negative(int), default=0)
worker_num: int = Parameter(positive(int))
# computed