solid ground work

This commit is contained in:
Benoît Sierro
2021-08-25 12:18:55 +02:00
parent e0979b39f3
commit 1f0937d840
6 changed files with 746 additions and 189 deletions

47
func_rewrite.py Normal file
View File

@@ -0,0 +1,47 @@
from typing import Callable
import inspect
import re
def get_arg_names(func: Callable) -> list[str]:
spec = inspect.getfullargspec(func)
args = spec.args
if spec.defaults is not None and len(spec.defaults) > 0:
args = args[: -len(spec.defaults)]
return args
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
return scope[tmp_name]
def lol(a, b=None, c=None):
print(f"{a=}, {b=}, {c=}")
def main():
lol1 = func_rewrite(lol, ["c"])
print(inspect.getfullargspec(lol1))
lol2 = func_rewrite(lol, ["b"])
print(inspect.getfullargspec(lol2))
if __name__ == "__main__":
main()

View File

@@ -14,9 +14,11 @@ from .errors import *
from .logger import get_logger from .logger import get_logger
from .math import power_fact from .math import power_fact
from .physics import fiber, pulse, units from .physics import fiber, pulse, units
from .utils import override_config, required_simulations from .utils import override_config, required_simulations, evaluator
from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
global_evaluator = evaluator.Evaluator()
@dataclass @dataclass
class Params(BareParams): class Params(BareParams):
@@ -541,65 +543,65 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]:
return previous.name, count_variations(*configs) return previous.name, count_variations(*configs)
def wspace(t, t_num=0): # def wspace(t, t_num=0):
"""frequency array such that x(t) <-> np.fft(x)(w) # """frequency array such that x(t) <-> np.fft(x)(w)
Parameters # Parameters
---------- # ----------
t : float or array # t : float or array
float : total width of the time window # float : total width of the time window
array : time array # array : time array
t_num : int- # t_num : int-
if t is a float, specifies the number of points # if t is a float, specifies the number of points
Returns # Returns
---------- # ----------
w : array # w : array
linspace of frencies corresponding to t # linspace of frencies corresponding to t
""" # """
if isinstance(t, (np.ndarray, list, tuple)): # if isinstance(t, (np.ndarray, list, tuple)):
dt = t[1] - t[0] # dt = t[1] - t[0]
t_num = len(t) # t_num = len(t)
t = t[-1] - t[0] + dt # t = t[-1] - t[0] + dt
else: # else:
dt = t / t_num # dt = t / t_num
w = 2 * pi * np.arange(t_num) / t # w = 2 * pi * np.arange(t_num) / t
w = np.where(w >= pi / dt, w - 2 * pi / dt, w) # w = np.where(w >= pi / dt, w - 2 * pi / dt, w)
return w # return w
def tspace(time_window=None, t_num=None, dt=None): # def tspace(time_window=None, t_num=None, dt=None):
"""returns a time array centered on 0 # """returns a time array centered on 0
Parameters # Parameters
---------- # ----------
time_window : float # time_window : float
total time spanned # total time spanned
t_num : int # t_num : int
number of points # number of points
dt : float # dt : float
time resolution # time resolution
at least 2 arguments must be given. They are prioritize as such # at least 2 arguments must be given. They are prioritize as such
t_num > time_window > dt # t_num > time_window > dt
Returns # Returns
------- # -------
t : array # t : array
a linearily spaced time array # a linearily spaced time array
Raises # Raises
------ # ------
TypeError # TypeError
missing at least 1 argument # missing at least 1 argument
""" # """
if t_num is not None: # if t_num is not None:
if isinstance(time_window, (float, int)): # if isinstance(time_window, (float, int)):
return np.linspace(-time_window / 2, time_window / 2, int(t_num)) # return np.linspace(-time_window / 2, time_window / 2, int(t_num))
elif isinstance(dt, (float, int)): # elif isinstance(dt, (float, int)):
time_window = (t_num - 1) * dt # time_window = (t_num - 1) * dt
return np.linspace(-time_window / 2, time_window / 2, t_num) # return np.linspace(-time_window / 2, time_window / 2, t_num)
elif isinstance(time_window, (float, int)) and isinstance(dt, (float, int)): # elif isinstance(time_window, (float, int)) and isinstance(dt, (float, int)):
t_num = int(time_window / dt) + 1 # t_num = int(time_window / dt) + 1
return np.linspace(-time_window / 2, time_window / 2, t_num) # return np.linspace(-time_window / 2, time_window / 2, t_num)
else: # else:
raise TypeError("not enough parameter to determine time vector") # raise TypeError("not enough parameter to determine time vector")
def recover_params(params: BareParams, data_folder: Path) -> Params: def recover_params(params: BareParams, data_folder: Path) -> Params:
@@ -620,115 +622,115 @@ def recover_params(params: BareParams, data_folder: Path) -> Params:
return params return params
def build_sim_grid( # def build_sim_grid(
length: float, # length: float,
z_num: int, # z_num: int,
wavelength: float, # wavelength: float,
deg: int, # deg: int,
time_window: float = None, # time_window: float = None,
t_num: int = None, # t_num: int = None,
dt: float = None, # dt: float = None,
) -> tuple[ # ) -> tuple[
np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray # np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray
]: # ]:
"""computes a bunch of values that relate to the simulation grid # """computes a bunch of values that relate to the simulation grid
Parameters # Parameters
---------- # ----------
length : float # length : float
length of the fiber in m # length of the fiber in m
z_num : int # z_num : int
number of spatial points # number of spatial points
wavelength : float # wavelength : float
pump wavelength in m # pump wavelength in m
deg : int # deg : int
dispersion interpolation degree # dispersion interpolation degree
time_window : float, optional # time_window : float, optional
total width of the temporal grid in s, by default None # total width of the temporal grid in s, by default None
t_num : int, optional # t_num : int, optional
number of temporal grid points, by default None # number of temporal grid points, by default None
dt : float, optional # dt : float, optional
spacing of the temporal grid in s, by default None # spacing of the temporal grid in s, by default None
Returns # Returns
------- # -------
z_targets : np.ndarray, shape (z_num, ) # z_targets : np.ndarray, shape (z_num, )
spatial points in m # spatial points in m
t : np.ndarray, shape (t_num, ) # t : np.ndarray, shape (t_num, )
temporal points in s # temporal points in s
time_window : float # time_window : float
total width of the temporal grid in s, by default None # total width of the temporal grid in s, by default None
t_num : int # t_num : int
number of temporal grid points, by default None # number of temporal grid points, by default None
dt : float # dt : float
spacing of the temporal grid in s, by default None # spacing of the temporal grid in s, by default None
w_c : np.ndarray, shape (t_num, ) # w_c : np.ndarray, shape (t_num, )
centered angular frequencies in rad/s where 0 is the pump frequency # centered angular frequencies in rad/s where 0 is the pump frequency
w0 : float # w0 : float
pump angular frequency # pump angular frequency
w : np.ndarray, shape (t_num, ) # w : np.ndarray, shape (t_num, )
actual angualr frequency grid in rad/s # actual angualr frequency grid in rad/s
w_power_fact : np.ndarray, shape (deg, t_num) # w_power_fact : np.ndarray, shape (deg, t_num)
set of all the necessaray powers of w_c # set of all the necessaray powers of w_c
l : np.ndarray, shape (t_num) # l : np.ndarray, shape (t_num)
wavelengths in m # wavelengths in m
""" # """
t = tspace(time_window, t_num, dt) # t = tspace(time_window, t_num, dt)
time_window = t.max() - t.min() # time_window = t.max() - t.min()
dt = t[1] - t[0] # dt = t[1] - t[0]
t_num = len(t) # t_num = len(t)
z_targets = np.linspace(0, length, z_num) # z_targets = np.linspace(0, length, z_num)
w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, deg) # w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, deg)
l = units.To.m(w) # l = units.To.m(w)
return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l # return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l
def build_sim_grid_in_place(params: BareParams): # def build_sim_grid_in_place(params: BareParams):
"""similar to calling build_sim_grid, but sets the attributes in place""" # """similar to calling build_sim_grid, but sets the attributes in place"""
( # (
params.z_targets, # params.z_targets,
params.t, # params.t,
params.time_window, # params.time_window,
params.t_num, # params.t_num,
params.dt, # params.dt,
params.w_c, # params.w_c,
params.w0, # params.w0,
params.w, # params.w,
params.w_power_fact, # params.w_power_fact,
params.l, # params.l,
) = build_sim_grid( # ) = build_sim_grid(
params.length, # params.length,
params.z_num, # params.z_num,
params.wavelength, # params.wavelength,
params.interpolation_degree, # params.interpolation_degree,
params.time_window, # params.time_window,
params.t_num, # params.t_num,
params.dt, # params.dt,
) # )
def update_frequency_domain( # def update_frequency_domain(
t: np.ndarray, wavelength: float, deg: int # t: np.ndarray, wavelength: float, deg: int
) -> Tuple[np.ndarray, float, np.ndarray, np.ndarray]: # ) -> Tuple[np.ndarray, float, np.ndarray, np.ndarray]:
"""updates the frequency grid # """updates the frequency grid
Parameters # Parameters
---------- # ----------
t : np.ndarray # t : np.ndarray
time array # time array
wavelength : float # wavelength : float
wavelength # wavelength
deg : int # deg : int
interpolation degree of the dispersion # interpolation degree of the dispersion
Returns # Returns
------- # -------
Tuple[np.ndarray, float, np.ndarray, np.ndarray] # Tuple[np.ndarray, float, np.ndarray, np.ndarray]
w_c, w0, w, w_power_fact # w_c, w0, w, w_power_fact
""" # """
w_c = wspace(t) # w_c = wspace(t)
w0 = units.m(wavelength) # w0 = units.m(wavelength)
w = w_c + w0 # w = w_c + w0
w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)]) # w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)])
return w_c, w0, w, w_power_fact # return w_c, w0, w, w_power_fact

