diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index b8843a2..bfd5c25 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -25,3 +25,80 @@ SPEC1_FN_N = "spectrum_{}_{}.npy" Z_FN = "z.npy" PARAM_FN = "params.toml" PARAM_SEPARATOR = " " + +VALID_VARIABLE = { + "dispersion_file", + "prev_data_dir", + "field_file", + "loss_file", + "A_eff_file", + "beta2_coefficients", + "gamma", + "pitch", + "pitch_ratio", + "effective_mode_diameter", + "core_radius", + "model", + "capillary_num", + "capillary_radius", + "capillary_thickness", + "capillary_spacing", + "capillary_resonance_strengths", + "capillary_resonance_max_order", + "capillary_nested", + "he_mode", + "fit_parameters", + "input_transmission", + "n2", + "pressure", + "temperature", + "gas_name", + "plasma_density", + "peak_power", + "mean_power", + "peak_power", + "energy", + "quantum_noise", + "shape", + "wavelength", + "intensity_noise", + "width", + "t0", + "soliton_num", + "behaviors", + "raman_type", + "tolerated_error", + "step_size", + "interpolation_degree", + "ideal_gas", + "length", + "num", +} + +MANDATORY_PARAMETERS = [ + "name", + "w_c", + "w", + "w0", + "w_power_fact", + "alpha", + "spec_0", + "field_0", + "mean_power", + "input_transmission", + "z_targets", + "length", + "beta2_coefficients", + "gamma_arr", + "behaviors", + "raman_type", + "hr_w", + "adapt_step_size", + "tolerated_error", + "dynamic_dispersion", + "recovery_last_stored", + "output_path", + "repeat", + "linear_operator", + "nonlinear_op", +] diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py new file mode 100644 index 0000000..a374929 --- /dev/null +++ b/src/scgenerator/evaluator.py @@ -0,0 +1,383 @@ +from typing import Optional, Callable, Union, Any +from dataclasses import dataclass +from .physics import fiber, pulse, materials, units +from .utils import _mock_function, get_arg_names, get_logger, func_rewrite +from .errors import * +from collections import defaultdict +from .const import MANDATORY_PARAMETERS +import numpy as np +import itertools +from . import math, utils, operators + + +class Rule: + def __init__( + self, + target: Union[str, list[Optional[str]]], + func: Callable, + args: list[str] = None, + priorities: Union[int, list[int]] = None, + conditions: dict[str, str] = None, + ): + targets = list(target) if isinstance(target, (list, tuple)) else [target] + self.func = func + if priorities is None: + priorities = [1] * len(targets) + elif isinstance(priorities, (int, float, np.integer, np.floating)): + priorities = [priorities] + self.targets = dict(zip(targets, priorities)) + if args is None: + args = get_arg_names(func) + self.args = args + self.mock_func = _mock_function(len(self.args), len(self.targets)) + self.conditions = conditions or {} + + def __repr__(self) -> str: + return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})" + + def __str__(self) -> str: + return f"[{', '.join(self.args)}] -- {self.func.__module__}.{self.func.__name__} --> [{', '.join(self.targets)}]" + + @classmethod + def deduce( + cls, + target: Union[str, list[Optional[str]]], + func: Callable, + kwarg_names: list[str], + n_var: int, + args_const: list[str] = None, + priorities: Union[int, list[int]] = None, + ) -> list["Rule"]: + """given a function that doesn't need all its keyword arguemtn specified, will + return a list of Rule obj, one for each combination of n_var specified kwargs + + Parameters + ---------- + target : str | list[str | None] + name of the variable(s) that func returns + func : Callable + function to work with + kwarg_names : list[str] + list of all kwargs of the function to be used + n_var : int + how many shoulf be used per rule + arg_const : list[str], optional + override the name of the positional arguments + + Returns + ------- + list[Rule] + list of all possible rules + + Example + ------- + >> def lol(a, b=None, c=None): + pass + >> print(Rule.deduce(["d"], lol, ["b", "c"], 1)) + [ + Rule(targets={'d': 1}, func=, args=['a', 'b']), + Rule(targets={'d': 1}, func=, args=['a', 'c']) + ] + """ + rules: list[cls] = [] + for var_possibility in itertools.combinations(kwarg_names, n_var): + + new_func = func_rewrite(func, list(var_possibility), args_const) + + rules.append(cls(target, new_func, priorities=priorities)) + return rules + + +@dataclass +class EvalStat: + priority: float = np.inf + + +class Evaluator: + defaults: dict[str, Any] = {} + + @classmethod + def default(cls) -> "Evaluator": + evaluator = cls() + evaluator.append(*default_rules) + return evaluator + + @classmethod + def evaluate_default(cls, params: dict[str, Any], check_only=False) -> dict[str, Any]: + evaluator = cls.default() + evaluator.set(**params) + for target in MANDATORY_PARAMETERS: + evaluator.compute(target, check_only=check_only) + return evaluator.params + + @classmethod + def register_default_param(cls, key, value): + cls.defaults[key] = value + + def __init__(self): + self.rules: dict[str, list[Rule]] = defaultdict(list) + self.params = {} + self.__curent_lookup: list[str] = [] + self.__failed_rules: dict[str, list[Rule]] = defaultdict(list) + self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat) + self.logger = get_logger(__name__) + + def append(self, *rule: Rule): + for r in rule: + for t in r.targets: + if t is not None: + self.rules[t].append(r) + self.rules[t].sort(key=lambda el: el.targets[t], reverse=True) + + def set(self, **params: Any): + self.params.update(params) + for k in params: + self.eval_stats[k].priority = np.inf + + def reset(self): + self.params = {} + self.eval_stats = defaultdict(EvalStat) + + def compute(self, target: str, check_only=False) -> Any: + """computes a target + + Parameters + ---------- + target : str + name of the target + + Returns + ------- + Any + return type of the target function + + Raises + ------ + EvaluatorError + a cyclic dependence exists + KeyError + there is no saved rule for the target + """ + value = self.params.get(target) + if value is None: + prefix = "\t" * len(self.__curent_lookup) + # Avoid cycles + if target in self.__curent_lookup: + raise EvaluatorError( + "cyclic dependency detected : " + f"{target!r} seems to depend on itself, " + f"please provide a value for at least one variable in {self.__curent_lookup!r}. " + + self.attempted_rules_str(target) + ) + else: + self.__curent_lookup.append(target) + + if len(self.rules[target]) == 0: + error = EvaluatorError(f"no rule for {target}") + else: + error = None + + # try every rule until one succeeds + for ii, rule in enumerate(filter(self.validate_condition, self.rules[target])): + self.logger.debug( + prefix + f"attempt {ii+1} to compute {target}, this time using {rule!r}" + ) + try: + args = [self.compute(k, check_only=check_only) for k in rule.args] + if check_only: + returned_values = rule.mock_func(*args) + else: + returned_values = rule.func(*args) + if len(rule.targets) == 1: + returned_values = [returned_values] + for ((param_name, param_priority), returned_value) in zip( + rule.targets.items(), returned_values + ): + if ( + param_name == target + or param_name not in self.params + or self.eval_stats[param_name].priority < param_priority + ): + if check_only: + success_str = f"able to compute {param_name} " + else: + v_str = format(returned_value).replace("\n", "") + success_str = f"computed {param_name}={v_str} " + self.logger.info( + prefix + + success_str + + f"using {rule.func.__name__} from {rule.func.__module__}" + ) + self.set_value(param_name, returned_value, param_priority) + if param_name == target: + value = returned_value + break + except (EvaluatorError, KeyError, NoDefaultError) as e: + error = e + self.logger.debug( + prefix + f"error using {rule.func.__name__} : {str(error).strip()}" + ) + self.__failed_rules[target].append(rule) + continue + else: + default = self.defaults.get(target) + if default is None: + error = error or NoDefaultError( + prefix + + f"No default provided for {target}. Current lookup cycle : {self.__curent_lookup!r}. " + + self.attempted_rules_str(target) + ) + else: + value = default + self.logger.info(prefix + f"using default value of {value} for {target}") + self.set_value(target, value, 0) + + assert target == self.__curent_lookup.pop() + self.__failed_rules[target] = [] + + if value is None and error is not None: + raise error + + return value + + def __getitem__(self, key: str) -> Any: + return self.params[key] + + def set_value(self, key: str, value: Any, priority: int): + self.params[key] = value + self.eval_stats[key].priority = priority + + def validate_condition(self, rule: Rule) -> bool: + return all(self.compute(k) == v for k, v in rule.conditions.items()) + + def attempted_rules_str(self, target: str) -> str: + rules = ", ".join(str(r) for r in self.__failed_rules[target]) + if len(rules) == 0: + return "" + return "attempted rules : " + rules + + +default_rules: list[Rule] = [ + # Grid + *Rule.deduce( + ["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "l"], + math.build_sim_grid, + ["time_window", "t_num", "dt"], + 2, + ), + Rule("adapt_step_size", lambda step_size: step_size == 0), + Rule("dynamic_dispersion", lambda pressure: isinstance(pressure, (list, tuple, np.ndarray))), + # Pulse + Rule("spec_0", np.fft.fft, ["field_0"]), + Rule("field_0", np.fft.ifft, ["spec_0"]), + Rule("spec_0", utils.load_previous_spectrum, ["recovery_data_dir"], priorities=4), + Rule("spec_0", utils.load_previous_spectrum, priorities=3), + *Rule.deduce( + ["pre_field_0", "peak_power", "energy", "width"], + pulse.load_and_adjust_field_file, + ["energy", "peak_power"], + 1, + priorities=[2, 1, 1, 1], + ), + Rule("pre_field_0", pulse.initial_field, priorities=1), + Rule( + "field_0", + pulse.finalize_pulse, + [ + "pre_field_0", + "quantum_noise", + "w_c", + "w0", + "time_window", + "dt", + "additional_noise_factor", + "input_transmission", + ], + ), + Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]), + Rule("peak_power", pulse.soliton_num_to_peak_power), + Rule("mean_power", pulse.energy_to_mean_power), + Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]), + Rule("energy", pulse.mean_power_to_energy, priorities=2), + Rule("t0", pulse.width_to_t0), + Rule("t0", pulse.soliton_num_to_t0), + Rule("width", pulse.t0_to_width), + Rule("soliton_num", pulse.soliton_num), + Rule("L_D", pulse.L_D), + Rule("L_NL", pulse.L_NL), + Rule("L_sol", pulse.L_sol), + # Fiber Dispersion + Rule("wl_for_disp", fiber.lambda_for_dispersion), + Rule("w_for_disp", units.m, ["wl_for_disp"]), + Rule( + "beta2_coefficients", + fiber.dispersion_coefficients, + ["wl_for_disp", "beta2_arr", "w0", "interpolation_range", "interpolation_degree"], + ), + Rule("beta2_arr", fiber.beta2), + Rule("beta2_arr", fiber.dispersion_from_coefficients), + Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]), + Rule( + ["wl_for_disp", "beta2_arr", "interpolation_range"], + fiber.load_custom_dispersion, + priorities=[2, 2, 2], + ), + Rule("hr_w", fiber.delayed_raman_w), + Rule("n_gas_2", materials.n_gas_2), + Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")), + Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")), + Rule("n_eff", fiber.n_eff_marcatili_adjusted, conditions=dict(model="marcatili_adjusted")), + Rule( + "n_eff", + fiber.n_eff_pcf, + ["wl_for_disp", "pitch", "pitch_ratio"], + conditions=dict(model="pcf"), + ), + Rule("capillary_spacing", fiber.capillary_spacing_hasan), + Rule("capillary_resonance_strengths", fiber.capillary_resonance_strengths), + Rule("capillary_resonance_strengths", lambda: [], priorities=-1), + # Fiber nonlinearity + Rule("A_eff", fiber.A_eff_from_V), + Rule("A_eff", fiber.A_eff_from_diam), + Rule("A_eff", fiber.A_eff_hasan, conditions=dict(model="hasan")), + Rule("A_eff", fiber.A_eff_from_gamma, priorities=-1), + Rule("A_eff", fiber.A_eff_marcatili, priorities=-2), + Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]), + Rule("A_eff_arr", fiber.load_custom_A_eff), + Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1), + Rule( + "V_eff", + fiber.V_parameter_koshiba, + ["wavelength", "pitch", "pitch_ratio"], + conditions=dict(model="pcf"), + ), + Rule("V_eff", fiber.V_eff_step_index, ["wavelength", "core_radius", "numerical_aperture"]), + Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")), + Rule( + "V_eff_arr", + fiber.V_eff_step_index, + ["l", "core_radius", "numerical_aperture", "interpolation_range"], + ), + Rule("gamma", lambda gamma_arr: gamma_arr[0]), + Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]), + Rule("n2", materials.gas_n2), + Rule("n2", lambda: 2.2e-20, priorities=-1), + # Operators + Rule("gamma_op", operators.ConstantGamma), + Rule("gamma_op", operators.NoGamma, priorities=-1), + Rule("ss_op", operators.SelfSteepening), + Rule("ss_op", operators.NoSelfSteepening, priorities=-1), + Rule("spm_op", operators.SPM), + Rule("spm_op", operators.NoSPM, priorities=-1), + Rule("raman_op", operators.Raman), + Rule("raman_op", operators.NoRaman, priorities=-1), + Rule("nonlinear_operator", operators.EnvelopeNonLinearOperator), + Rule("loss_op", operators.CustomConstantLoss, priorities=3), + Rule("loss_op", operators.CapillaryLoss, priorities=2), + Rule("loss_op", operators.ConstantLoss, priorities=1), + Rule("loss_op", operators.NoLoss, priorities=-1), + Rule("disp_op", operators.ConstantPolyDispersion), + Rule("linear_operator", operators.LinearOperator), + # gas + Rule("n_gas_2", materials.n_gas_2), +] diff --git a/src/scgenerator/physics/properties.py b/src/scgenerator/operators.py similarity index 88% rename from src/scgenerator/physics/properties.py rename to src/scgenerator/operators.py index 4d82dc2..b7396c8 100644 --- a/src/scgenerator/physics/properties.py +++ b/src/scgenerator/operators.py @@ -10,8 +10,8 @@ from dataclasses import dataclass, field import numpy as np from scipy.interpolate import interp1d -from . import fiber -from .. import math +from .physics import fiber +from . import math class SpectrumDescriptor: @@ -45,6 +45,20 @@ class CurrentState: return self.z / self.length +class Operator(ABC): + def __repr__(self) -> str: + return ( + self.__class__.__name__ + + "(" + + ", ".join(k + "=" + repr(v) for k, v in self.__dict__.items()) + + ")" + ) + + @abstractmethod + def __call__(self, state: CurrentState) -> np.ndarray: + pass + + class NoOp: def __init__(self, w: np.ndarray): self.zero_arr = np.zeros_like(w) @@ -55,7 +69,7 @@ class NoOp: ################################################## -class AbstractDispersion(ABC): +class AbstractDispersion(Operator): @abstractmethod def __call__(self, state: CurrentState) -> np.ndarray: """returns the dispersion in the frequency domain @@ -74,7 +88,7 @@ class AbstractDispersion(ABC): class ConstantPolyDispersion(AbstractDispersion): """ - dispersion approximated by fitting a polynom on the dispersion and + dispersion approximated by fitting a polynome on the dispersion and evaluating on the envelope """ @@ -87,8 +101,8 @@ class ConstantPolyDispersion(AbstractDispersion): beta2_arr: np.ndarray, w0: float, w_c: np.ndarray, - interpolation_range: tuple[float, float] = None, - interpolation_degree: int = 8, + interpolation_range: tuple[float, float], + interpolation_degree: int, ): self.coefs = fiber.dispersion_coefficients( wl_for_disp, beta2_arr, w0, interpolation_range, interpolation_degree @@ -108,9 +122,9 @@ class ConstantPolyDispersion(AbstractDispersion): class LinearOperator: - def __init__(self, disp: AbstractDispersion, loss: AbstractLoss): - self.disp = disp - self.loss = loss + def __init__(self, disp_op: AbstractDispersion, loss_op: AbstractLoss): + self.disp = disp_op + self.loss = loss_op def __call__(self, state: CurrentState) -> np.ndarray: """returns the linear operator to be multiplied by the spectrum in the frequency domain @@ -135,7 +149,7 @@ class LinearOperator: # Raman -class AbstractRaman(ABC): +class AbstractRaman(Operator): f_r: float = 0.0 @abstractmethod @@ -171,7 +185,7 @@ class Raman(AbstractRaman): # SPM -class AbstractSPM(ABC): +class AbstractSPM(Operator): fraction: float = 1.0 @abstractmethod @@ -206,7 +220,7 @@ class SPM(AbstractSPM): # Selt Steepening -class AbstractSelfSteepening(ABC): +class AbstractSelfSteepening(Operator): @abstractmethod def __call__(self, state: CurrentState) -> np.ndarray: """returns the self-steepening component @@ -239,7 +253,7 @@ class SelfSteepening(AbstractSelfSteepening): # Gamma operator -class AbstractGamma(ABC): +class AbstractGamma(Operator): @abstractmethod def __call__(self, state: CurrentState) -> np.ndarray: """returns the gamma component @@ -275,7 +289,7 @@ class ConstantGamma(AbstractSelfSteepening): # Nonlinear combination -class AbstractNonLinearOperator(ABC): +class NonLinearOperator(Operator): @abstractmethod def __call__(self, state: CurrentState) -> np.ndarray: """returns the nonlinear operator applied on the spectrum in the frequency domain @@ -292,7 +306,7 @@ class AbstractNonLinearOperator(ABC): """ -class EnvelopeNonLinearOperator(AbstractNonLinearOperator): +class EnvelopeNonLinearOperator(NonLinearOperator): def __init__( self, gamma_op: AbstractGamma, @@ -319,7 +333,7 @@ class EnvelopeNonLinearOperator(AbstractNonLinearOperator): ################################################## -class AbstractLoss(ABC): +class AbstractLoss(Operator): @abstractmethod def __call__(self, state: CurrentState) -> np.ndarray: """returns the loss in the frequency domain @@ -342,10 +356,15 @@ class ConstantLoss(AbstractLoss): def __init__(self, alpha: float, w: np.ndarray): self.alpha_arr = alpha * np.ones_like(w) - def __call__(self, state: CurrentState) -> np.ndarray: + def __call__(self, state: CurrentState = None) -> np.ndarray: return self.alpha_arr +class NoLoss(ConstantLoss): + def __init__(self, w: np.ndarray): + super().__init__(0, w) + + class CapillaryLoss(ConstantLoss): def __init__( self, diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index c6729c6..804d366 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -2,10 +2,8 @@ from __future__ import annotations import datetime as datetime_module import enum -import inspect import itertools import os -import re import time from collections import defaultdict from copy import copy @@ -18,94 +16,18 @@ import numpy as np from numpy.lib import isin from . import env, math, utils -from .const import PARAM_FN, __version__ -from .errors import EvaluatorError, NoDefaultError +from .const import PARAM_FN, __version__, VALID_VARIABLE, MANDATORY_PARAMETERS from .logger import get_logger -from .physics import fiber, materials, pulse, units -from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names, update_path_name +from .utils import fiber_folder, update_path_name from .variationer import VariationDescriptor, Variationer +from .evaluator import Evaluator +from .operators import NonLinearOperator, LinearOperator T = TypeVar("T") # Validator -VALID_VARIABLE = { - "dispersion_file", - "prev_data_dir", - "field_file", - "loss_file", - "A_eff_file", - "beta2_coefficients", - "gamma", - "pitch", - "pitch_ratio", - "effective_mode_diameter", - "core_radius", - "model", - "capillary_num", - "capillary_radius", - "capillary_thickness", - "capillary_spacing", - "capillary_resonance_strengths", - "capillary_resonance_max_order", - "capillary_nested", - "he_mode", - "fit_parameters", - "input_transmission", - "n2", - "pressure", - "temperature", - "gas_name", - "plasma_density", - "peak_power", - "mean_power", - "peak_power", - "energy", - "quantum_noise", - "shape", - "wavelength", - "intensity_noise", - "width", - "t0", - "soliton_num", - "behaviors", - "raman_type", - "tolerated_error", - "step_size", - "interpolation_degree", - "ideal_gas", - "length", - "num", -} - -MANDATORY_PARAMETERS = [ - "name", - "w_c", - "w", - "w0", - "w_power_fact", - "alpha", - "spec_0", - "field_0", - "mean_power", - "input_transmission", - "z_targets", - "length", - "beta2_coefficients", - "gamma_arr", - "behaviors", - "raman_type", - "hr_w", - "adapt_step_size", - "tolerated_error", - "dynamic_dispersion", - "recovery_last_stored", - "output_path", - "repeat", -] - - @lru_cache def type_checker(*types): def _type_checker_wrapper(validator, n=None): @@ -286,6 +208,8 @@ class Parameter: def __set_name__(self, owner, name): self.name = name + if self.default is not None: + Evaluator.register_default_param(self.name, self.default) def __get__(self, instance, owner): if instance is None: @@ -405,9 +329,7 @@ class Parameters(_AbstractParameters): validator_list(literal("spm", "raman", "ss")), converter=tuple, default=("spm", "ss") ) parallel: bool = Parameter(boolean, default=True) - raman_type: str = Parameter( - literal("measured", "agrawal", "stolen"), converter=str.lower, default="agrawal" - ) + raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower) ideal_gas: bool = Parameter(boolean, default=False) repeat: int = Parameter(positive(int), default=1) t_num: int = Parameter(positive(int)) @@ -423,6 +345,8 @@ class Parameters(_AbstractParameters): worker_num: int = Parameter(positive(int)) # computed + linear_operator: LinearOperator = Parameter(type_checker(LinearOperator)) + nonlinear_operator: NonLinearOperator = Parameter(type_checker(NonLinearOperator)) field_0: np.ndarray = Parameter(type_checker(np.ndarray)) spec_0: np.ndarray = Parameter(type_checker(np.ndarray)) beta2: float = Parameter(type_checker(int, float)) @@ -538,253 +462,6 @@ class Parameters(_AbstractParameters): return None -class Rule: - def __init__( - self, - target: Union[str, list[Optional[str]]], - func: Callable, - args: list[str] = None, - priorities: Union[int, list[int]] = None, - conditions: dict[str, str] = None, - ): - targets = list(target) if isinstance(target, (list, tuple)) else [target] - self.func = func - if priorities is None: - priorities = [1] * len(targets) - elif isinstance(priorities, (int, float, np.integer, np.floating)): - priorities = [priorities] - self.targets = dict(zip(targets, priorities)) - if args is None: - args = get_arg_names(func) - self.args = args - self.mock_func = _mock_function(len(self.args), len(self.targets)) - self.conditions = conditions or {} - - def __repr__(self) -> str: - return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})" - - def __str__(self) -> str: - return f"[{', '.join(self.args)}] -- {self.func.__module__}.{self.func.__name__} --> [{', '.join(self.targets)}]" - - @classmethod - def deduce( - cls, - target: Union[str, list[Optional[str]]], - func: Callable, - kwarg_names: list[str], - n_var: int, - args_const: list[str] = None, - priorities: Union[int, list[int]] = None, - ) -> list["Rule"]: - """given a function that doesn't need all its keyword arguemtn specified, will - return a list of Rule obj, one for each combination of n_var specified kwargs - - Parameters - ---------- - target : str | list[str | None] - name of the variable(s) that func returns - func : Callable - function to work with - kwarg_names : list[str] - list of all kwargs of the function to be used - n_var : int - how many shoulf be used per rule - arg_const : list[str], optional - override the name of the positional arguments - - Returns - ------- - list[Rule] - list of all possible rules - - Example - ------- - >> def lol(a, b=None, c=None): - pass - >> print(Rule.deduce(["d"], lol, ["b", "c"], 1)) - [ - Rule(targets={'d': 1}, func=, args=['a', 'b']), - Rule(targets={'d': 1}, func=, args=['a', 'c']) - ] - """ - rules: list[cls] = [] - for var_possibility in itertools.combinations(kwarg_names, n_var): - - new_func = func_rewrite(func, list(var_possibility), args_const) - - rules.append(cls(target, new_func, priorities=priorities)) - return rules - - -@dataclass -class EvalStat: - priority: float = np.inf - - -class Evaluator: - @classmethod - def default(cls) -> "Evaluator": - evaluator = cls() - evaluator.append(*default_rules) - return evaluator - - @classmethod - def evaluate_default(cls, params: dict[str, Any], check_only=False) -> dict[str, Any]: - evaluator = cls.default() - evaluator.set(**params) - for target in MANDATORY_PARAMETERS: - evaluator.compute(target, check_only=check_only) - return evaluator.params - - def __init__(self): - self.rules: dict[str, list[Rule]] = defaultdict(list) - self.params = {} - self.__curent_lookup: list[str] = [] - self.__failed_rules: dict[str, list[Rule]] = defaultdict(list) - self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat) - self.logger = get_logger(__name__) - - def append(self, *rule: Rule): - for r in rule: - for t in r.targets: - if t is not None: - self.rules[t].append(r) - self.rules[t].sort(key=lambda el: el.targets[t], reverse=True) - - def set(self, **params: Any): - self.params.update(params) - for k in params: - self.eval_stats[k].priority = np.inf - - def reset(self): - self.params = {} - self.eval_stats = defaultdict(EvalStat) - - def get_default(self, key: str) -> Any: - try: - return getattr(Parameters, key).default - except AttributeError: - return None - - def compute(self, target: str, check_only=False) -> Any: - """computes a target - - Parameters - ---------- - target : str - name of the target - - Returns - ------- - Any - return type of the target function - - Raises - ------ - EvaluatorError - a cyclic dependence exists - KeyError - there is no saved rule for the target - """ - value = self.params.get(target) - if value is None: - prefix = "\t" * len(self.__curent_lookup) - # Avoid cycles - if target in self.__curent_lookup: - raise EvaluatorError( - "cyclic dependency detected : " - f"{target!r} seems to depend on itself, " - f"please provide a value for at least one variable in {self.__curent_lookup!r}. " - + self.attempted_rules_str(target) - ) - else: - self.__curent_lookup.append(target) - - if len(self.rules[target]) == 0: - error = EvaluatorError(f"no rule for {target}") - else: - error = None - - # try every rule until one succeeds - for ii, rule in enumerate(filter(self.validate_condition, self.rules[target])): - self.logger.debug( - prefix + f"attempt {ii+1} to compute {target}, this time using {rule!r}" - ) - try: - args = [self.compute(k, check_only=check_only) for k in rule.args] - if check_only: - returned_values = rule.mock_func(*args) - else: - returned_values = rule.func(*args) - if len(rule.targets) == 1: - returned_values = [returned_values] - for ((param_name, param_priority), returned_value) in zip( - rule.targets.items(), returned_values - ): - if ( - param_name == target - or param_name not in self.params - or self.eval_stats[param_name].priority < param_priority - ): - if check_only: - success_str = f"able to compute {param_name} " - else: - v_str = format(returned_value).replace("\n", "") - success_str = f"computed {param_name}={v_str} " - self.logger.info( - prefix - + success_str - + f"using {rule.func.__name__} from {rule.func.__module__}" - ) - self.set_value(param_name, returned_value, param_priority) - if param_name == target: - value = returned_value - break - except (EvaluatorError, KeyError, NoDefaultError) as e: - error = e - self.logger.debug( - prefix + f"error using {rule.func.__name__} : {str(error).strip()}" - ) - self.__failed_rules[target].append(rule) - continue - else: - default = self.get_default(target) - if default is None: - error = error or NoDefaultError( - prefix - + f"No default provided for {target}. Current lookup cycle : {self.__curent_lookup!r}. " - + self.attempted_rules_str(target) - ) - else: - value = default - self.logger.info(prefix + f"using default value of {value} for {target}") - self.set_value(target, value, 0) - - assert target == self.__curent_lookup.pop() - self.__failed_rules[target] = [] - - if value is None and error is not None: - raise error - - return value - - def __getitem__(self, key: str) -> Any: - return self.params[key] - - def set_value(self, key: str, value: Any, priority: int): - self.params[key] = value - self.eval_stats[key].priority = priority - - def validate_condition(self, rule: Rule) -> bool: - return all(self.compute(k) == v for k, v in rule.conditions.items()) - - def attempted_rules_str(self, target: str) -> str: - rules = ", ".join(str(r) for r in self.__failed_rules[target]) - if len(rules) == 0: - return "" - return "attempted rules : " + rules - - class Configuration: """ Primary role is to load the final config file of the simulation and deduce every @@ -1041,120 +718,6 @@ class Configuration: return param -default_rules: list[Rule] = [ - # Grid - *Rule.deduce( - ["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "l"], - math.build_sim_grid, - ["time_window", "t_num", "dt"], - 2, - ), - Rule("adapt_step_size", lambda step_size: step_size == 0), - Rule("dynamic_dispersion", lambda pressure: isinstance(pressure, (list, tuple, np.ndarray))), - # Pulse - Rule("spec_0", np.fft.fft, ["field_0"]), - Rule("field_0", np.fft.ifft, ["spec_0"]), - Rule("spec_0", utils.load_previous_spectrum, ["recovery_data_dir"], priorities=4), - Rule("spec_0", utils.load_previous_spectrum, priorities=3), - *Rule.deduce( - ["pre_field_0", "peak_power", "energy", "width"], - pulse.load_and_adjust_field_file, - ["energy", "peak_power"], - 1, - priorities=[2, 1, 1, 1], - ), - Rule("pre_field_0", pulse.initial_field, priorities=1), - Rule( - "field_0", - pulse.finalize_pulse, - [ - "pre_field_0", - "quantum_noise", - "w_c", - "w0", - "time_window", - "dt", - "additional_noise_factor", - "input_transmission", - ], - ), - Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]), - Rule("peak_power", pulse.soliton_num_to_peak_power), - Rule("mean_power", pulse.energy_to_mean_power), - Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]), - Rule("energy", pulse.mean_power_to_energy, priorities=2), - Rule("t0", pulse.width_to_t0), - Rule("t0", pulse.soliton_num_to_t0), - Rule("width", pulse.t0_to_width), - Rule("soliton_num", pulse.soliton_num), - Rule("L_D", pulse.L_D), - Rule("L_NL", pulse.L_NL), - Rule("L_sol", pulse.L_sol), - # Fiber Dispersion - Rule("wl_for_disp", fiber.lambda_for_dispersion), - Rule("w_for_disp", units.m, ["wl_for_disp"]), - Rule( - "beta2_coefficients", - fiber.dispersion_coefficients, - ["wl_for_disp", "beta2_arr", "w0", "interpolation_range", "interpolation_degree"], - ), - Rule("beta2_arr", fiber.beta2), - Rule("beta2_arr", fiber.dispersion_from_coefficients), - Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]), - Rule( - ["wl_for_disp", "beta2_arr", "interpolation_range"], - fiber.load_custom_dispersion, - priorities=[2, 2, 2], - ), - Rule("hr_w", fiber.delayed_raman_w), - Rule("n_gas_2", materials.n_gas_2), - Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")), - Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")), - Rule("n_eff", fiber.n_eff_marcatili_adjusted, conditions=dict(model="marcatili_adjusted")), - Rule( - "n_eff", - fiber.n_eff_pcf, - ["wl_for_disp", "pitch", "pitch_ratio"], - conditions=dict(model="pcf"), - ), - Rule("capillary_spacing", fiber.capillary_spacing_hasan), - Rule("capillary_resonance_strengths", fiber.capillary_resonance_strengths), - Rule("capillary_resonance_strengths", lambda: [], priorities=-1), - # Fiber nonlinearity - Rule("A_eff", fiber.A_eff_from_V), - Rule("A_eff", fiber.A_eff_from_diam), - Rule("A_eff", fiber.A_eff_hasan, conditions=dict(model="hasan")), - Rule("A_eff", fiber.A_eff_from_gamma, priorities=-1), - Rule("A_eff", fiber.A_eff_marcatili, priorities=-2), - Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]), - Rule("A_eff_arr", fiber.load_custom_A_eff), - Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1), - Rule( - "V_eff", - fiber.V_parameter_koshiba, - ["wavelength", "pitch", "pitch_ratio"], - conditions=dict(model="pcf"), - ), - Rule("V_eff", fiber.V_eff_step_index, ["wavelength", "core_radius", "numerical_aperture"]), - Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")), - Rule( - "V_eff_arr", - fiber.V_eff_step_index, - ["l", "core_radius", "numerical_aperture", "interpolation_range"], - ), - Rule("gamma", lambda gamma_arr: gamma_arr[0]), - Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]), - Rule("n2", materials.gas_n2), - Rule("n2", lambda: 2.2e-20, priorities=-1), - # Fiber loss - Rule("alpha_arr", fiber.compute_capillary_loss), - Rule("alpha_arr", fiber.load_custom_loss), - Rule("alpha_arr", lambda alpha, t: np.ones_like(t) * alpha, priorities=-1), - # gas - Rule("n_gas_2", materials.n_gas_2), -] - - if __name__ == "__main__": numero = type_checker(int) diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 9b33cae..5b63a16 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -843,19 +843,6 @@ def load_custom_loss(l: np.ndarray, loss_file: str) -> np.ndarray: return interp1d(wl, loss, fill_value=0, bounds_error=False)(l) -def compute_capillary_loss( - l: np.ndarray, - core_radius: float, - interpolation_range: tuple[float, float], - he_mode: tuple[int, int], -) -> np.ndarray: - mask = (l < interpolation_range[1]) & (l > 0) - alpha = capillary_loss(l[mask], he_mode, core_radius) - out = np.zeros_like(l) - out[mask] = alpha - return out - - @np_cache def dispersion_coefficients( wl_for_disp: np.ndarray, diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 0f6094f..dfe7189 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -2,8 +2,8 @@ import multiprocessing import multiprocessing.connection import os import random -from datetime import datetime from dataclasses import dataclass +from datetime import datetime from pathlib import Path from typing import Any, Generator, Type, Union @@ -13,9 +13,9 @@ from .. import utils from ..logger import get_logger from ..parameter import Configuration, Parameters from ..pbar import PBars, ProgressBarActor, progress_worker +from ..operators import CurrentState from . import pulse from .fiber import create_non_linear_op, fast_dispersion_op -from .properties import CurrentState try: import ray diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index 1e95d85..6d5acd2 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -4,12 +4,13 @@ scgenerator module but some function may be used in any python program """ from __future__ import annotations -from dataclasses import dataclass + import inspect import itertools import os import re from collections import defaultdict +from dataclasses import dataclass from functools import cache from pathlib import Path from string import printable as str_printable @@ -373,11 +374,19 @@ def to_62(i: int) -> str: def get_arg_names(func: Callable) -> list[str]: - # spec = inspect.getfullargspec(func) - # args = spec.args - # if spec.defaults is not None and len(spec.defaults) > 0: - # args = args[: -len(spec.defaults)] - # return args + """returns the positional argument names of func. + + Parameters + ---------- + func : Callable + if a function, returns the names of the positional arguments + + + Returns + ------- + list[str] + [description] + """ return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty]