From 1f0937d840e044cc2c49e8616fe49049dd0305fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Wed, 25 Aug 2021 12:18:55 +0200 Subject: [PATCH] solid ground work --- func_rewrite.py | 47 ++++ src/scgenerator/initialize.py | 320 ++++++++++++------------- src/scgenerator/math.py | 154 ++++++++++++ src/scgenerator/physics/materials.py | 2 - src/scgenerator/utils/evaluator.py | 341 +++++++++++++++++++++++++++ src/scgenerator/utils/parameter.py | 71 +++--- 6 files changed, 746 insertions(+), 189 deletions(-) create mode 100644 func_rewrite.py create mode 100644 src/scgenerator/utils/evaluator.py diff --git a/func_rewrite.py b/func_rewrite.py new file mode 100644 index 0000000..8ccc359 --- /dev/null +++ b/func_rewrite.py @@ -0,0 +1,47 @@ +from typing import Callable +import inspect +import re + + +def get_arg_names(func: Callable) -> list[str]: + spec = inspect.getfullargspec(func) + args = spec.args + if spec.defaults is not None and len(spec.defaults) > 0: + args = args[: -len(spec.defaults)] + return args + + +def validate_arg_names(names: list[str]): + for n in names: + if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None: + raise ValueError(f"{n} is an invalid parameter name") + + +def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None): + if arg_names is None: + arg_names = get_arg_names(func) + else: + validate_arg_names(arg_names) + validate_arg_names(kwarg_names) + sign_arg_str = ", ".join(arg_names + kwarg_names) + call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names]) + tmp_name = f"{func.__name__}_0" + func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})" + scope = dict(__func__=func) + exec(func_str, scope) + return scope[tmp_name] + + +def lol(a, b=None, c=None): + print(f"{a=}, {b=}, {c=}") + + +def main(): + lol1 = func_rewrite(lol, ["c"]) + print(inspect.getfullargspec(lol1)) + lol2 = func_rewrite(lol, ["b"]) + print(inspect.getfullargspec(lol2)) + + +if __name__ == "__main__": + main() diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 423390f..680d867 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -14,9 +14,11 @@ from .errors import * from .logger import get_logger from .math import power_fact from .physics import fiber, pulse, units -from .utils import override_config, required_simulations +from .utils import override_config, required_simulations, evaluator from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters +global_evaluator = evaluator.Evaluator() + @dataclass class Params(BareParams): @@ -541,65 +543,65 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]: return previous.name, count_variations(*configs) -def wspace(t, t_num=0): - """frequency array such that x(t) <-> np.fft(x)(w) - Parameters - ---------- - t : float or array - float : total width of the time window - array : time array - t_num : int- - if t is a float, specifies the number of points - Returns - ---------- - w : array - linspace of frencies corresponding to t - """ - if isinstance(t, (np.ndarray, list, tuple)): - dt = t[1] - t[0] - t_num = len(t) - t = t[-1] - t[0] + dt - else: - dt = t / t_num - w = 2 * pi * np.arange(t_num) / t - w = np.where(w >= pi / dt, w - 2 * pi / dt, w) - return w +# def wspace(t, t_num=0): +# """frequency array such that x(t) <-> np.fft(x)(w) +# Parameters +# ---------- +# t : float or array +# float : total width of the time window +# array : time array +# t_num : int- +# if t is a float, specifies the number of points +# Returns +# ---------- +# w : array +# linspace of frencies corresponding to t +# """ +# if isinstance(t, (np.ndarray, list, tuple)): +# dt = t[1] - t[0] +# t_num = len(t) +# t = t[-1] - t[0] + dt +# else: +# dt = t / t_num +# w = 2 * pi * np.arange(t_num) / t +# w = np.where(w >= pi / dt, w - 2 * pi / dt, w) +# return w -def tspace(time_window=None, t_num=None, dt=None): - """returns a time array centered on 0 - Parameters - ---------- - time_window : float - total time spanned - t_num : int - number of points - dt : float - time resolution +# def tspace(time_window=None, t_num=None, dt=None): +# """returns a time array centered on 0 +# Parameters +# ---------- +# time_window : float +# total time spanned +# t_num : int +# number of points +# dt : float +# time resolution - at least 2 arguments must be given. They are prioritize as such - t_num > time_window > dt +# at least 2 arguments must be given. They are prioritize as such +# t_num > time_window > dt - Returns - ------- - t : array - a linearily spaced time array - Raises - ------ - TypeError - missing at least 1 argument - """ - if t_num is not None: - if isinstance(time_window, (float, int)): - return np.linspace(-time_window / 2, time_window / 2, int(t_num)) - elif isinstance(dt, (float, int)): - time_window = (t_num - 1) * dt - return np.linspace(-time_window / 2, time_window / 2, t_num) - elif isinstance(time_window, (float, int)) and isinstance(dt, (float, int)): - t_num = int(time_window / dt) + 1 - return np.linspace(-time_window / 2, time_window / 2, t_num) - else: - raise TypeError("not enough parameter to determine time vector") +# Returns +# ------- +# t : array +# a linearily spaced time array +# Raises +# ------ +# TypeError +# missing at least 1 argument +# """ +# if t_num is not None: +# if isinstance(time_window, (float, int)): +# return np.linspace(-time_window / 2, time_window / 2, int(t_num)) +# elif isinstance(dt, (float, int)): +# time_window = (t_num - 1) * dt +# return np.linspace(-time_window / 2, time_window / 2, t_num) +# elif isinstance(time_window, (float, int)) and isinstance(dt, (float, int)): +# t_num = int(time_window / dt) + 1 +# return np.linspace(-time_window / 2, time_window / 2, t_num) +# else: +# raise TypeError("not enough parameter to determine time vector") def recover_params(params: BareParams, data_folder: Path) -> Params: @@ -620,115 +622,115 @@ def recover_params(params: BareParams, data_folder: Path) -> Params: return params -def build_sim_grid( - length: float, - z_num: int, - wavelength: float, - deg: int, - time_window: float = None, - t_num: int = None, - dt: float = None, -) -> tuple[ - np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray -]: - """computes a bunch of values that relate to the simulation grid +# def build_sim_grid( +# length: float, +# z_num: int, +# wavelength: float, +# deg: int, +# time_window: float = None, +# t_num: int = None, +# dt: float = None, +# ) -> tuple[ +# np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray +# ]: +# """computes a bunch of values that relate to the simulation grid - Parameters - ---------- - length : float - length of the fiber in m - z_num : int - number of spatial points - wavelength : float - pump wavelength in m - deg : int - dispersion interpolation degree - time_window : float, optional - total width of the temporal grid in s, by default None - t_num : int, optional - number of temporal grid points, by default None - dt : float, optional - spacing of the temporal grid in s, by default None +# Parameters +# ---------- +# length : float +# length of the fiber in m +# z_num : int +# number of spatial points +# wavelength : float +# pump wavelength in m +# deg : int +# dispersion interpolation degree +# time_window : float, optional +# total width of the temporal grid in s, by default None +# t_num : int, optional +# number of temporal grid points, by default None +# dt : float, optional +# spacing of the temporal grid in s, by default None - Returns - ------- - z_targets : np.ndarray, shape (z_num, ) - spatial points in m - t : np.ndarray, shape (t_num, ) - temporal points in s - time_window : float - total width of the temporal grid in s, by default None - t_num : int - number of temporal grid points, by default None - dt : float - spacing of the temporal grid in s, by default None - w_c : np.ndarray, shape (t_num, ) - centered angular frequencies in rad/s where 0 is the pump frequency - w0 : float - pump angular frequency - w : np.ndarray, shape (t_num, ) - actual angualr frequency grid in rad/s - w_power_fact : np.ndarray, shape (deg, t_num) - set of all the necessaray powers of w_c - l : np.ndarray, shape (t_num) - wavelengths in m - """ - t = tspace(time_window, t_num, dt) +# Returns +# ------- +# z_targets : np.ndarray, shape (z_num, ) +# spatial points in m +# t : np.ndarray, shape (t_num, ) +# temporal points in s +# time_window : float +# total width of the temporal grid in s, by default None +# t_num : int +# number of temporal grid points, by default None +# dt : float +# spacing of the temporal grid in s, by default None +# w_c : np.ndarray, shape (t_num, ) +# centered angular frequencies in rad/s where 0 is the pump frequency +# w0 : float +# pump angular frequency +# w : np.ndarray, shape (t_num, ) +# actual angualr frequency grid in rad/s +# w_power_fact : np.ndarray, shape (deg, t_num) +# set of all the necessaray powers of w_c +# l : np.ndarray, shape (t_num) +# wavelengths in m +# """ +# t = tspace(time_window, t_num, dt) - time_window = t.max() - t.min() - dt = t[1] - t[0] - t_num = len(t) - z_targets = np.linspace(0, length, z_num) - w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, deg) - l = units.To.m(w) - return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l +# time_window = t.max() - t.min() +# dt = t[1] - t[0] +# t_num = len(t) +# z_targets = np.linspace(0, length, z_num) +# w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, deg) +# l = units.To.m(w) +# return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l -def build_sim_grid_in_place(params: BareParams): - """similar to calling build_sim_grid, but sets the attributes in place""" - ( - params.z_targets, - params.t, - params.time_window, - params.t_num, - params.dt, - params.w_c, - params.w0, - params.w, - params.w_power_fact, - params.l, - ) = build_sim_grid( - params.length, - params.z_num, - params.wavelength, - params.interpolation_degree, - params.time_window, - params.t_num, - params.dt, - ) +# def build_sim_grid_in_place(params: BareParams): +# """similar to calling build_sim_grid, but sets the attributes in place""" +# ( +# params.z_targets, +# params.t, +# params.time_window, +# params.t_num, +# params.dt, +# params.w_c, +# params.w0, +# params.w, +# params.w_power_fact, +# params.l, +# ) = build_sim_grid( +# params.length, +# params.z_num, +# params.wavelength, +# params.interpolation_degree, +# params.time_window, +# params.t_num, +# params.dt, +# ) -def update_frequency_domain( - t: np.ndarray, wavelength: float, deg: int -) -> Tuple[np.ndarray, float, np.ndarray, np.ndarray]: - """updates the frequency grid +# def update_frequency_domain( +# t: np.ndarray, wavelength: float, deg: int +# ) -> Tuple[np.ndarray, float, np.ndarray, np.ndarray]: +# """updates the frequency grid - Parameters - ---------- - t : np.ndarray - time array - wavelength : float - wavelength - deg : int - interpolation degree of the dispersion +# Parameters +# ---------- +# t : np.ndarray +# time array +# wavelength : float +# wavelength +# deg : int +# interpolation degree of the dispersion - Returns - ------- - Tuple[np.ndarray, float, np.ndarray, np.ndarray] - w_c, w0, w, w_power_fact - """ - w_c = wspace(t) - w0 = units.m(wavelength) - w = w_c + w0 - w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)]) - return w_c, w0, w, w_power_fact +# Returns +# ------- +# Tuple[np.ndarray, float, np.ndarray, np.ndarray] +# w_c, w0, w, w_power_fact +# """ +# w_c = wspace(t) +# w0 = units.m(wavelength) +# w = w_c + w0 +# w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)]) +# return w_c, w0, w, w_power_fact diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 208f386..20db3dc 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -5,6 +5,9 @@ from scipy.interpolate import griddata, interp1d from scipy.special import jn_zeros from .utils.cache import np_cache +pi = np.pi +c = 299792458.0 + def span(*vec): """returns the min and max of whatever array-like is given. can accept many args""" @@ -218,3 +221,154 @@ def all_zeros(x: np.ndarray, y: np.ndarray) -> np.ndarray: pos = np.argwhere(y[1:] * y[:-1] < 0)[:, 0] m = (y[pos] - y[pos - 1]) / (x[pos] - x[pos - 1]) return -y[pos] / m + x[pos] + + +def wspace(t, t_num=0): + """frequency array such that x(t) <-> np.fft(x)(w) + Parameters + ---------- + t : float or array + float : total width of the time window + array : time array + t_num : int- + if t is a float, specifies the number of points + Returns + ---------- + w : array + linspace of frencies corresponding to t + """ + if isinstance(t, (np.ndarray, list, tuple)): + dt = t[1] - t[0] + t_num = len(t) + t = t[-1] - t[0] + dt + else: + dt = t / t_num + w = 2 * pi * np.arange(t_num) / t + w = np.where(w >= pi / dt, w - 2 * pi / dt, w) + return w + + +def tspace(time_window=None, t_num=None, dt=None): + """returns a time array centered on 0 + Parameters + ---------- + time_window : float + total time spanned + t_num : int + number of points + dt : float + time resolution + + at least 2 arguments must be given. They are prioritize as such + t_num > time_window > dt + + Returns + ------- + t : array + a linearily spaced time array + Raises + ------ + TypeError + missing at least 1 argument + """ + if t_num is not None: + if isinstance(time_window, (float, int)): + return np.linspace(-time_window / 2, time_window / 2, int(t_num)) + elif isinstance(dt, (float, int)): + time_window = (t_num - 1) * dt + return np.linspace(-time_window / 2, time_window / 2, t_num) + elif isinstance(time_window, (float, int)) and isinstance(dt, (float, int)): + t_num = int(time_window / dt) + 1 + return np.linspace(-time_window / 2, time_window / 2, t_num) + else: + raise TypeError("not enough parameter to determine time vector") + + +def build_sim_grid( + length: float, + z_num: int, + wavelength: float, + interpolation_degree: int, + time_window: float = None, + t_num: int = None, + dt: float = None, +) -> tuple[ + np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray +]: + """computes a bunch of values that relate to the simulation grid + + Parameters + ---------- + length : float + length of the fiber in m + z_num : int + number of spatial points + wavelength : float + pump wavelength in m + deg : int + dispersion interpolation degree + time_window : float, optional + total width of the temporal grid in s, by default None + t_num : int, optional + number of temporal grid points, by default None + dt : float, optional + spacing of the temporal grid in s, by default None + + Returns + ------- + z_targets : np.ndarray, shape (z_num, ) + spatial points in m + t : np.ndarray, shape (t_num, ) + temporal points in s + time_window : float + total width of the temporal grid in s, by default None + t_num : int + number of temporal grid points, by default None + dt : float + spacing of the temporal grid in s, by default None + w_c : np.ndarray, shape (t_num, ) + centered angular frequencies in rad/s where 0 is the pump frequency + w0 : float + pump angular frequency + w : np.ndarray, shape (t_num, ) + actual angualr frequency grid in rad/s + w_power_fact : np.ndarray, shape (deg, t_num) + set of all the necessaray powers of w_c + l : np.ndarray, shape (t_num) + wavelengths in m + """ + t = tspace(time_window, t_num, dt) + + time_window = t.max() - t.min() + dt = t[1] - t[0] + t_num = len(t) + z_targets = np.linspace(0, length, z_num) + w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, interpolation_degree) + l = 2 * pi * c / w + return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l + + +def update_frequency_domain( + t: np.ndarray, wavelength: float, deg: int +) -> tuple[np.ndarray, float, np.ndarray, np.ndarray]: + """updates the frequency grid + + Parameters + ---------- + t : np.ndarray + time array + wavelength : float + wavelength + deg : int + interpolation degree of the dispersion + + Returns + ------- + Tuple[np.ndarray, float, np.ndarray, np.ndarray] + w_c, w0, w, w_power_fact + """ + w_c = wspace(t) + w0 = 2 * pi * c / wavelength + w = w_c + w0 + w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)]) + return w_c, w0, w, w_power_fact diff --git a/src/scgenerator/physics/materials.py b/src/scgenerator/physics/materials.py index 7114626..f7a7ad3 100644 --- a/src/scgenerator/physics/materials.py +++ b/src/scgenerator/physics/materials.py @@ -3,8 +3,6 @@ import numpy as np import scipy.special from scipy.integrate import cumulative_trapezoid -from scgenerator import math - from ..logger import get_logger from . import units from .units import NA, c, kB, me, e, hbar diff --git a/src/scgenerator/utils/evaluator.py b/src/scgenerator/utils/evaluator.py new file mode 100644 index 0000000..f0c01e6 --- /dev/null +++ b/src/scgenerator/utils/evaluator.py @@ -0,0 +1,341 @@ +from collections import defaultdict +from typing import Any, Callable, Union +from typing import TypeVar, Optional +from dataclasses import dataclass +import numpy as np +import itertools +from functools import wraps +import re + +from ..physics import fiber, pulse, materials +from .. import math + +T = TypeVar("T") +import inspect + + +class Rule: + def __init__( + self, + target: Union[str, list[Optional[str]]], + func: Callable, + args: list[str] = None, + priorities: Union[int, list[int]] = None, + ): + targets = list(target) if isinstance(target, (list, tuple)) else [target] + self.func = func + if priorities is None: + priorities = [1] * len(targets) + elif isinstance(priorities, (int, float, np.integer, np.floating)): + priorities = [priorities] + self.targets = dict(zip(targets, priorities)) + if args is None: + args = get_arg_names(func) + self.args = args + + def __repr__(self) -> str: + return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})" + + @classmethod + def deduce( + cls, + target: Union[str, list[Optional[str]]], + func: Callable, + kwarg_names: list[str], + n_var: int, + args_const: list[str] = None, + ) -> list["Rule"]: + """given a function that doesn't need all its keyword arguemtn specified, will + return a list of Rule obj, one for each combination of n_var specified kwargs + + Parameters + ---------- + target : str | list[str | None] + name of the variable(s) that func returns + func : Callable + function to work with + kwarg_names : list[str] + list of all kwargs of the function to be used + n_var : int + how many shoulf be used per rule + arg_const : list[str], optional + override the name of the positional arguments + + Returns + ------- + list[Rule] + list of all possible rules + + Example + ------- + >> def lol(a, b=None, c=None): + pass + >> print(Rule.deduce(["d"], lol, ["b", "c"], 1)) + [ + Rule(targets={'d': 1}, func=, args=['a', 'b']), + Rule(targets={'d': 1}, func=, args=['a', 'c']) + ] + """ + rules: list[cls] = [] + for var_possibility in itertools.combinations(kwarg_names, n_var): + + new_func = func_rewrite(func, list(var_possibility), args_const) + + rules.append(cls(target, new_func)) + return rules + + +@dataclass +class EvalStat: + priority: float = np.inf + + +class Evaluator: + def __init__(self): + self.rules: dict[str, list[Rule]] = defaultdict(list) + self.params = {} + self.__curent_lookup = set() + self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat) + + def append(self, *rule: Rule): + for r in rule: + for t in r.targets: + if t is not None: + self.rules[t].append(r) + self.rules[t].sort(key=lambda el: el.targets[t], reverse=True) + + def update(self, **params: Any): + self.params.update(params) + for k in params: + self.eval_stats[k].priority = np.inf + + def reset(self): + self.params = {} + self.eval_stats = defaultdict(EvalStat) + + def compute(self, target: str) -> Any: + """computes a target + + Parameters + ---------- + target : str + name of the target + + Returns + ------- + Any + return type of the target function + + Raises + ------ + RecursionError + a cyclic dependence exists + KeyError + there is no saved rule for the target + """ + value = self.params.get(target) + if value is None: + if target in self.__curent_lookup: + raise RecursionError( + "cyclic dependency detected : " + f"{target!r} seems to depend on itself, " + f"please provide a value for at least one variable in {self.__curent_lookup}" + ) + else: + self.__curent_lookup.add(target) + + error = None + for rule in reversed(self.rules[target]): + try: + args = [self.compute(k) for k in rule.args] + returned_values = rule.func(*args) + if len(rule.targets) == 1: + self.params[target] = returned_values + self.eval_stats[target].priority = rule.targets[target] + value = returned_values + else: + for ((k, p), v) in zip(rule.targets.items(), returned_values): + if ( + k == target + or k not in self.params + or self.eval_stats[k].priority < p + ): + self.params[k] = v + self.eval_stats[k] = p + if k == target: + value = v + break + except (RecursionError, KeyError) as e: + error = e + continue + + if value is None and error is not None: + raise error + + self.__curent_lookup.remove(target) + return value + + def __call__(self, target: str, args: list[str] = None): + """creates a wrapper that adds decorated functions to the set of rules + + Parameters + ---------- + target : str + name of the target + args : list[str], optional + list of name of arguments. Automatically deduced from function signature if + not provided, by default None + """ + + def wrapper(func): + self.append(Rule(target, func, args)) + return func + + return wrapper + + +def get_arg_names(func: Callable) -> list[str]: + spec = inspect.getfullargspec(func) + args = spec.args + if spec.defaults is not None and len(spec.defaults) > 0: + args = args[: -len(spec.defaults)] + return args + + +def validate_arg_names(names: list[str]): + for n in names: + if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None: + raise ValueError(f"{n} is an invalid parameter name") + + +def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None): + if arg_names is None: + arg_names = get_arg_names(func) + else: + validate_arg_names(arg_names) + validate_arg_names(kwarg_names) + sign_arg_str = ", ".join(arg_names + kwarg_names) + call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names]) + tmp_name = f"{func.__name__}_0" + func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})" + scope = dict(__func__=func) + exec(func_str, scope) + return scope[tmp_name] + + +default_rules: list[Rule] = [ + *Rule.deduce( + ["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "w_power_fact", "l"], + math.build_sim_grid, + ["time_window", "t_num", "dt"], + 2, + ) +] +""" + Rule("gamma", fiber.gamma_parameter), + Rule("gamma", lambda gamma_arr: gamma_arr[0]), + Rule(["beta", "gamma", "interp_range"], fiber.PCF_dispersion), + Rule("n2"), + Rule("loss"), + Rule("loss_file"), + Rule("effective_mode_diameter"), + Rule("A_eff"), + Rule("A_eff_file"), + Rule("pitch"), + Rule("pitch_ratio"), + Rule("core_radius"), + Rule("he_mode"), + Rule("fit_parameters"), + Rule("beta"), + Rule("dispersion_file"), + Rule("model"), + Rule("length"), + Rule("capillary_num"), + Rule("capillary_outer_d"), + Rule("capillary_thickness"), + Rule("capillary_spacing"), + Rule("capillary_resonance_strengths"), + Rule("capillary_nested"), + Rule("gas_name"), + Rule("pressure"), + Rule("temperature"), + Rule("plasma_density"), + Rule("field_file"), + Rule("repetition_rate"), + Rule("peak_power"), + Rule("mean_power"), + Rule("energy"), + Rule("soliton_num"), + Rule("quantum_noise"), + Rule("shape"), + Rule("wavelength"), + Rule("intensity_noise"), + Rule("width"), + Rule("t0"), + Rule("behaviors"), + Rule("parallel"), + Rule("raman_type"), + Rule("ideal_gas"), + Rule("repeat"), + Rule("t_num"), + Rule("z_num"), + Rule("time_window"), + Rule("dt"), + Rule("tolerated_error"), + Rule("step_size"), + Rule("lower_wavelength_interp_limit"), + Rule("upper_wavelength_interp_limit"), + Rule("interpolation_degree"), + Rule("prev_sim_dir"), + Rule("recovery_last_stored"), + Rule("worker_num"), + Rule("field_0"), + Rule("spec_0"), + Rule("alpha"), + Rule("gamma_arr"), + Rule("A_eff_arr"), + Rule("w"), + Rule("l"), + Rule("w_c"), + Rule("w0"), + Rule("w_power_fact"), + Rule("t"), + Rule("L_D"), + Rule("L_NL"), + Rule("L_sol"), + Rule("dynamic_dispersion"), + Rule("adapt_step_size"), + Rule("error_ok"), + Rule("hr_w"), + Rule("z_targets"), + Rule("const_qty"), + Rule("beta_func"), + Rule("gamma_func"), + Rule("interp_range"), + Rule("datetime"), + Rule("version"), +] +""" + + +def main(): + + evalor = Evaluator() + evalor.append(*default_rules) + evalor.update( + **{ + "length": 1, + "z_num": 128, + "wavelength": 1500e-9, + "interpolation_degree": 8, + "t_num": 16384, + "dt": 1e-15, + } + ) + evalor.compute("z_targets") + print(evalor.params.keys()) + print(evalor.params["l"][evalor.params["l"] > 0].min()) + + +if __name__ == "__main__": + main() diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/utils/parameter.py index 338840d..efedb72 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/utils/parameter.py @@ -8,6 +8,9 @@ import numpy as np from ..const import __version__ +# from .evaluator import Rule, Evaluator +# from ..physics import pulse, fiber, materials + T = TypeVar("T") # Validator @@ -187,7 +190,7 @@ def translate(p_name: str, p_value: T) -> tuple[str, T]: class Parameter: - def __init__(self, validator, converter=None, default=None, display_info=None): + def __init__(self, validator, converter=None, default=None, display_info=None, rules=None): """Single parameter Parameters @@ -208,6 +211,10 @@ class Parameter: self.converter = converter self.default = default self.display_info = display_info + if rules is None: + self.rules = [] + else: + self.rules = rules def __set_name__(self, owner, name): self.name = name @@ -344,14 +351,14 @@ class BareParams: """ # root - name: str = Parameter(string) + name: str = Parameter(string, default="no name") prev_data_dir: str = Parameter(string) previous_config_file: str = Parameter(string) # # fiber - input_transmission: float = Parameter(in_range_incl(0, 1)) + input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0) gamma: float = Parameter(non_negative(float, int)) - n2: float = Parameter(non_negative(float, int)) + n2: float = Parameter(non_negative(float, int), default=2.2e-20) loss: str = Parameter(literal("capillary")) loss_file: str = Parameter(string) effective_mode_diameter: float = Parameter(positive(float, int)) @@ -360,58 +367,66 @@ class BareParams: pitch: float = Parameter(in_range_excl(0, 1e-3)) pitch_ratio: float = Parameter(in_range_excl(0, 1)) core_radius: float = Parameter(in_range_excl(0, 1e-3)) - he_mode: Tuple[int, int] = Parameter(int_pair) - fit_parameters: Tuple[int, int] = Parameter(int_pair) + he_mode: Tuple[int, int] = Parameter(int_pair, default=(1, 1)) + fit_parameters: Tuple[int, int] = Parameter(int_pair, default=(0.08, 200e-9)) beta: Iterable[float] = Parameter(num_list) dispersion_file: str = Parameter(string) - model: str = Parameter(literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom")) - length: float = Parameter(non_negative(float, int)) + model: str = Parameter( + literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"), default="custom" + ) + length: float = Parameter(non_negative(float, int), default=1.0) capillary_num: int = Parameter(positive(int)) capillary_outer_d: float = Parameter(in_range_excl(0, 1e-3)) capillary_thickness: float = Parameter(in_range_excl(0, 1e-3)) capillary_spacing: float = Parameter(in_range_excl(0, 1e-3)) - capillary_resonance_strengths: Iterable[float] = Parameter(num_list) - capillary_nested: int = Parameter(non_negative(int)) + capillary_resonance_strengths: Iterable[float] = Parameter(num_list, default=[]) + capillary_nested: int = Parameter(non_negative(int), default=0) # gas - gas_name: str = Parameter(string, converter=str.lower) + gas_name: str = Parameter(string, converter=str.lower, default="vacuum") pressure: Union[float, Iterable[float]] = Parameter( - validator_or(non_negative(float, int), num_list), display_info=(1e-5, "bar") + validator_or(non_negative(float, int), num_list), display_info=(1e-5, "bar"), default=1e5 ) - temperature: float = Parameter(positive(float, int), display_info=(1, "K")) - plasma_density: float = Parameter(non_negative(float, int)) + temperature: float = Parameter(positive(float, int), display_info=(1, "K"), default=300) + plasma_density: float = Parameter(non_negative(float, int), default=0) # pulse field_file: str = Parameter(string) - repetition_rate: float = Parameter(non_negative(float, int), display_info=(1e-6, "MHz")) + repetition_rate: float = Parameter( + non_negative(float, int), display_info=(1e-6, "MHz"), default=40e6 + ) peak_power: float = Parameter(positive(float, int), display_info=(1e-3, "kW")) mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW")) energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ")) soliton_num: float = Parameter(non_negative(float, int)) - quantum_noise: bool = Parameter(boolean) - shape: str = Parameter(literal("gaussian", "sech")) + quantum_noise: bool = Parameter(boolean, default=False) + shape: str = Parameter(literal("gaussian", "sech"), default="gaussian") wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm")) - intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%")) + intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%"), default=0) width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) # simulation - behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss"))) - parallel: bool = Parameter(boolean) - raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower) - ideal_gas: bool = Parameter(boolean) - repeat: int = Parameter(positive(int)) + behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss")), default=["spm", "ss"]) + parallel: bool = Parameter(boolean, default=True) + raman_type: str = Parameter( + literal("measured", "agrawal", "stolen"), converter=str.lower, default="agrawal" + ) + ideal_gas: bool = Parameter(boolean, default=False) + repeat: int = Parameter(positive(int), default=1) t_num: int = Parameter(positive(int)) z_num: int = Parameter(positive(int)) time_window: float = Parameter(positive(float, int)) dt: float = Parameter(in_range_excl(0, 5e-15)) - tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3)) + tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11) step_size: float = Parameter(positive(float, int)) - lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9)) - upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9)) - interpolation_degree: int = Parameter(positive(int)) + lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9), default=100e-9) + upper_wavelength_interp_limit: float = Parameter( + in_range_incl(200e-9, 5000e-9), default=2000e-9 + ) + interpolation_degree: int = Parameter(positive(int), default=8) prev_sim_dir: str = Parameter(string) - recovery_last_stored: int = Parameter(non_negative(int)) + recovery_last_stored: int = Parameter(non_negative(int), default=0) worker_num: int = Parameter(positive(int)) # computed