View File

@@ -5,6 +5,9 @@ from scipy.interpolate import griddata, interp1d
from scipy.special import jn_zeros from scipy.special import jn_zeros
from .utils.cache import np_cache from .utils.cache import np_cache
pi = np.pi
c = 299792458.0
def span(*vec): def span(*vec):
"""returns the min and max of whatever array-like is given. can accept many args""" """returns the min and max of whatever array-like is given. can accept many args"""
@@ -218,3 +221,154 @@ def all_zeros(x: np.ndarray, y: np.ndarray) -> np.ndarray:
pos = np.argwhere(y[1:] * y[:-1] < 0)[:, 0] pos = np.argwhere(y[1:] * y[:-1] < 0)[:, 0]
m = (y[pos] - y[pos - 1]) / (x[pos] - x[pos - 1]) m = (y[pos] - y[pos - 1]) / (x[pos] - x[pos - 1])
return -y[pos] / m + x[pos] return -y[pos] / m + x[pos]
def wspace(t, t_num=0):
"""frequency array such that x(t) <-> np.fft(x)(w)
Parameters
----------
t : float or array
float : total width of the time window
array : time array
t_num : int-
if t is a float, specifies the number of points
Returns
----------
w : array
linspace of frencies corresponding to t
"""
if isinstance(t, (np.ndarray, list, tuple)):
dt = t[1] - t[0]
t_num = len(t)
t = t[-1] - t[0] + dt
else:
dt = t / t_num
w = 2 * pi * np.arange(t_num) / t
w = np.where(w >= pi / dt, w - 2 * pi / dt, w)
return w
def tspace(time_window=None, t_num=None, dt=None):
"""returns a time array centered on 0
Parameters
----------
time_window : float
total time spanned
t_num : int
number of points
dt : float
time resolution
at least 2 arguments must be given. They are prioritize as such
t_num > time_window > dt
Returns
-------
t : array
a linearily spaced time array
Raises
------
TypeError
missing at least 1 argument
"""
if t_num is not None:
if isinstance(time_window, (float, int)):
return np.linspace(-time_window / 2, time_window / 2, int(t_num))
elif isinstance(dt, (float, int)):
time_window = (t_num - 1) * dt
return np.linspace(-time_window / 2, time_window / 2, t_num)
elif isinstance(time_window, (float, int)) and isinstance(dt, (float, int)):
t_num = int(time_window / dt) + 1
return np.linspace(-time_window / 2, time_window / 2, t_num)
else:
raise TypeError("not enough parameter to determine time vector")
def build_sim_grid(
length: float,
z_num: int,
wavelength: float,
interpolation_degree: int,
time_window: float = None,
t_num: int = None,
dt: float = None,
) -> tuple[
np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray
]:
"""computes a bunch of values that relate to the simulation grid
Parameters
----------
length : float
length of the fiber in m
z_num : int
number of spatial points
wavelength : float
pump wavelength in m
deg : int
dispersion interpolation degree
time_window : float, optional
total width of the temporal grid in s, by default None
t_num : int, optional
number of temporal grid points, by default None
dt : float, optional
spacing of the temporal grid in s, by default None
Returns
-------
z_targets : np.ndarray, shape (z_num, )
spatial points in m
t : np.ndarray, shape (t_num, )
temporal points in s
time_window : float
total width of the temporal grid in s, by default None
t_num : int
number of temporal grid points, by default None
dt : float
spacing of the temporal grid in s, by default None
w_c : np.ndarray, shape (t_num, )
centered angular frequencies in rad/s where 0 is the pump frequency
w0 : float
pump angular frequency
w : np.ndarray, shape (t_num, )
actual angualr frequency grid in rad/s
w_power_fact : np.ndarray, shape (deg, t_num)
set of all the necessaray powers of w_c
l : np.ndarray, shape (t_num)
wavelengths in m
"""
t = tspace(time_window, t_num, dt)
time_window = t.max() - t.min()
dt = t[1] - t[0]
t_num = len(t)
z_targets = np.linspace(0, length, z_num)
w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, interpolation_degree)
l = 2 * pi * c / w
return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l
def update_frequency_domain(
t: np.ndarray, wavelength: float, deg: int
) -> tuple[np.ndarray, float, np.ndarray, np.ndarray]:
"""updates the frequency grid
Parameters
----------
t : np.ndarray
time array
wavelength : float
wavelength
deg : int
interpolation degree of the dispersion
Returns
-------
Tuple[np.ndarray, float, np.ndarray, np.ndarray]
w_c, w0, w, w_power_fact
"""
w_c = wspace(t)
w0 = 2 * pi * c / wavelength
w = w_c + w0
w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)])
return w_c, w0, w, w_power_fact

View File

@@ -3,8 +3,6 @@ import numpy as np
import scipy.special import scipy.special
from scipy.integrate import cumulative_trapezoid from scipy.integrate import cumulative_trapezoid
from scgenerator import math
from ..logger import get_logger from ..logger import get_logger
from . import units from . import units
from .units import NA, c, kB, me, e, hbar from .units import NA, c, kB, me, e, hbar

View File

@@ -0,0 +1,341 @@
from collections import defaultdict
from typing import Any, Callable, Union
from typing import TypeVar, Optional
from dataclasses import dataclass
import numpy as np
import itertools
from functools import wraps
import re
from ..physics import fiber, pulse, materials
from .. import math
T = TypeVar("T")
import inspect
class Rule:
def __init__(
self,
target: Union[str, list[Optional[str]]],
func: Callable,
args: list[str] = None,
priorities: Union[int, list[int]] = None,
):
targets = list(target) if isinstance(target, (list, tuple)) else [target]
self.func = func
if priorities is None:
priorities = [1] * len(targets)
elif isinstance(priorities, (int, float, np.integer, np.floating)):
priorities = [priorities]
self.targets = dict(zip(targets, priorities))
if args is None:
args = get_arg_names(func)
self.args = args
def __repr__(self) -> str:
return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})"
@classmethod
def deduce(
cls,
target: Union[str, list[Optional[str]]],
func: Callable,
kwarg_names: list[str],
n_var: int,
args_const: list[str] = None,
) -> list["Rule"]:
"""given a function that doesn't need all its keyword arguemtn specified, will
return a list of Rule obj, one for each combination of n_var specified kwargs
Parameters
----------
target : str | list[str | None]
name of the variable(s) that func returns
func : Callable
function to work with
kwarg_names : list[str]
list of all kwargs of the function to be used
n_var : int
how many shoulf be used per rule
arg_const : list[str], optional
override the name of the positional arguments
Returns
-------
list[Rule]
list of all possible rules
Example
-------
>> def lol(a, b=None, c=None):
pass
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
[
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
]
"""
rules: list[cls] = []
for var_possibility in itertools.combinations(kwarg_names, n_var):
new_func = func_rewrite(func, list(var_possibility), args_const)
rules.append(cls(target, new_func))
return rules
@dataclass
class EvalStat:
priority: float = np.inf
class Evaluator:
def __init__(self):
self.rules: dict[str, list[Rule]] = defaultdict(list)
self.params = {}
self.__curent_lookup = set()
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
def append(self, *rule: Rule):
for r in rule:
for t in r.targets:
if t is not None:
self.rules[t].append(r)
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
def update(self, **params: Any):
self.params.update(params)
for k in params:
self.eval_stats[k].priority = np.inf
def reset(self):
self.params = {}
self.eval_stats = defaultdict(EvalStat)
def compute(self, target: str) -> Any:
"""computes a target
Parameters
----------
target : str
name of the target
Returns
-------
Any
return type of the target function
Raises
------
RecursionError
a cyclic dependence exists
KeyError
there is no saved rule for the target
"""
value = self.params.get(target)
if value is None:
if target in self.__curent_lookup:
raise RecursionError(
"cyclic dependency detected : "
f"{target!r} seems to depend on itself, "
f"please provide a value for at least one variable in {self.__curent_lookup}"
)
else:
self.__curent_lookup.add(target)
error = None
for rule in reversed(self.rules[target]):
try:
args = [self.compute(k) for k in rule.args]
returned_values = rule.func(*args)
if len(rule.targets) == 1:
self.params[target] = returned_values
self.eval_stats[target].priority = rule.targets[target]
value = returned_values
else:
for ((k, p), v) in zip(rule.targets.items(), returned_values):
if (
k == target
or k not in self.params
or self.eval_stats[k].priority < p
):
self.params[k] = v
self.eval_stats[k] = p
if k == target:
value = v
break
except (RecursionError, KeyError) as e:
error = e
continue
if value is None and error is not None:
raise error
self.__curent_lookup.remove(target)
return value
def __call__(self, target: str, args: list[str] = None):
"""creates a wrapper that adds decorated functions to the set of rules
Parameters
----------
target : str
name of the target
args : list[str], optional
list of name of arguments. Automatically deduced from function signature if
not provided, by default None
"""
def wrapper(func):
self.append(Rule(target, func, args))
return func
return wrapper
def get_arg_names(func: Callable) -> list[str]:
spec = inspect.getfullargspec(func)
args = spec.args
if spec.defaults is not None and len(spec.defaults) > 0:
args = args[: -len(spec.defaults)]
return args
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
return scope[tmp_name]
default_rules: list[Rule] = [
*Rule.deduce(
["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "w_power_fact", "l"],
math.build_sim_grid,
["time_window", "t_num", "dt"],
2,
)
]
"""
Rule("gamma", fiber.gamma_parameter),
Rule("gamma", lambda gamma_arr: gamma_arr[0]),
Rule(["beta", "gamma", "interp_range"], fiber.PCF_dispersion),
Rule("n2"),
Rule("loss"),
Rule("loss_file"),
Rule("effective_mode_diameter"),
Rule("A_eff"),
Rule("A_eff_file"),
Rule("pitch"),
Rule("pitch_ratio"),
Rule("core_radius"),
Rule("he_mode"),
Rule("fit_parameters"),
Rule("beta"),
Rule("dispersion_file"),
Rule("model"),
Rule("length"),
Rule("capillary_num"),
Rule("capillary_outer_d"),
Rule("capillary_thickness"),
Rule("capillary_spacing"),
Rule("capillary_resonance_strengths"),
Rule("capillary_nested"),
Rule("gas_name"),
Rule("pressure"),
Rule("temperature"),
Rule("plasma_density"),
Rule("field_file"),
Rule("repetition_rate"),
Rule("peak_power"),
Rule("mean_power"),
Rule("energy"),
Rule("soliton_num"),
Rule("quantum_noise"),
Rule("shape"),
Rule("wavelength"),
Rule("intensity_noise"),
Rule("width"),
Rule("t0"),
Rule("behaviors"),
Rule("parallel"),
Rule("raman_type"),
Rule("ideal_gas"),
Rule("repeat"),
Rule("t_num"),
Rule("z_num"),
Rule("time_window"),
Rule("dt"),
Rule("tolerated_error"),
Rule("step_size"),
Rule("lower_wavelength_interp_limit"),
Rule("upper_wavelength_interp_limit"),
Rule("interpolation_degree"),
Rule("prev_sim_dir"),
Rule("recovery_last_stored"),
Rule("worker_num"),
Rule("field_0"),
Rule("spec_0"),
Rule("alpha"),
Rule("gamma_arr"),
Rule("A_eff_arr"),
Rule("w"),
Rule("l"),
Rule("w_c"),
Rule("w0"),
Rule("w_power_fact"),
Rule("t"),
Rule("L_D"),
Rule("L_NL"),
Rule("L_sol"),
Rule("dynamic_dispersion"),
Rule("adapt_step_size"),
Rule("error_ok"),
Rule("hr_w"),
Rule("z_targets"),
Rule("const_qty"),
Rule("beta_func"),
Rule("gamma_func"),
Rule("interp_range"),
Rule("datetime"),
Rule("version"),
]
"""
def main():
evalor = Evaluator()
evalor.append(*default_rules)
evalor.update(
**{
"length": 1,
"z_num": 128,
"wavelength": 1500e-9,
"interpolation_degree": 8,
"t_num": 16384,
"dt": 1e-15,
}
)
evalor.compute("z_targets")
print(evalor.params.keys())
print(evalor.params["l"][evalor.params["l"] > 0].min())
if __name__ == "__main__":
main()

View File

@@ -8,6 +8,9 @@ import numpy as np
from ..const import __version__ from ..const import __version__
# from .evaluator import Rule, Evaluator
# from ..physics import pulse, fiber, materials
T = TypeVar("T") T = TypeVar("T")
# Validator # Validator
@@ -187,7 +190,7 @@ def translate(p_name: str, p_value: T) -> tuple[str, T]:
class Parameter: class Parameter:
def __init__(self, validator, converter=None, default=None, display_info=None): def __init__(self, validator, converter=None, default=None, display_info=None, rules=None):
"""Single parameter """Single parameter
Parameters Parameters
@@ -208,6 +211,10 @@ class Parameter:
self.converter = converter self.converter = converter
self.default = default self.default = default
self.display_info = display_info self.display_info = display_info
if rules is None:
self.rules = []
else:
self.rules = rules
def __set_name__(self, owner, name): def __set_name__(self, owner, name):
self.name = name self.name = name
@@ -344,14 +351,14 @@ class BareParams:
""" """
# root # root
name: str = Parameter(string) name: str = Parameter(string, default="no name")
prev_data_dir: str = Parameter(string) prev_data_dir: str = Parameter(string)
previous_config_file: str = Parameter(string) previous_config_file: str = Parameter(string)
# # fiber # # fiber
input_transmission: float = Parameter(in_range_incl(0, 1)) input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
gamma: float = Parameter(non_negative(float, int)) gamma: float = Parameter(non_negative(float, int))
n2: float = Parameter(non_negative(float, int)) n2: float = Parameter(non_negative(float, int), default=2.2e-20)
loss: str = Parameter(literal("capillary")) loss: str = Parameter(literal("capillary"))
loss_file: str = Parameter(string) loss_file: str = Parameter(string)
effective_mode_diameter: float = Parameter(positive(float, int)) effective_mode_diameter: float = Parameter(positive(float, int))
@@ -360,58 +367,66 @@ class BareParams:
pitch: float = Parameter(in_range_excl(0, 1e-3)) pitch: float = Parameter(in_range_excl(0, 1e-3))
pitch_ratio: float = Parameter(in_range_excl(0, 1)) pitch_ratio: float = Parameter(in_range_excl(0, 1))
core_radius: float = Parameter(in_range_excl(0, 1e-3)) core_radius: float = Parameter(in_range_excl(0, 1e-3))
he_mode: Tuple[int, int] = Parameter(int_pair) he_mode: Tuple[int, int] = Parameter(int_pair, default=(1, 1))
fit_parameters: Tuple[int, int] = Parameter(int_pair) fit_parameters: Tuple[int, int] = Parameter(int_pair, default=(0.08, 200e-9))
beta: Iterable[float] = Parameter(num_list) beta: Iterable[float] = Parameter(num_list)
dispersion_file: str = Parameter(string) dispersion_file: str = Parameter(string)
model: str = Parameter(literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom")) model: str = Parameter(
length: float = Parameter(non_negative(float, int)) literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"), default="custom"
)
length: float = Parameter(non_negative(float, int), default=1.0)
capillary_num: int = Parameter(positive(int)) capillary_num: int = Parameter(positive(int))
capillary_outer_d: float = Parameter(in_range_excl(0, 1e-3)) capillary_outer_d: float = Parameter(in_range_excl(0, 1e-3))
capillary_thickness: float = Parameter(in_range_excl(0, 1e-3)) capillary_thickness: float = Parameter(in_range_excl(0, 1e-3))
capillary_spacing: float = Parameter(in_range_excl(0, 1e-3)) capillary_spacing: float = Parameter(in_range_excl(0, 1e-3))
capillary_resonance_strengths: Iterable[float] = Parameter(num_list) capillary_resonance_strengths: Iterable[float] = Parameter(num_list, default=[])
capillary_nested: int = Parameter(non_negative(int)) capillary_nested: int = Parameter(non_negative(int), default=0)
# gas # gas
gas_name: str = Parameter(string, converter=str.lower) gas_name: str = Parameter(string, converter=str.lower, default="vacuum")
pressure: Union[float, Iterable[float]] = Parameter( pressure: Union[float, Iterable[float]] = Parameter(
validator_or(non_negative(float, int), num_list), display_info=(1e-5, "bar") validator_or(non_negative(float, int), num_list), display_info=(1e-5, "bar"), default=1e5
) )
temperature: float = Parameter(positive(float, int), display_info=(1, "K")) temperature: float = Parameter(positive(float, int), display_info=(1, "K"), default=300)
plasma_density: float = Parameter(non_negative(float, int)) plasma_density: float = Parameter(non_negative(float, int), default=0)
# pulse # pulse
field_file: str = Parameter(string) field_file: str = Parameter(string)
repetition_rate: float = Parameter(non_negative(float, int), display_info=(1e-6, "MHz")) repetition_rate: float = Parameter(
non_negative(float, int), display_info=(1e-6, "MHz"), default=40e6
)
peak_power: float = Parameter(positive(float, int), display_info=(1e-3, "kW")) peak_power: float = Parameter(positive(float, int), display_info=(1e-3, "kW"))
mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW")) mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW"))
energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ")) energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ"))
soliton_num: float = Parameter(non_negative(float, int)) soliton_num: float = Parameter(non_negative(float, int))
quantum_noise: bool = Parameter(boolean) quantum_noise: bool = Parameter(boolean, default=False)
shape: str = Parameter(literal("gaussian", "sech")) shape: str = Parameter(literal("gaussian", "sech"), default="gaussian")
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm")) wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm"))
intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%")) intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%"), default=0)
width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
# simulation # simulation
behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss"))) behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss")), default=["spm", "ss"])
parallel: bool = Parameter(boolean) parallel: bool = Parameter(boolean, default=True)
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower) raman_type: str = Parameter(
ideal_gas: bool = Parameter(boolean) literal("measured", "agrawal", "stolen"), converter=str.lower, default="agrawal"
repeat: int = Parameter(positive(int)) )
ideal_gas: bool = Parameter(boolean, default=False)
repeat: int = Parameter(positive(int), default=1)
t_num: int = Parameter(positive(int)) t_num: int = Parameter(positive(int))
z_num: int = Parameter(positive(int)) z_num: int = Parameter(positive(int))
time_window: float = Parameter(positive(float, int)) time_window: float = Parameter(positive(float, int))
dt: float = Parameter(in_range_excl(0, 5e-15)) dt: float = Parameter(in_range_excl(0, 5e-15))
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3)) tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
step_size: float = Parameter(positive(float, int)) step_size: float = Parameter(positive(float, int))
lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9)) lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9), default=100e-9)
upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9)) upper_wavelength_interp_limit: float = Parameter(
interpolation_degree: int = Parameter(positive(int)) in_range_incl(200e-9, 5000e-9), default=2000e-9
)
interpolation_degree: int = Parameter(positive(int), default=8)
prev_sim_dir: str = Parameter(string) prev_sim_dir: str = Parameter(string)
recovery_last_stored: int = Parameter(non_negative(int)) recovery_last_stored: int = Parameter(non_negative(int), default=0)
worker_num: int = Parameter(positive(int)) worker_num: int = Parameter(positive(int))
# computed # computed