some renaming, added operators Rules and Parameter
This commit is contained in:
@@ -25,3 +25,80 @@ SPEC1_FN_N = "spectrum_{}_{}.npy"
|
|||||||
Z_FN = "z.npy"
|
Z_FN = "z.npy"
|
||||||
PARAM_FN = "params.toml"
|
PARAM_FN = "params.toml"
|
||||||
PARAM_SEPARATOR = " "
|
PARAM_SEPARATOR = " "
|
||||||
|
|
||||||
|
VALID_VARIABLE = {
|
||||||
|
"dispersion_file",
|
||||||
|
"prev_data_dir",
|
||||||
|
"field_file",
|
||||||
|
"loss_file",
|
||||||
|
"A_eff_file",
|
||||||
|
"beta2_coefficients",
|
||||||
|
"gamma",
|
||||||
|
"pitch",
|
||||||
|
"pitch_ratio",
|
||||||
|
"effective_mode_diameter",
|
||||||
|
"core_radius",
|
||||||
|
"model",
|
||||||
|
"capillary_num",
|
||||||
|
"capillary_radius",
|
||||||
|
"capillary_thickness",
|
||||||
|
"capillary_spacing",
|
||||||
|
"capillary_resonance_strengths",
|
||||||
|
"capillary_resonance_max_order",
|
||||||
|
"capillary_nested",
|
||||||
|
"he_mode",
|
||||||
|
"fit_parameters",
|
||||||
|
"input_transmission",
|
||||||
|
"n2",
|
||||||
|
"pressure",
|
||||||
|
"temperature",
|
||||||
|
"gas_name",
|
||||||
|
"plasma_density",
|
||||||
|
"peak_power",
|
||||||
|
"mean_power",
|
||||||
|
"peak_power",
|
||||||
|
"energy",
|
||||||
|
"quantum_noise",
|
||||||
|
"shape",
|
||||||
|
"wavelength",
|
||||||
|
"intensity_noise",
|
||||||
|
"width",
|
||||||
|
"t0",
|
||||||
|
"soliton_num",
|
||||||
|
"behaviors",
|
||||||
|
"raman_type",
|
||||||
|
"tolerated_error",
|
||||||
|
"step_size",
|
||||||
|
"interpolation_degree",
|
||||||
|
"ideal_gas",
|
||||||
|
"length",
|
||||||
|
"num",
|
||||||
|
}
|
||||||
|
|
||||||
|
MANDATORY_PARAMETERS = [
|
||||||
|
"name",
|
||||||
|
"w_c",
|
||||||
|
"w",
|
||||||
|
"w0",
|
||||||
|
"w_power_fact",
|
||||||
|
"alpha",
|
||||||
|
"spec_0",
|
||||||
|
"field_0",
|
||||||
|
"mean_power",
|
||||||
|
"input_transmission",
|
||||||
|
"z_targets",
|
||||||
|
"length",
|
||||||
|
"beta2_coefficients",
|
||||||
|
"gamma_arr",
|
||||||
|
"behaviors",
|
||||||
|
"raman_type",
|
||||||
|
"hr_w",
|
||||||
|
"adapt_step_size",
|
||||||
|
"tolerated_error",
|
||||||
|
"dynamic_dispersion",
|
||||||
|
"recovery_last_stored",
|
||||||
|
"output_path",
|
||||||
|
"repeat",
|
||||||
|
"linear_operator",
|
||||||
|
"nonlinear_op",
|
||||||
|
]
|
||||||
|
|||||||
383
src/scgenerator/evaluator.py
Normal file
383
src/scgenerator/evaluator.py
Normal file
@@ -0,0 +1,383 @@
|
|||||||
|
from typing import Optional, Callable, Union, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from .physics import fiber, pulse, materials, units
|
||||||
|
from .utils import _mock_function, get_arg_names, get_logger, func_rewrite
|
||||||
|
from .errors import *
|
||||||
|
from collections import defaultdict
|
||||||
|
from .const import MANDATORY_PARAMETERS
|
||||||
|
import numpy as np
|
||||||
|
import itertools
|
||||||
|
from . import math, utils, operators
|
||||||
|
|
||||||
|
|
||||||
|
class Rule:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: Union[str, list[Optional[str]]],
|
||||||
|
func: Callable,
|
||||||
|
args: list[str] = None,
|
||||||
|
priorities: Union[int, list[int]] = None,
|
||||||
|
conditions: dict[str, str] = None,
|
||||||
|
):
|
||||||
|
targets = list(target) if isinstance(target, (list, tuple)) else [target]
|
||||||
|
self.func = func
|
||||||
|
if priorities is None:
|
||||||
|
priorities = [1] * len(targets)
|
||||||
|
elif isinstance(priorities, (int, float, np.integer, np.floating)):
|
||||||
|
priorities = [priorities]
|
||||||
|
self.targets = dict(zip(targets, priorities))
|
||||||
|
if args is None:
|
||||||
|
args = get_arg_names(func)
|
||||||
|
self.args = args
|
||||||
|
self.mock_func = _mock_function(len(self.args), len(self.targets))
|
||||||
|
self.conditions = conditions or {}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"[{', '.join(self.args)}] -- {self.func.__module__}.{self.func.__name__} --> [{', '.join(self.targets)}]"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deduce(
|
||||||
|
cls,
|
||||||
|
target: Union[str, list[Optional[str]]],
|
||||||
|
func: Callable,
|
||||||
|
kwarg_names: list[str],
|
||||||
|
n_var: int,
|
||||||
|
args_const: list[str] = None,
|
||||||
|
priorities: Union[int, list[int]] = None,
|
||||||
|
) -> list["Rule"]:
|
||||||
|
"""given a function that doesn't need all its keyword arguemtn specified, will
|
||||||
|
return a list of Rule obj, one for each combination of n_var specified kwargs
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
target : str | list[str | None]
|
||||||
|
name of the variable(s) that func returns
|
||||||
|
func : Callable
|
||||||
|
function to work with
|
||||||
|
kwarg_names : list[str]
|
||||||
|
list of all kwargs of the function to be used
|
||||||
|
n_var : int
|
||||||
|
how many shoulf be used per rule
|
||||||
|
arg_const : list[str], optional
|
||||||
|
override the name of the positional arguments
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[Rule]
|
||||||
|
list of all possible rules
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>> def lol(a, b=None, c=None):
|
||||||
|
pass
|
||||||
|
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
|
||||||
|
[
|
||||||
|
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
|
||||||
|
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
rules: list[cls] = []
|
||||||
|
for var_possibility in itertools.combinations(kwarg_names, n_var):
|
||||||
|
|
||||||
|
new_func = func_rewrite(func, list(var_possibility), args_const)
|
||||||
|
|
||||||
|
rules.append(cls(target, new_func, priorities=priorities))
|
||||||
|
return rules
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalStat:
|
||||||
|
priority: float = np.inf
|
||||||
|
|
||||||
|
|
||||||
|
class Evaluator:
|
||||||
|
defaults: dict[str, Any] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls) -> "Evaluator":
|
||||||
|
evaluator = cls()
|
||||||
|
evaluator.append(*default_rules)
|
||||||
|
return evaluator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def evaluate_default(cls, params: dict[str, Any], check_only=False) -> dict[str, Any]:
|
||||||
|
evaluator = cls.default()
|
||||||
|
evaluator.set(**params)
|
||||||
|
for target in MANDATORY_PARAMETERS:
|
||||||
|
evaluator.compute(target, check_only=check_only)
|
||||||
|
return evaluator.params
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_default_param(cls, key, value):
|
||||||
|
cls.defaults[key] = value
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.rules: dict[str, list[Rule]] = defaultdict(list)
|
||||||
|
self.params = {}
|
||||||
|
self.__curent_lookup: list[str] = []
|
||||||
|
self.__failed_rules: dict[str, list[Rule]] = defaultdict(list)
|
||||||
|
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
def append(self, *rule: Rule):
|
||||||
|
for r in rule:
|
||||||
|
for t in r.targets:
|
||||||
|
if t is not None:
|
||||||
|
self.rules[t].append(r)
|
||||||
|
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
|
||||||
|
|
||||||
|
def set(self, **params: Any):
|
||||||
|
self.params.update(params)
|
||||||
|
for k in params:
|
||||||
|
self.eval_stats[k].priority = np.inf
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.params = {}
|
||||||
|
self.eval_stats = defaultdict(EvalStat)
|
||||||
|
|
||||||
|
def compute(self, target: str, check_only=False) -> Any:
|
||||||
|
"""computes a target
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
target : str
|
||||||
|
name of the target
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Any
|
||||||
|
return type of the target function
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
EvaluatorError
|
||||||
|
a cyclic dependence exists
|
||||||
|
KeyError
|
||||||
|
there is no saved rule for the target
|
||||||
|
"""
|
||||||
|
value = self.params.get(target)
|
||||||
|
if value is None:
|
||||||
|
prefix = "\t" * len(self.__curent_lookup)
|
||||||
|
# Avoid cycles
|
||||||
|
if target in self.__curent_lookup:
|
||||||
|
raise EvaluatorError(
|
||||||
|
"cyclic dependency detected : "
|
||||||
|
f"{target!r} seems to depend on itself, "
|
||||||
|
f"please provide a value for at least one variable in {self.__curent_lookup!r}. "
|
||||||
|
+ self.attempted_rules_str(target)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.__curent_lookup.append(target)
|
||||||
|
|
||||||
|
if len(self.rules[target]) == 0:
|
||||||
|
error = EvaluatorError(f"no rule for {target}")
|
||||||
|
else:
|
||||||
|
error = None
|
||||||
|
|
||||||
|
# try every rule until one succeeds
|
||||||
|
for ii, rule in enumerate(filter(self.validate_condition, self.rules[target])):
|
||||||
|
self.logger.debug(
|
||||||
|
prefix + f"attempt {ii+1} to compute {target}, this time using {rule!r}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
args = [self.compute(k, check_only=check_only) for k in rule.args]
|
||||||
|
if check_only:
|
||||||
|
returned_values = rule.mock_func(*args)
|
||||||
|
else:
|
||||||
|
returned_values = rule.func(*args)
|
||||||
|
if len(rule.targets) == 1:
|
||||||
|
returned_values = [returned_values]
|
||||||
|
for ((param_name, param_priority), returned_value) in zip(
|
||||||
|
rule.targets.items(), returned_values
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
param_name == target
|
||||||
|
or param_name not in self.params
|
||||||
|
or self.eval_stats[param_name].priority < param_priority
|
||||||
|
):
|
||||||
|
if check_only:
|
||||||
|
success_str = f"able to compute {param_name} "
|
||||||
|
else:
|
||||||
|
v_str = format(returned_value).replace("\n", "")
|
||||||
|
success_str = f"computed {param_name}={v_str} "
|
||||||
|
self.logger.info(
|
||||||
|
prefix
|
||||||
|
+ success_str
|
||||||
|
+ f"using {rule.func.__name__} from {rule.func.__module__}"
|
||||||
|
)
|
||||||
|
self.set_value(param_name, returned_value, param_priority)
|
||||||
|
if param_name == target:
|
||||||
|
value = returned_value
|
||||||
|
break
|
||||||
|
except (EvaluatorError, KeyError, NoDefaultError) as e:
|
||||||
|
error = e
|
||||||
|
self.logger.debug(
|
||||||
|
prefix + f"error using {rule.func.__name__} : {str(error).strip()}"
|
||||||
|
)
|
||||||
|
self.__failed_rules[target].append(rule)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
default = self.defaults.get(target)
|
||||||
|
if default is None:
|
||||||
|
error = error or NoDefaultError(
|
||||||
|
prefix
|
||||||
|
+ f"No default provided for {target}. Current lookup cycle : {self.__curent_lookup!r}. "
|
||||||
|
+ self.attempted_rules_str(target)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
value = default
|
||||||
|
self.logger.info(prefix + f"using default value of {value} for {target}")
|
||||||
|
self.set_value(target, value, 0)
|
||||||
|
|
||||||
|
assert target == self.__curent_lookup.pop()
|
||||||
|
self.__failed_rules[target] = []
|
||||||
|
|
||||||
|
if value is None and error is not None:
|
||||||
|
raise error
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> Any:
|
||||||
|
return self.params[key]
|
||||||
|
|
||||||
|
def set_value(self, key: str, value: Any, priority: int):
|
||||||
|
self.params[key] = value
|
||||||
|
self.eval_stats[key].priority = priority
|
||||||
|
|
||||||
|
def validate_condition(self, rule: Rule) -> bool:
|
||||||
|
return all(self.compute(k) == v for k, v in rule.conditions.items())
|
||||||
|
|
||||||
|
def attempted_rules_str(self, target: str) -> str:
|
||||||
|
rules = ", ".join(str(r) for r in self.__failed_rules[target])
|
||||||
|
if len(rules) == 0:
|
||||||
|
return ""
|
||||||
|
return "attempted rules : " + rules
|
||||||
|
|
||||||
|
|
||||||
|
default_rules: list[Rule] = [
|
||||||
|
# Grid
|
||||||
|
*Rule.deduce(
|
||||||
|
["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "l"],
|
||||||
|
math.build_sim_grid,
|
||||||
|
["time_window", "t_num", "dt"],
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
Rule("adapt_step_size", lambda step_size: step_size == 0),
|
||||||
|
Rule("dynamic_dispersion", lambda pressure: isinstance(pressure, (list, tuple, np.ndarray))),
|
||||||
|
# Pulse
|
||||||
|
Rule("spec_0", np.fft.fft, ["field_0"]),
|
||||||
|
Rule("field_0", np.fft.ifft, ["spec_0"]),
|
||||||
|
Rule("spec_0", utils.load_previous_spectrum, ["recovery_data_dir"], priorities=4),
|
||||||
|
Rule("spec_0", utils.load_previous_spectrum, priorities=3),
|
||||||
|
*Rule.deduce(
|
||||||
|
["pre_field_0", "peak_power", "energy", "width"],
|
||||||
|
pulse.load_and_adjust_field_file,
|
||||||
|
["energy", "peak_power"],
|
||||||
|
1,
|
||||||
|
priorities=[2, 1, 1, 1],
|
||||||
|
),
|
||||||
|
Rule("pre_field_0", pulse.initial_field, priorities=1),
|
||||||
|
Rule(
|
||||||
|
"field_0",
|
||||||
|
pulse.finalize_pulse,
|
||||||
|
[
|
||||||
|
"pre_field_0",
|
||||||
|
"quantum_noise",
|
||||||
|
"w_c",
|
||||||
|
"w0",
|
||||||
|
"time_window",
|
||||||
|
"dt",
|
||||||
|
"additional_noise_factor",
|
||||||
|
"input_transmission",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
||||||
|
Rule("peak_power", pulse.soliton_num_to_peak_power),
|
||||||
|
Rule("mean_power", pulse.energy_to_mean_power),
|
||||||
|
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
|
||||||
|
Rule("energy", pulse.mean_power_to_energy, priorities=2),
|
||||||
|
Rule("t0", pulse.width_to_t0),
|
||||||
|
Rule("t0", pulse.soliton_num_to_t0),
|
||||||
|
Rule("width", pulse.t0_to_width),
|
||||||
|
Rule("soliton_num", pulse.soliton_num),
|
||||||
|
Rule("L_D", pulse.L_D),
|
||||||
|
Rule("L_NL", pulse.L_NL),
|
||||||
|
Rule("L_sol", pulse.L_sol),
|
||||||
|
# Fiber Dispersion
|
||||||
|
Rule("wl_for_disp", fiber.lambda_for_dispersion),
|
||||||
|
Rule("w_for_disp", units.m, ["wl_for_disp"]),
|
||||||
|
Rule(
|
||||||
|
"beta2_coefficients",
|
||||||
|
fiber.dispersion_coefficients,
|
||||||
|
["wl_for_disp", "beta2_arr", "w0", "interpolation_range", "interpolation_degree"],
|
||||||
|
),
|
||||||
|
Rule("beta2_arr", fiber.beta2),
|
||||||
|
Rule("beta2_arr", fiber.dispersion_from_coefficients),
|
||||||
|
Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]),
|
||||||
|
Rule(
|
||||||
|
["wl_for_disp", "beta2_arr", "interpolation_range"],
|
||||||
|
fiber.load_custom_dispersion,
|
||||||
|
priorities=[2, 2, 2],
|
||||||
|
),
|
||||||
|
Rule("hr_w", fiber.delayed_raman_w),
|
||||||
|
Rule("n_gas_2", materials.n_gas_2),
|
||||||
|
Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")),
|
||||||
|
Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")),
|
||||||
|
Rule("n_eff", fiber.n_eff_marcatili_adjusted, conditions=dict(model="marcatili_adjusted")),
|
||||||
|
Rule(
|
||||||
|
"n_eff",
|
||||||
|
fiber.n_eff_pcf,
|
||||||
|
["wl_for_disp", "pitch", "pitch_ratio"],
|
||||||
|
conditions=dict(model="pcf"),
|
||||||
|
),
|
||||||
|
Rule("capillary_spacing", fiber.capillary_spacing_hasan),
|
||||||
|
Rule("capillary_resonance_strengths", fiber.capillary_resonance_strengths),
|
||||||
|
Rule("capillary_resonance_strengths", lambda: [], priorities=-1),
|
||||||
|
# Fiber nonlinearity
|
||||||
|
Rule("A_eff", fiber.A_eff_from_V),
|
||||||
|
Rule("A_eff", fiber.A_eff_from_diam),
|
||||||
|
Rule("A_eff", fiber.A_eff_hasan, conditions=dict(model="hasan")),
|
||||||
|
Rule("A_eff", fiber.A_eff_from_gamma, priorities=-1),
|
||||||
|
Rule("A_eff", fiber.A_eff_marcatili, priorities=-2),
|
||||||
|
Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]),
|
||||||
|
Rule("A_eff_arr", fiber.load_custom_A_eff),
|
||||||
|
Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1),
|
||||||
|
Rule(
|
||||||
|
"V_eff",
|
||||||
|
fiber.V_parameter_koshiba,
|
||||||
|
["wavelength", "pitch", "pitch_ratio"],
|
||||||
|
conditions=dict(model="pcf"),
|
||||||
|
),
|
||||||
|
Rule("V_eff", fiber.V_eff_step_index, ["wavelength", "core_radius", "numerical_aperture"]),
|
||||||
|
Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")),
|
||||||
|
Rule(
|
||||||
|
"V_eff_arr",
|
||||||
|
fiber.V_eff_step_index,
|
||||||
|
["l", "core_radius", "numerical_aperture", "interpolation_range"],
|
||||||
|
),
|
||||||
|
Rule("gamma", lambda gamma_arr: gamma_arr[0]),
|
||||||
|
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]),
|
||||||
|
Rule("n2", materials.gas_n2),
|
||||||
|
Rule("n2", lambda: 2.2e-20, priorities=-1),
|
||||||
|
# Operators
|
||||||
|
Rule("gamma_op", operators.ConstantGamma),
|
||||||
|
Rule("gamma_op", operators.NoGamma, priorities=-1),
|
||||||
|
Rule("ss_op", operators.SelfSteepening),
|
||||||
|
Rule("ss_op", operators.NoSelfSteepening, priorities=-1),
|
||||||
|
Rule("spm_op", operators.SPM),
|
||||||
|
Rule("spm_op", operators.NoSPM, priorities=-1),
|
||||||
|
Rule("raman_op", operators.Raman),
|
||||||
|
Rule("raman_op", operators.NoRaman, priorities=-1),
|
||||||
|
Rule("nonlinear_operator", operators.EnvelopeNonLinearOperator),
|
||||||
|
Rule("loss_op", operators.CustomConstantLoss, priorities=3),
|
||||||
|
Rule("loss_op", operators.CapillaryLoss, priorities=2),
|
||||||
|
Rule("loss_op", operators.ConstantLoss, priorities=1),
|
||||||
|
Rule("loss_op", operators.NoLoss, priorities=-1),
|
||||||
|
Rule("disp_op", operators.ConstantPolyDispersion),
|
||||||
|
Rule("linear_operator", operators.LinearOperator),
|
||||||
|
# gas
|
||||||
|
Rule("n_gas_2", materials.n_gas_2),
|
||||||
|
]
|
||||||
@@ -10,8 +10,8 @@ from dataclasses import dataclass, field
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
|
|
||||||
from . import fiber
|
from .physics import fiber
|
||||||
from .. import math
|
from . import math
|
||||||
|
|
||||||
|
|
||||||
class SpectrumDescriptor:
|
class SpectrumDescriptor:
|
||||||
@@ -45,6 +45,20 @@ class CurrentState:
|
|||||||
return self.z / self.length
|
return self.z / self.length
|
||||||
|
|
||||||
|
|
||||||
|
class Operator(ABC):
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
self.__class__.__name__
|
||||||
|
+ "("
|
||||||
|
+ ", ".join(k + "=" + repr(v) for k, v in self.__dict__.items())
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NoOp:
|
class NoOp:
|
||||||
def __init__(self, w: np.ndarray):
|
def __init__(self, w: np.ndarray):
|
||||||
self.zero_arr = np.zeros_like(w)
|
self.zero_arr = np.zeros_like(w)
|
||||||
@@ -55,7 +69,7 @@ class NoOp:
|
|||||||
##################################################
|
##################################################
|
||||||
|
|
||||||
|
|
||||||
class AbstractDispersion(ABC):
|
class AbstractDispersion(Operator):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||||
"""returns the dispersion in the frequency domain
|
"""returns the dispersion in the frequency domain
|
||||||
@@ -74,7 +88,7 @@ class AbstractDispersion(ABC):
|
|||||||
|
|
||||||
class ConstantPolyDispersion(AbstractDispersion):
|
class ConstantPolyDispersion(AbstractDispersion):
|
||||||
"""
|
"""
|
||||||
dispersion approximated by fitting a polynom on the dispersion and
|
dispersion approximated by fitting a polynome on the dispersion and
|
||||||
evaluating on the envelope
|
evaluating on the envelope
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -87,8 +101,8 @@ class ConstantPolyDispersion(AbstractDispersion):
|
|||||||
beta2_arr: np.ndarray,
|
beta2_arr: np.ndarray,
|
||||||
w0: float,
|
w0: float,
|
||||||
w_c: np.ndarray,
|
w_c: np.ndarray,
|
||||||
interpolation_range: tuple[float, float] = None,
|
interpolation_range: tuple[float, float],
|
||||||
interpolation_degree: int = 8,
|
interpolation_degree: int,
|
||||||
):
|
):
|
||||||
self.coefs = fiber.dispersion_coefficients(
|
self.coefs = fiber.dispersion_coefficients(
|
||||||
wl_for_disp, beta2_arr, w0, interpolation_range, interpolation_degree
|
wl_for_disp, beta2_arr, w0, interpolation_range, interpolation_degree
|
||||||
@@ -108,9 +122,9 @@ class ConstantPolyDispersion(AbstractDispersion):
|
|||||||
|
|
||||||
|
|
||||||
class LinearOperator:
|
class LinearOperator:
|
||||||
def __init__(self, disp: AbstractDispersion, loss: AbstractLoss):
|
def __init__(self, disp_op: AbstractDispersion, loss_op: AbstractLoss):
|
||||||
self.disp = disp
|
self.disp = disp_op
|
||||||
self.loss = loss
|
self.loss = loss_op
|
||||||
|
|
||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||||
"""returns the linear operator to be multiplied by the spectrum in the frequency domain
|
"""returns the linear operator to be multiplied by the spectrum in the frequency domain
|
||||||
@@ -135,7 +149,7 @@ class LinearOperator:
|
|||||||
# Raman
|
# Raman
|
||||||
|
|
||||||
|
|
||||||
class AbstractRaman(ABC):
|
class AbstractRaman(Operator):
|
||||||
f_r: float = 0.0
|
f_r: float = 0.0
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -171,7 +185,7 @@ class Raman(AbstractRaman):
|
|||||||
# SPM
|
# SPM
|
||||||
|
|
||||||
|
|
||||||
class AbstractSPM(ABC):
|
class AbstractSPM(Operator):
|
||||||
fraction: float = 1.0
|
fraction: float = 1.0
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -206,7 +220,7 @@ class SPM(AbstractSPM):
|
|||||||
# Selt Steepening
|
# Selt Steepening
|
||||||
|
|
||||||
|
|
||||||
class AbstractSelfSteepening(ABC):
|
class AbstractSelfSteepening(Operator):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||||
"""returns the self-steepening component
|
"""returns the self-steepening component
|
||||||
@@ -239,7 +253,7 @@ class SelfSteepening(AbstractSelfSteepening):
|
|||||||
# Gamma operator
|
# Gamma operator
|
||||||
|
|
||||||
|
|
||||||
class AbstractGamma(ABC):
|
class AbstractGamma(Operator):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||||
"""returns the gamma component
|
"""returns the gamma component
|
||||||
@@ -275,7 +289,7 @@ class ConstantGamma(AbstractSelfSteepening):
|
|||||||
# Nonlinear combination
|
# Nonlinear combination
|
||||||
|
|
||||||
|
|
||||||
class AbstractNonLinearOperator(ABC):
|
class NonLinearOperator(Operator):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||||
"""returns the nonlinear operator applied on the spectrum in the frequency domain
|
"""returns the nonlinear operator applied on the spectrum in the frequency domain
|
||||||
@@ -292,7 +306,7 @@ class AbstractNonLinearOperator(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class EnvelopeNonLinearOperator(AbstractNonLinearOperator):
|
class EnvelopeNonLinearOperator(NonLinearOperator):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gamma_op: AbstractGamma,
|
gamma_op: AbstractGamma,
|
||||||
@@ -319,7 +333,7 @@ class EnvelopeNonLinearOperator(AbstractNonLinearOperator):
|
|||||||
##################################################
|
##################################################
|
||||||
|
|
||||||
|
|
||||||
class AbstractLoss(ABC):
|
class AbstractLoss(Operator):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||||
"""returns the loss in the frequency domain
|
"""returns the loss in the frequency domain
|
||||||
@@ -342,10 +356,15 @@ class ConstantLoss(AbstractLoss):
|
|||||||
def __init__(self, alpha: float, w: np.ndarray):
|
def __init__(self, alpha: float, w: np.ndarray):
|
||||||
self.alpha_arr = alpha * np.ones_like(w)
|
self.alpha_arr = alpha * np.ones_like(w)
|
||||||
|
|
||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
def __call__(self, state: CurrentState = None) -> np.ndarray:
|
||||||
return self.alpha_arr
|
return self.alpha_arr
|
||||||
|
|
||||||
|
|
||||||
|
class NoLoss(ConstantLoss):
|
||||||
|
def __init__(self, w: np.ndarray):
|
||||||
|
super().__init__(0, w)
|
||||||
|
|
||||||
|
|
||||||
class CapillaryLoss(ConstantLoss):
|
class CapillaryLoss(ConstantLoss):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -2,10 +2,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import datetime as datetime_module
|
import datetime as datetime_module
|
||||||
import enum
|
import enum
|
||||||
import inspect
|
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from copy import copy
|
from copy import copy
|
||||||
@@ -18,94 +16,18 @@ import numpy as np
|
|||||||
from numpy.lib import isin
|
from numpy.lib import isin
|
||||||
|
|
||||||
from . import env, math, utils
|
from . import env, math, utils
|
||||||
from .const import PARAM_FN, __version__
|
from .const import PARAM_FN, __version__, VALID_VARIABLE, MANDATORY_PARAMETERS
|
||||||
from .errors import EvaluatorError, NoDefaultError
|
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .physics import fiber, materials, pulse, units
|
from .utils import fiber_folder, update_path_name
|
||||||
from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names, update_path_name
|
|
||||||
from .variationer import VariationDescriptor, Variationer
|
from .variationer import VariationDescriptor, Variationer
|
||||||
|
from .evaluator import Evaluator
|
||||||
|
from .operators import NonLinearOperator, LinearOperator
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
# Validator
|
# Validator
|
||||||
|
|
||||||
|
|
||||||
VALID_VARIABLE = {
|
|
||||||
"dispersion_file",
|
|
||||||
"prev_data_dir",
|
|
||||||
"field_file",
|
|
||||||
"loss_file",
|
|
||||||
"A_eff_file",
|
|
||||||
"beta2_coefficients",
|
|
||||||
"gamma",
|
|
||||||
"pitch",
|
|
||||||
"pitch_ratio",
|
|
||||||
"effective_mode_diameter",
|
|
||||||
"core_radius",
|
|
||||||
"model",
|
|
||||||
"capillary_num",
|
|
||||||
"capillary_radius",
|
|
||||||
"capillary_thickness",
|
|
||||||
"capillary_spacing",
|
|
||||||
"capillary_resonance_strengths",
|
|
||||||
"capillary_resonance_max_order",
|
|
||||||
"capillary_nested",
|
|
||||||
"he_mode",
|
|
||||||
"fit_parameters",
|
|
||||||
"input_transmission",
|
|
||||||
"n2",
|
|
||||||
"pressure",
|
|
||||||
"temperature",
|
|
||||||
"gas_name",
|
|
||||||
"plasma_density",
|
|
||||||
"peak_power",
|
|
||||||
"mean_power",
|
|
||||||
"peak_power",
|
|
||||||
"energy",
|
|
||||||
"quantum_noise",
|
|
||||||
"shape",
|
|
||||||
"wavelength",
|
|
||||||
"intensity_noise",
|
|
||||||
"width",
|
|
||||||
"t0",
|
|
||||||
"soliton_num",
|
|
||||||
"behaviors",
|
|
||||||
"raman_type",
|
|
||||||
"tolerated_error",
|
|
||||||
"step_size",
|
|
||||||
"interpolation_degree",
|
|
||||||
"ideal_gas",
|
|
||||||
"length",
|
|
||||||
"num",
|
|
||||||
}
|
|
||||||
|
|
||||||
MANDATORY_PARAMETERS = [
|
|
||||||
"name",
|
|
||||||
"w_c",
|
|
||||||
"w",
|
|
||||||
"w0",
|
|
||||||
"w_power_fact",
|
|
||||||
"alpha",
|
|
||||||
"spec_0",
|
|
||||||
"field_0",
|
|
||||||
"mean_power",
|
|
||||||
"input_transmission",
|
|
||||||
"z_targets",
|
|
||||||
"length",
|
|
||||||
"beta2_coefficients",
|
|
||||||
"gamma_arr",
|
|
||||||
"behaviors",
|
|
||||||
"raman_type",
|
|
||||||
"hr_w",
|
|
||||||
"adapt_step_size",
|
|
||||||
"tolerated_error",
|
|
||||||
"dynamic_dispersion",
|
|
||||||
"recovery_last_stored",
|
|
||||||
"output_path",
|
|
||||||
"repeat",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def type_checker(*types):
|
def type_checker(*types):
|
||||||
def _type_checker_wrapper(validator, n=None):
|
def _type_checker_wrapper(validator, n=None):
|
||||||
@@ -286,6 +208,8 @@ class Parameter:
|
|||||||
|
|
||||||
def __set_name__(self, owner, name):
|
def __set_name__(self, owner, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
if self.default is not None:
|
||||||
|
Evaluator.register_default_param(self.name, self.default)
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
def __get__(self, instance, owner):
|
||||||
if instance is None:
|
if instance is None:
|
||||||
@@ -405,9 +329,7 @@ class Parameters(_AbstractParameters):
|
|||||||
validator_list(literal("spm", "raman", "ss")), converter=tuple, default=("spm", "ss")
|
validator_list(literal("spm", "raman", "ss")), converter=tuple, default=("spm", "ss")
|
||||||
)
|
)
|
||||||
parallel: bool = Parameter(boolean, default=True)
|
parallel: bool = Parameter(boolean, default=True)
|
||||||
raman_type: str = Parameter(
|
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
|
||||||
literal("measured", "agrawal", "stolen"), converter=str.lower, default="agrawal"
|
|
||||||
)
|
|
||||||
ideal_gas: bool = Parameter(boolean, default=False)
|
ideal_gas: bool = Parameter(boolean, default=False)
|
||||||
repeat: int = Parameter(positive(int), default=1)
|
repeat: int = Parameter(positive(int), default=1)
|
||||||
t_num: int = Parameter(positive(int))
|
t_num: int = Parameter(positive(int))
|
||||||
@@ -423,6 +345,8 @@ class Parameters(_AbstractParameters):
|
|||||||
worker_num: int = Parameter(positive(int))
|
worker_num: int = Parameter(positive(int))
|
||||||
|
|
||||||
# computed
|
# computed
|
||||||
|
linear_operator: LinearOperator = Parameter(type_checker(LinearOperator))
|
||||||
|
nonlinear_operator: NonLinearOperator = Parameter(type_checker(NonLinearOperator))
|
||||||
field_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
field_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
beta2: float = Parameter(type_checker(int, float))
|
beta2: float = Parameter(type_checker(int, float))
|
||||||
@@ -538,253 +462,6 @@ class Parameters(_AbstractParameters):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Rule:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
target: Union[str, list[Optional[str]]],
|
|
||||||
func: Callable,
|
|
||||||
args: list[str] = None,
|
|
||||||
priorities: Union[int, list[int]] = None,
|
|
||||||
conditions: dict[str, str] = None,
|
|
||||||
):
|
|
||||||
targets = list(target) if isinstance(target, (list, tuple)) else [target]
|
|
||||||
self.func = func
|
|
||||||
if priorities is None:
|
|
||||||
priorities = [1] * len(targets)
|
|
||||||
elif isinstance(priorities, (int, float, np.integer, np.floating)):
|
|
||||||
priorities = [priorities]
|
|
||||||
self.targets = dict(zip(targets, priorities))
|
|
||||||
if args is None:
|
|
||||||
args = get_arg_names(func)
|
|
||||||
self.args = args
|
|
||||||
self.mock_func = _mock_function(len(self.args), len(self.targets))
|
|
||||||
self.conditions = conditions or {}
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})"
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"[{', '.join(self.args)}] -- {self.func.__module__}.{self.func.__name__} --> [{', '.join(self.targets)}]"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def deduce(
|
|
||||||
cls,
|
|
||||||
target: Union[str, list[Optional[str]]],
|
|
||||||
func: Callable,
|
|
||||||
kwarg_names: list[str],
|
|
||||||
n_var: int,
|
|
||||||
args_const: list[str] = None,
|
|
||||||
priorities: Union[int, list[int]] = None,
|
|
||||||
) -> list["Rule"]:
|
|
||||||
"""given a function that doesn't need all its keyword arguemtn specified, will
|
|
||||||
return a list of Rule obj, one for each combination of n_var specified kwargs
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
target : str | list[str | None]
|
|
||||||
name of the variable(s) that func returns
|
|
||||||
func : Callable
|
|
||||||
function to work with
|
|
||||||
kwarg_names : list[str]
|
|
||||||
list of all kwargs of the function to be used
|
|
||||||
n_var : int
|
|
||||||
how many shoulf be used per rule
|
|
||||||
arg_const : list[str], optional
|
|
||||||
override the name of the positional arguments
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[Rule]
|
|
||||||
list of all possible rules
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
>> def lol(a, b=None, c=None):
|
|
||||||
pass
|
|
||||||
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
|
|
||||||
[
|
|
||||||
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
|
|
||||||
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
rules: list[cls] = []
|
|
||||||
for var_possibility in itertools.combinations(kwarg_names, n_var):
|
|
||||||
|
|
||||||
new_func = func_rewrite(func, list(var_possibility), args_const)
|
|
||||||
|
|
||||||
rules.append(cls(target, new_func, priorities=priorities))
|
|
||||||
return rules
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EvalStat:
|
|
||||||
priority: float = np.inf
|
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
|
||||||
@classmethod
|
|
||||||
def default(cls) -> "Evaluator":
|
|
||||||
evaluator = cls()
|
|
||||||
evaluator.append(*default_rules)
|
|
||||||
return evaluator
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def evaluate_default(cls, params: dict[str, Any], check_only=False) -> dict[str, Any]:
|
|
||||||
evaluator = cls.default()
|
|
||||||
evaluator.set(**params)
|
|
||||||
for target in MANDATORY_PARAMETERS:
|
|
||||||
evaluator.compute(target, check_only=check_only)
|
|
||||||
return evaluator.params
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.rules: dict[str, list[Rule]] = defaultdict(list)
|
|
||||||
self.params = {}
|
|
||||||
self.__curent_lookup: list[str] = []
|
|
||||||
self.__failed_rules: dict[str, list[Rule]] = defaultdict(list)
|
|
||||||
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
|
|
||||||
self.logger = get_logger(__name__)
|
|
||||||
|
|
||||||
def append(self, *rule: Rule):
|
|
||||||
for r in rule:
|
|
||||||
for t in r.targets:
|
|
||||||
if t is not None:
|
|
||||||
self.rules[t].append(r)
|
|
||||||
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
|
|
||||||
|
|
||||||
def set(self, **params: Any):
|
|
||||||
self.params.update(params)
|
|
||||||
for k in params:
|
|
||||||
self.eval_stats[k].priority = np.inf
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.params = {}
|
|
||||||
self.eval_stats = defaultdict(EvalStat)
|
|
||||||
|
|
||||||
def get_default(self, key: str) -> Any:
|
|
||||||
try:
|
|
||||||
return getattr(Parameters, key).default
|
|
||||||
except AttributeError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def compute(self, target: str, check_only=False) -> Any:
|
|
||||||
"""computes a target
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
target : str
|
|
||||||
name of the target
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Any
|
|
||||||
return type of the target function
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
EvaluatorError
|
|
||||||
a cyclic dependence exists
|
|
||||||
KeyError
|
|
||||||
there is no saved rule for the target
|
|
||||||
"""
|
|
||||||
value = self.params.get(target)
|
|
||||||
if value is None:
|
|
||||||
prefix = "\t" * len(self.__curent_lookup)
|
|
||||||
# Avoid cycles
|
|
||||||
if target in self.__curent_lookup:
|
|
||||||
raise EvaluatorError(
|
|
||||||
"cyclic dependency detected : "
|
|
||||||
f"{target!r} seems to depend on itself, "
|
|
||||||
f"please provide a value for at least one variable in {self.__curent_lookup!r}. "
|
|
||||||
+ self.attempted_rules_str(target)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.__curent_lookup.append(target)
|
|
||||||
|
|
||||||
if len(self.rules[target]) == 0:
|
|
||||||
error = EvaluatorError(f"no rule for {target}")
|
|
||||||
else:
|
|
||||||
error = None
|
|
||||||
|
|
||||||
# try every rule until one succeeds
|
|
||||||
for ii, rule in enumerate(filter(self.validate_condition, self.rules[target])):
|
|
||||||
self.logger.debug(
|
|
||||||
prefix + f"attempt {ii+1} to compute {target}, this time using {rule!r}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
args = [self.compute(k, check_only=check_only) for k in rule.args]
|
|
||||||
if check_only:
|
|
||||||
returned_values = rule.mock_func(*args)
|
|
||||||
else:
|
|
||||||
returned_values = rule.func(*args)
|
|
||||||
if len(rule.targets) == 1:
|
|
||||||
returned_values = [returned_values]
|
|
||||||
for ((param_name, param_priority), returned_value) in zip(
|
|
||||||
rule.targets.items(), returned_values
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
param_name == target
|
|
||||||
or param_name not in self.params
|
|
||||||
or self.eval_stats[param_name].priority < param_priority
|
|
||||||
):
|
|
||||||
if check_only:
|
|
||||||
success_str = f"able to compute {param_name} "
|
|
||||||
else:
|
|
||||||
v_str = format(returned_value).replace("\n", "")
|
|
||||||
success_str = f"computed {param_name}={v_str} "
|
|
||||||
self.logger.info(
|
|
||||||
prefix
|
|
||||||
+ success_str
|
|
||||||
+ f"using {rule.func.__name__} from {rule.func.__module__}"
|
|
||||||
)
|
|
||||||
self.set_value(param_name, returned_value, param_priority)
|
|
||||||
if param_name == target:
|
|
||||||
value = returned_value
|
|
||||||
break
|
|
||||||
except (EvaluatorError, KeyError, NoDefaultError) as e:
|
|
||||||
error = e
|
|
||||||
self.logger.debug(
|
|
||||||
prefix + f"error using {rule.func.__name__} : {str(error).strip()}"
|
|
||||||
)
|
|
||||||
self.__failed_rules[target].append(rule)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
default = self.get_default(target)
|
|
||||||
if default is None:
|
|
||||||
error = error or NoDefaultError(
|
|
||||||
prefix
|
|
||||||
+ f"No default provided for {target}. Current lookup cycle : {self.__curent_lookup!r}. "
|
|
||||||
+ self.attempted_rules_str(target)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
value = default
|
|
||||||
self.logger.info(prefix + f"using default value of {value} for {target}")
|
|
||||||
self.set_value(target, value, 0)
|
|
||||||
|
|
||||||
assert target == self.__curent_lookup.pop()
|
|
||||||
self.__failed_rules[target] = []
|
|
||||||
|
|
||||||
if value is None and error is not None:
|
|
||||||
raise error
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> Any:
|
|
||||||
return self.params[key]
|
|
||||||
|
|
||||||
def set_value(self, key: str, value: Any, priority: int):
|
|
||||||
self.params[key] = value
|
|
||||||
self.eval_stats[key].priority = priority
|
|
||||||
|
|
||||||
def validate_condition(self, rule: Rule) -> bool:
|
|
||||||
return all(self.compute(k) == v for k, v in rule.conditions.items())
|
|
||||||
|
|
||||||
def attempted_rules_str(self, target: str) -> str:
|
|
||||||
rules = ", ".join(str(r) for r in self.__failed_rules[target])
|
|
||||||
if len(rules) == 0:
|
|
||||||
return ""
|
|
||||||
return "attempted rules : " + rules
|
|
||||||
|
|
||||||
|
|
||||||
class Configuration:
|
class Configuration:
|
||||||
"""
|
"""
|
||||||
Primary role is to load the final config file of the simulation and deduce every
|
Primary role is to load the final config file of the simulation and deduce every
|
||||||
@@ -1041,120 +718,6 @@ class Configuration:
|
|||||||
return param
|
return param
|
||||||
|
|
||||||
|
|
||||||
default_rules: list[Rule] = [
|
|
||||||
# Grid
|
|
||||||
*Rule.deduce(
|
|
||||||
["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "l"],
|
|
||||||
math.build_sim_grid,
|
|
||||||
["time_window", "t_num", "dt"],
|
|
||||||
2,
|
|
||||||
),
|
|
||||||
Rule("adapt_step_size", lambda step_size: step_size == 0),
|
|
||||||
Rule("dynamic_dispersion", lambda pressure: isinstance(pressure, (list, tuple, np.ndarray))),
|
|
||||||
# Pulse
|
|
||||||
Rule("spec_0", np.fft.fft, ["field_0"]),
|
|
||||||
Rule("field_0", np.fft.ifft, ["spec_0"]),
|
|
||||||
Rule("spec_0", utils.load_previous_spectrum, ["recovery_data_dir"], priorities=4),
|
|
||||||
Rule("spec_0", utils.load_previous_spectrum, priorities=3),
|
|
||||||
*Rule.deduce(
|
|
||||||
["pre_field_0", "peak_power", "energy", "width"],
|
|
||||||
pulse.load_and_adjust_field_file,
|
|
||||||
["energy", "peak_power"],
|
|
||||||
1,
|
|
||||||
priorities=[2, 1, 1, 1],
|
|
||||||
),
|
|
||||||
Rule("pre_field_0", pulse.initial_field, priorities=1),
|
|
||||||
Rule(
|
|
||||||
"field_0",
|
|
||||||
pulse.finalize_pulse,
|
|
||||||
[
|
|
||||||
"pre_field_0",
|
|
||||||
"quantum_noise",
|
|
||||||
"w_c",
|
|
||||||
"w0",
|
|
||||||
"time_window",
|
|
||||||
"dt",
|
|
||||||
"additional_noise_factor",
|
|
||||||
"input_transmission",
|
|
||||||
],
|
|
||||||
),
|
|
||||||
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
|
||||||
Rule("peak_power", pulse.soliton_num_to_peak_power),
|
|
||||||
Rule("mean_power", pulse.energy_to_mean_power),
|
|
||||||
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
|
|
||||||
Rule("energy", pulse.mean_power_to_energy, priorities=2),
|
|
||||||
Rule("t0", pulse.width_to_t0),
|
|
||||||
Rule("t0", pulse.soliton_num_to_t0),
|
|
||||||
Rule("width", pulse.t0_to_width),
|
|
||||||
Rule("soliton_num", pulse.soliton_num),
|
|
||||||
Rule("L_D", pulse.L_D),
|
|
||||||
Rule("L_NL", pulse.L_NL),
|
|
||||||
Rule("L_sol", pulse.L_sol),
|
|
||||||
# Fiber Dispersion
|
|
||||||
Rule("wl_for_disp", fiber.lambda_for_dispersion),
|
|
||||||
Rule("w_for_disp", units.m, ["wl_for_disp"]),
|
|
||||||
Rule(
|
|
||||||
"beta2_coefficients",
|
|
||||||
fiber.dispersion_coefficients,
|
|
||||||
["wl_for_disp", "beta2_arr", "w0", "interpolation_range", "interpolation_degree"],
|
|
||||||
),
|
|
||||||
Rule("beta2_arr", fiber.beta2),
|
|
||||||
Rule("beta2_arr", fiber.dispersion_from_coefficients),
|
|
||||||
Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]),
|
|
||||||
Rule(
|
|
||||||
["wl_for_disp", "beta2_arr", "interpolation_range"],
|
|
||||||
fiber.load_custom_dispersion,
|
|
||||||
priorities=[2, 2, 2],
|
|
||||||
),
|
|
||||||
Rule("hr_w", fiber.delayed_raman_w),
|
|
||||||
Rule("n_gas_2", materials.n_gas_2),
|
|
||||||
Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")),
|
|
||||||
Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")),
|
|
||||||
Rule("n_eff", fiber.n_eff_marcatili_adjusted, conditions=dict(model="marcatili_adjusted")),
|
|
||||||
Rule(
|
|
||||||
"n_eff",
|
|
||||||
fiber.n_eff_pcf,
|
|
||||||
["wl_for_disp", "pitch", "pitch_ratio"],
|
|
||||||
conditions=dict(model="pcf"),
|
|
||||||
),
|
|
||||||
Rule("capillary_spacing", fiber.capillary_spacing_hasan),
|
|
||||||
Rule("capillary_resonance_strengths", fiber.capillary_resonance_strengths),
|
|
||||||
Rule("capillary_resonance_strengths", lambda: [], priorities=-1),
|
|
||||||
# Fiber nonlinearity
|
|
||||||
Rule("A_eff", fiber.A_eff_from_V),
|
|
||||||
Rule("A_eff", fiber.A_eff_from_diam),
|
|
||||||
Rule("A_eff", fiber.A_eff_hasan, conditions=dict(model="hasan")),
|
|
||||||
Rule("A_eff", fiber.A_eff_from_gamma, priorities=-1),
|
|
||||||
Rule("A_eff", fiber.A_eff_marcatili, priorities=-2),
|
|
||||||
Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]),
|
|
||||||
Rule("A_eff_arr", fiber.load_custom_A_eff),
|
|
||||||
Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1),
|
|
||||||
Rule(
|
|
||||||
"V_eff",
|
|
||||||
fiber.V_parameter_koshiba,
|
|
||||||
["wavelength", "pitch", "pitch_ratio"],
|
|
||||||
conditions=dict(model="pcf"),
|
|
||||||
),
|
|
||||||
Rule("V_eff", fiber.V_eff_step_index, ["wavelength", "core_radius", "numerical_aperture"]),
|
|
||||||
Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")),
|
|
||||||
Rule(
|
|
||||||
"V_eff_arr",
|
|
||||||
fiber.V_eff_step_index,
|
|
||||||
["l", "core_radius", "numerical_aperture", "interpolation_range"],
|
|
||||||
),
|
|
||||||
Rule("gamma", lambda gamma_arr: gamma_arr[0]),
|
|
||||||
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]),
|
|
||||||
Rule("n2", materials.gas_n2),
|
|
||||||
Rule("n2", lambda: 2.2e-20, priorities=-1),
|
|
||||||
# Fiber loss
|
|
||||||
Rule("alpha_arr", fiber.compute_capillary_loss),
|
|
||||||
Rule("alpha_arr", fiber.load_custom_loss),
|
|
||||||
Rule("alpha_arr", lambda alpha, t: np.ones_like(t) * alpha, priorities=-1),
|
|
||||||
# gas
|
|
||||||
Rule("n_gas_2", materials.n_gas_2),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
numero = type_checker(int)
|
numero = type_checker(int)
|
||||||
|
|||||||
@@ -843,19 +843,6 @@ def load_custom_loss(l: np.ndarray, loss_file: str) -> np.ndarray:
|
|||||||
return interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
|
return interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
|
||||||
|
|
||||||
|
|
||||||
def compute_capillary_loss(
|
|
||||||
l: np.ndarray,
|
|
||||||
core_radius: float,
|
|
||||||
interpolation_range: tuple[float, float],
|
|
||||||
he_mode: tuple[int, int],
|
|
||||||
) -> np.ndarray:
|
|
||||||
mask = (l < interpolation_range[1]) & (l > 0)
|
|
||||||
alpha = capillary_loss(l[mask], he_mode, core_radius)
|
|
||||||
out = np.zeros_like(l)
|
|
||||||
out[mask] = alpha
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@np_cache
|
@np_cache
|
||||||
def dispersion_coefficients(
|
def dispersion_coefficients(
|
||||||
wl_for_disp: np.ndarray,
|
wl_for_disp: np.ndarray,
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ import multiprocessing
|
|||||||
import multiprocessing.connection
|
import multiprocessing.connection
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from datetime import datetime
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generator, Type, Union
|
from typing import Any, Generator, Type, Union
|
||||||
|
|
||||||
@@ -13,9 +13,9 @@ from .. import utils
|
|||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..parameter import Configuration, Parameters
|
from ..parameter import Configuration, Parameters
|
||||||
from ..pbar import PBars, ProgressBarActor, progress_worker
|
from ..pbar import PBars, ProgressBarActor, progress_worker
|
||||||
|
from ..operators import CurrentState
|
||||||
from . import pulse
|
from . import pulse
|
||||||
from .fiber import create_non_linear_op, fast_dispersion_op
|
from .fiber import create_non_linear_op, fast_dispersion_op
|
||||||
from .properties import CurrentState
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
|
|||||||
@@ -4,12 +4,13 @@ scgenerator module but some function may be used in any python program
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from dataclasses import dataclass
|
|
||||||
import inspect
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from string import printable as str_printable
|
from string import printable as str_printable
|
||||||
@@ -373,11 +374,19 @@ def to_62(i: int) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_arg_names(func: Callable) -> list[str]:
|
def get_arg_names(func: Callable) -> list[str]:
|
||||||
# spec = inspect.getfullargspec(func)
|
"""returns the positional argument names of func.
|
||||||
# args = spec.args
|
|
||||||
# if spec.defaults is not None and len(spec.defaults) > 0:
|
Parameters
|
||||||
# args = args[: -len(spec.defaults)]
|
----------
|
||||||
# return args
|
func : Callable
|
||||||
|
if a function, returns the names of the positional arguments
|
||||||
|
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[str]
|
||||||
|
[description]
|
||||||
|
"""
|
||||||
return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty]
|
return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user