solid ground work
This commit is contained in:
47
func_rewrite.py
Normal file
47
func_rewrite.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Callable
|
||||
import inspect
|
||||
import re
|
||||
|
||||
|
||||
def get_arg_names(func: Callable) -> list[str]:
|
||||
spec = inspect.getfullargspec(func)
|
||||
args = spec.args
|
||||
if spec.defaults is not None and len(spec.defaults) > 0:
|
||||
args = args[: -len(spec.defaults)]
|
||||
return args
|
||||
|
||||
|
||||
def validate_arg_names(names: list[str]):
|
||||
for n in names:
|
||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
validate_arg_names(arg_names)
|
||||
validate_arg_names(kwarg_names)
|
||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||
tmp_name = f"{func.__name__}_0"
|
||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||
scope = dict(__func__=func)
|
||||
exec(func_str, scope)
|
||||
return scope[tmp_name]
|
||||
|
||||
|
||||
def lol(a, b=None, c=None):
|
||||
print(f"{a=}, {b=}, {c=}")
|
||||
|
||||
|
||||
def main():
|
||||
lol1 = func_rewrite(lol, ["c"])
|
||||
print(inspect.getfullargspec(lol1))
|
||||
lol2 = func_rewrite(lol, ["b"])
|
||||
print(inspect.getfullargspec(lol2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,9 +14,11 @@ from .errors import *
|
||||
from .logger import get_logger
|
||||
from .math import power_fact
|
||||
from .physics import fiber, pulse, units
|
||||
from .utils import override_config, required_simulations
|
||||
from .utils import override_config, required_simulations, evaluator
|
||||
from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
|
||||
|
||||
global_evaluator = evaluator.Evaluator()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Params(BareParams):
|
||||
@@ -541,65 +543,65 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]:
|
||||
return previous.name, count_variations(*configs)
|
||||
|
||||
|
||||
def wspace(t, t_num=0):
|
||||
"""frequency array such that x(t) <-> np.fft(x)(w)
|
||||
Parameters
|
||||
----------
|
||||
t : float or array
|
||||
float : total width of the time window
|
||||
array : time array
|
||||
t_num : int-
|
||||
if t is a float, specifies the number of points
|
||||
Returns
|
||||
----------
|
||||
w : array
|
||||
linspace of frencies corresponding to t
|
||||
"""
|
||||
if isinstance(t, (np.ndarray, list, tuple)):
|
||||
dt = t[1] - t[0]
|
||||
t_num = len(t)
|
||||
t = t[-1] - t[0] + dt
|
||||
else:
|
||||
dt = t / t_num
|
||||
w = 2 * pi * np.arange(t_num) / t
|
||||
w = np.where(w >= pi / dt, w - 2 * pi / dt, w)
|
||||
return w
|
||||
# def wspace(t, t_num=0):
|
||||
# """frequency array such that x(t) <-> np.fft(x)(w)
|
||||
# Parameters
|
||||
# ----------
|
||||
# t : float or array
|
||||
# float : total width of the time window
|
||||
# array : time array
|
||||
# t_num : int-
|
||||
# if t is a float, specifies the number of points
|
||||
# Returns
|
||||
# ----------
|
||||
# w : array
|
||||
# linspace of frencies corresponding to t
|
||||
# """
|
||||
# if isinstance(t, (np.ndarray, list, tuple)):
|
||||
# dt = t[1] - t[0]
|
||||
# t_num = len(t)
|
||||
# t = t[-1] - t[0] + dt
|
||||
# else:
|
||||
# dt = t / t_num
|
||||
# w = 2 * pi * np.arange(t_num) / t
|
||||
# w = np.where(w >= pi / dt, w - 2 * pi / dt, w)
|
||||
# return w
|
||||
|
||||
|
||||
def tspace(time_window=None, t_num=None, dt=None):
|
||||
"""returns a time array centered on 0
|
||||
Parameters
|
||||
----------
|
||||
time_window : float
|
||||
total time spanned
|
||||
t_num : int
|
||||
number of points
|
||||
dt : float
|
||||
time resolution
|
||||
# def tspace(time_window=None, t_num=None, dt=None):
|
||||
# """returns a time array centered on 0
|
||||
# Parameters
|
||||
# ----------
|
||||
# time_window : float
|
||||
# total time spanned
|
||||
# t_num : int
|
||||
# number of points
|
||||
# dt : float
|
||||
# time resolution
|
||||
|
||||
at least 2 arguments must be given. They are prioritize as such
|
||||
t_num > time_window > dt
|
||||
# at least 2 arguments must be given. They are prioritize as such
|
||||
# t_num > time_window > dt
|
||||
|
||||
Returns
|
||||
-------
|
||||
t : array
|
||||
a linearily spaced time array
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
missing at least 1 argument
|
||||
"""
|
||||
if t_num is not None:
|
||||
if isinstance(time_window, (float, int)):
|
||||
return np.linspace(-time_window / 2, time_window / 2, int(t_num))
|
||||
elif isinstance(dt, (float, int)):
|
||||
time_window = (t_num - 1) * dt
|
||||
return np.linspace(-time_window / 2, time_window / 2, t_num)
|
||||
elif isinstance(time_window, (float, int)) and isinstance(dt, (float, int)):
|
||||
t_num = int(time_window / dt) + 1
|
||||
return np.linspace(-time_window / 2, time_window / 2, t_num)
|
||||
else:
|
||||
raise TypeError("not enough parameter to determine time vector")
|
||||
# Returns
|
||||
# -------
|
||||
# t : array
|
||||
# a linearily spaced time array
|
||||
# Raises
|
||||
# ------
|
||||
# TypeError
|
||||
# missing at least 1 argument
|
||||
# """
|
||||
# if t_num is not None:
|
||||
# if isinstance(time_window, (float, int)):
|
||||
# return np.linspace(-time_window / 2, time_window / 2, int(t_num))
|
||||
# elif isinstance(dt, (float, int)):
|
||||
# time_window = (t_num - 1) * dt
|
||||
# return np.linspace(-time_window / 2, time_window / 2, t_num)
|
||||
# elif isinstance(time_window, (float, int)) and isinstance(dt, (float, int)):
|
||||
# t_num = int(time_window / dt) + 1
|
||||
# return np.linspace(-time_window / 2, time_window / 2, t_num)
|
||||
# else:
|
||||
# raise TypeError("not enough parameter to determine time vector")
|
||||
|
||||
|
||||
def recover_params(params: BareParams, data_folder: Path) -> Params:
|
||||
@@ -620,115 +622,115 @@ def recover_params(params: BareParams, data_folder: Path) -> Params:
|
||||
return params
|
||||
|
||||
|
||||
def build_sim_grid(
|
||||
length: float,
|
||||
z_num: int,
|
||||
wavelength: float,
|
||||
deg: int,
|
||||
time_window: float = None,
|
||||
t_num: int = None,
|
||||
dt: float = None,
|
||||
) -> tuple[
|
||||
np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray
|
||||
]:
|
||||
"""computes a bunch of values that relate to the simulation grid
|
||||
# def build_sim_grid(
|
||||
# length: float,
|
||||
# z_num: int,
|
||||
# wavelength: float,
|
||||
# deg: int,
|
||||
# time_window: float = None,
|
||||
# t_num: int = None,
|
||||
# dt: float = None,
|
||||
# ) -> tuple[
|
||||
# np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray
|
||||
# ]:
|
||||
# """computes a bunch of values that relate to the simulation grid
|
||||
|
||||
Parameters
|
||||
----------
|
||||
length : float
|
||||
length of the fiber in m
|
||||
z_num : int
|
||||
number of spatial points
|
||||
wavelength : float
|
||||
pump wavelength in m
|
||||
deg : int
|
||||
dispersion interpolation degree
|
||||
time_window : float, optional
|
||||
total width of the temporal grid in s, by default None
|
||||
t_num : int, optional
|
||||
number of temporal grid points, by default None
|
||||
dt : float, optional
|
||||
spacing of the temporal grid in s, by default None
|
||||
# Parameters
|
||||
# ----------
|
||||
# length : float
|
||||
# length of the fiber in m
|
||||
# z_num : int
|
||||
# number of spatial points
|
||||
# wavelength : float
|
||||
# pump wavelength in m
|
||||
# deg : int
|
||||
# dispersion interpolation degree
|
||||
# time_window : float, optional
|
||||
# total width of the temporal grid in s, by default None
|
||||
# t_num : int, optional
|
||||
# number of temporal grid points, by default None
|
||||
# dt : float, optional
|
||||
# spacing of the temporal grid in s, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
z_targets : np.ndarray, shape (z_num, )
|
||||
spatial points in m
|
||||
t : np.ndarray, shape (t_num, )
|
||||
temporal points in s
|
||||
time_window : float
|
||||
total width of the temporal grid in s, by default None
|
||||
t_num : int
|
||||
number of temporal grid points, by default None
|
||||
dt : float
|
||||
spacing of the temporal grid in s, by default None
|
||||
w_c : np.ndarray, shape (t_num, )
|
||||
centered angular frequencies in rad/s where 0 is the pump frequency
|
||||
w0 : float
|
||||
pump angular frequency
|
||||
w : np.ndarray, shape (t_num, )
|
||||
actual angualr frequency grid in rad/s
|
||||
w_power_fact : np.ndarray, shape (deg, t_num)
|
||||
set of all the necessaray powers of w_c
|
||||
l : np.ndarray, shape (t_num)
|
||||
wavelengths in m
|
||||
"""
|
||||
t = tspace(time_window, t_num, dt)
|
||||
# Returns
|
||||
# -------
|
||||
# z_targets : np.ndarray, shape (z_num, )
|
||||
# spatial points in m
|
||||
# t : np.ndarray, shape (t_num, )
|
||||
# temporal points in s
|
||||
# time_window : float
|
||||
# total width of the temporal grid in s, by default None
|
||||
# t_num : int
|
||||
# number of temporal grid points, by default None
|
||||
# dt : float
|
||||
# spacing of the temporal grid in s, by default None
|
||||
# w_c : np.ndarray, shape (t_num, )
|
||||
# centered angular frequencies in rad/s where 0 is the pump frequency
|
||||
# w0 : float
|
||||
# pump angular frequency
|
||||
# w : np.ndarray, shape (t_num, )
|
||||
# actual angualr frequency grid in rad/s
|
||||
# w_power_fact : np.ndarray, shape (deg, t_num)
|
||||
# set of all the necessaray powers of w_c
|
||||
# l : np.ndarray, shape (t_num)
|
||||
# wavelengths in m
|
||||
# """
|
||||
# t = tspace(time_window, t_num, dt)
|
||||
|
||||
time_window = t.max() - t.min()
|
||||
dt = t[1] - t[0]
|
||||
t_num = len(t)
|
||||
z_targets = np.linspace(0, length, z_num)
|
||||
w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, deg)
|
||||
l = units.To.m(w)
|
||||
return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l
|
||||
# time_window = t.max() - t.min()
|
||||
# dt = t[1] - t[0]
|
||||
# t_num = len(t)
|
||||
# z_targets = np.linspace(0, length, z_num)
|
||||
# w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, deg)
|
||||
# l = units.To.m(w)
|
||||
# return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l
|
||||
|
||||
|
||||
def build_sim_grid_in_place(params: BareParams):
|
||||
"""similar to calling build_sim_grid, but sets the attributes in place"""
|
||||
(
|
||||
params.z_targets,
|
||||
params.t,
|
||||
params.time_window,
|
||||
params.t_num,
|
||||
params.dt,
|
||||
params.w_c,
|
||||
params.w0,
|
||||
params.w,
|
||||
params.w_power_fact,
|
||||
params.l,
|
||||
) = build_sim_grid(
|
||||
params.length,
|
||||
params.z_num,
|
||||
params.wavelength,
|
||||
params.interpolation_degree,
|
||||
params.time_window,
|
||||
params.t_num,
|
||||
params.dt,
|
||||
)
|
||||
# def build_sim_grid_in_place(params: BareParams):
|
||||
# """similar to calling build_sim_grid, but sets the attributes in place"""
|
||||
# (
|
||||
# params.z_targets,
|
||||
# params.t,
|
||||
# params.time_window,
|
||||
# params.t_num,
|
||||
# params.dt,
|
||||
# params.w_c,
|
||||
# params.w0,
|
||||
# params.w,
|
||||
# params.w_power_fact,
|
||||
# params.l,
|
||||
# ) = build_sim_grid(
|
||||
# params.length,
|
||||
# params.z_num,
|
||||
# params.wavelength,
|
||||
# params.interpolation_degree,
|
||||
# params.time_window,
|
||||
# params.t_num,
|
||||
# params.dt,
|
||||
# )
|
||||
|
||||
|
||||
def update_frequency_domain(
|
||||
t: np.ndarray, wavelength: float, deg: int
|
||||
) -> Tuple[np.ndarray, float, np.ndarray, np.ndarray]:
|
||||
"""updates the frequency grid
|
||||
# def update_frequency_domain(
|
||||
# t: np.ndarray, wavelength: float, deg: int
|
||||
# ) -> Tuple[np.ndarray, float, np.ndarray, np.ndarray]:
|
||||
# """updates the frequency grid
|
||||
|
||||
Parameters
|
||||
----------
|
||||
t : np.ndarray
|
||||
time array
|
||||
wavelength : float
|
||||
wavelength
|
||||
deg : int
|
||||
interpolation degree of the dispersion
|
||||
# Parameters
|
||||
# ----------
|
||||
# t : np.ndarray
|
||||
# time array
|
||||
# wavelength : float
|
||||
# wavelength
|
||||
# deg : int
|
||||
# interpolation degree of the dispersion
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[np.ndarray, float, np.ndarray, np.ndarray]
|
||||
w_c, w0, w, w_power_fact
|
||||
"""
|
||||
w_c = wspace(t)
|
||||
w0 = units.m(wavelength)
|
||||
w = w_c + w0
|
||||
w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)])
|
||||
return w_c, w0, w, w_power_fact
|
||||
# Returns
|
||||
# -------
|
||||
# Tuple[np.ndarray, float, np.ndarray, np.ndarray]
|
||||
# w_c, w0, w, w_power_fact
|
||||
# """
|
||||
# w_c = wspace(t)
|
||||
# w0 = units.m(wavelength)
|
||||
# w = w_c + w0
|
||||
# w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)])
|
||||
# return w_c, w0, w, w_power_fact
|
||||
|
||||
@@ -5,6 +5,9 @@ from scipy.interpolate import griddata, interp1d
|
||||
from scipy.special import jn_zeros
|
||||
from .utils.cache import np_cache
|
||||
|
||||
pi = np.pi
|
||||
c = 299792458.0
|
||||
|
||||
|
||||
def span(*vec):
|
||||
"""returns the min and max of whatever array-like is given. can accept many args"""
|
||||
@@ -218,3 +221,154 @@ def all_zeros(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
pos = np.argwhere(y[1:] * y[:-1] < 0)[:, 0]
|
||||
m = (y[pos] - y[pos - 1]) / (x[pos] - x[pos - 1])
|
||||
return -y[pos] / m + x[pos]
|
||||
|
||||
|
||||
def wspace(t, t_num=0):
|
||||
"""frequency array such that x(t) <-> np.fft(x)(w)
|
||||
Parameters
|
||||
----------
|
||||
t : float or array
|
||||
float : total width of the time window
|
||||
array : time array
|
||||
t_num : int-
|
||||
if t is a float, specifies the number of points
|
||||
Returns
|
||||
----------
|
||||
w : array
|
||||
linspace of frencies corresponding to t
|
||||
"""
|
||||
if isinstance(t, (np.ndarray, list, tuple)):
|
||||
dt = t[1] - t[0]
|
||||
t_num = len(t)
|
||||
t = t[-1] - t[0] + dt
|
||||
else:
|
||||
dt = t / t_num
|
||||
w = 2 * pi * np.arange(t_num) / t
|
||||
w = np.where(w >= pi / dt, w - 2 * pi / dt, w)
|
||||
return w
|
||||
|
||||
|
||||
def tspace(time_window=None, t_num=None, dt=None):
|
||||
"""returns a time array centered on 0
|
||||
Parameters
|
||||
----------
|
||||
time_window : float
|
||||
total time spanned
|
||||
t_num : int
|
||||
number of points
|
||||
dt : float
|
||||
time resolution
|
||||
|
||||
at least 2 arguments must be given. They are prioritize as such
|
||||
t_num > time_window > dt
|
||||
|
||||
Returns
|
||||
-------
|
||||
t : array
|
||||
a linearily spaced time array
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
missing at least 1 argument
|
||||
"""
|
||||
if t_num is not None:
|
||||
if isinstance(time_window, (float, int)):
|
||||
return np.linspace(-time_window / 2, time_window / 2, int(t_num))
|
||||
elif isinstance(dt, (float, int)):
|
||||
time_window = (t_num - 1) * dt
|
||||
return np.linspace(-time_window / 2, time_window / 2, t_num)
|
||||
elif isinstance(time_window, (float, int)) and isinstance(dt, (float, int)):
|
||||
t_num = int(time_window / dt) + 1
|
||||
return np.linspace(-time_window / 2, time_window / 2, t_num)
|
||||
else:
|
||||
raise TypeError("not enough parameter to determine time vector")
|
||||
|
||||
|
||||
def build_sim_grid(
|
||||
length: float,
|
||||
z_num: int,
|
||||
wavelength: float,
|
||||
interpolation_degree: int,
|
||||
time_window: float = None,
|
||||
t_num: int = None,
|
||||
dt: float = None,
|
||||
) -> tuple[
|
||||
np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray
|
||||
]:
|
||||
"""computes a bunch of values that relate to the simulation grid
|
||||
|
||||
Parameters
|
||||
----------
|
||||
length : float
|
||||
length of the fiber in m
|
||||
z_num : int
|
||||
number of spatial points
|
||||
wavelength : float
|
||||
pump wavelength in m
|
||||
deg : int
|
||||
dispersion interpolation degree
|
||||
time_window : float, optional
|
||||
total width of the temporal grid in s, by default None
|
||||
t_num : int, optional
|
||||
number of temporal grid points, by default None
|
||||
dt : float, optional
|
||||
spacing of the temporal grid in s, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
z_targets : np.ndarray, shape (z_num, )
|
||||
spatial points in m
|
||||
t : np.ndarray, shape (t_num, )
|
||||
temporal points in s
|
||||
time_window : float
|
||||
total width of the temporal grid in s, by default None
|
||||
t_num : int
|
||||
number of temporal grid points, by default None
|
||||
dt : float
|
||||
spacing of the temporal grid in s, by default None
|
||||
w_c : np.ndarray, shape (t_num, )
|
||||
centered angular frequencies in rad/s where 0 is the pump frequency
|
||||
w0 : float
|
||||
pump angular frequency
|
||||
w : np.ndarray, shape (t_num, )
|
||||
actual angualr frequency grid in rad/s
|
||||
w_power_fact : np.ndarray, shape (deg, t_num)
|
||||
set of all the necessaray powers of w_c
|
||||
l : np.ndarray, shape (t_num)
|
||||
wavelengths in m
|
||||
"""
|
||||
t = tspace(time_window, t_num, dt)
|
||||
|
||||
time_window = t.max() - t.min()
|
||||
dt = t[1] - t[0]
|
||||
t_num = len(t)
|
||||
z_targets = np.linspace(0, length, z_num)
|
||||
w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, interpolation_degree)
|
||||
l = 2 * pi * c / w
|
||||
return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l
|
||||
|
||||
|
||||
def update_frequency_domain(
|
||||
t: np.ndarray, wavelength: float, deg: int
|
||||
) -> tuple[np.ndarray, float, np.ndarray, np.ndarray]:
|
||||
"""updates the frequency grid
|
||||
|
||||
Parameters
|
||||
----------
|
||||
t : np.ndarray
|
||||
time array
|
||||
wavelength : float
|
||||
wavelength
|
||||
deg : int
|
||||
interpolation degree of the dispersion
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[np.ndarray, float, np.ndarray, np.ndarray]
|
||||
w_c, w0, w, w_power_fact
|
||||
"""
|
||||
w_c = wspace(t)
|
||||
w0 = 2 * pi * c / wavelength
|
||||
w = w_c + w0
|
||||
w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)])
|
||||
return w_c, w0, w, w_power_fact
|
||||
|
||||
@@ -3,8 +3,6 @@ import numpy as np
|
||||
import scipy.special
|
||||
from scipy.integrate import cumulative_trapezoid
|
||||
|
||||
from scgenerator import math
|
||||
|
||||
from ..logger import get_logger
|
||||
from . import units
|
||||
from .units import NA, c, kB, me, e, hbar
|
||||
|
||||
341
src/scgenerator/utils/evaluator.py
Normal file
341
src/scgenerator/utils/evaluator.py
Normal file
@@ -0,0 +1,341 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Union
|
||||
from typing import TypeVar, Optional
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import itertools
|
||||
from functools import wraps
|
||||
import re
|
||||
|
||||
from ..physics import fiber, pulse, materials
|
||||
from .. import math
|
||||
|
||||
T = TypeVar("T")
|
||||
import inspect
|
||||
|
||||
|
||||
class Rule:
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[str, list[Optional[str]]],
|
||||
func: Callable,
|
||||
args: list[str] = None,
|
||||
priorities: Union[int, list[int]] = None,
|
||||
):
|
||||
targets = list(target) if isinstance(target, (list, tuple)) else [target]
|
||||
self.func = func
|
||||
if priorities is None:
|
||||
priorities = [1] * len(targets)
|
||||
elif isinstance(priorities, (int, float, np.integer, np.floating)):
|
||||
priorities = [priorities]
|
||||
self.targets = dict(zip(targets, priorities))
|
||||
if args is None:
|
||||
args = get_arg_names(func)
|
||||
self.args = args
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})"
|
||||
|
||||
@classmethod
|
||||
def deduce(
|
||||
cls,
|
||||
target: Union[str, list[Optional[str]]],
|
||||
func: Callable,
|
||||
kwarg_names: list[str],
|
||||
n_var: int,
|
||||
args_const: list[str] = None,
|
||||
) -> list["Rule"]:
|
||||
"""given a function that doesn't need all its keyword arguemtn specified, will
|
||||
return a list of Rule obj, one for each combination of n_var specified kwargs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : str | list[str | None]
|
||||
name of the variable(s) that func returns
|
||||
func : Callable
|
||||
function to work with
|
||||
kwarg_names : list[str]
|
||||
list of all kwargs of the function to be used
|
||||
n_var : int
|
||||
how many shoulf be used per rule
|
||||
arg_const : list[str], optional
|
||||
override the name of the positional arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Rule]
|
||||
list of all possible rules
|
||||
|
||||
Example
|
||||
-------
|
||||
>> def lol(a, b=None, c=None):
|
||||
pass
|
||||
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
|
||||
[
|
||||
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
|
||||
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
|
||||
]
|
||||
"""
|
||||
rules: list[cls] = []
|
||||
for var_possibility in itertools.combinations(kwarg_names, n_var):
|
||||
|
||||
new_func = func_rewrite(func, list(var_possibility), args_const)
|
||||
|
||||
rules.append(cls(target, new_func))
|
||||
return rules
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalStat:
|
||||
priority: float = np.inf
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self):
|
||||
self.rules: dict[str, list[Rule]] = defaultdict(list)
|
||||
self.params = {}
|
||||
self.__curent_lookup = set()
|
||||
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
|
||||
|
||||
def append(self, *rule: Rule):
|
||||
for r in rule:
|
||||
for t in r.targets:
|
||||
if t is not None:
|
||||
self.rules[t].append(r)
|
||||
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
|
||||
|
||||
def update(self, **params: Any):
|
||||
self.params.update(params)
|
||||
for k in params:
|
||||
self.eval_stats[k].priority = np.inf
|
||||
|
||||
def reset(self):
|
||||
self.params = {}
|
||||
self.eval_stats = defaultdict(EvalStat)
|
||||
|
||||
def compute(self, target: str) -> Any:
|
||||
"""computes a target
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : str
|
||||
name of the target
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
return type of the target function
|
||||
|
||||
Raises
|
||||
------
|
||||
RecursionError
|
||||
a cyclic dependence exists
|
||||
KeyError
|
||||
there is no saved rule for the target
|
||||
"""
|
||||
value = self.params.get(target)
|
||||
if value is None:
|
||||
if target in self.__curent_lookup:
|
||||
raise RecursionError(
|
||||
"cyclic dependency detected : "
|
||||
f"{target!r} seems to depend on itself, "
|
||||
f"please provide a value for at least one variable in {self.__curent_lookup}"
|
||||
)
|
||||
else:
|
||||
self.__curent_lookup.add(target)
|
||||
|
||||
error = None
|
||||
for rule in reversed(self.rules[target]):
|
||||
try:
|
||||
args = [self.compute(k) for k in rule.args]
|
||||
returned_values = rule.func(*args)
|
||||
if len(rule.targets) == 1:
|
||||
self.params[target] = returned_values
|
||||
self.eval_stats[target].priority = rule.targets[target]
|
||||
value = returned_values
|
||||
else:
|
||||
for ((k, p), v) in zip(rule.targets.items(), returned_values):
|
||||
if (
|
||||
k == target
|
||||
or k not in self.params
|
||||
or self.eval_stats[k].priority < p
|
||||
):
|
||||
self.params[k] = v
|
||||
self.eval_stats[k] = p
|
||||
if k == target:
|
||||
value = v
|
||||
break
|
||||
except (RecursionError, KeyError) as e:
|
||||
error = e
|
||||
continue
|
||||
|
||||
if value is None and error is not None:
|
||||
raise error
|
||||
|
||||
self.__curent_lookup.remove(target)
|
||||
return value
|
||||
|
||||
def __call__(self, target: str, args: list[str] = None):
|
||||
"""creates a wrapper that adds decorated functions to the set of rules
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : str
|
||||
name of the target
|
||||
args : list[str], optional
|
||||
list of name of arguments. Automatically deduced from function signature if
|
||||
not provided, by default None
|
||||
"""
|
||||
|
||||
def wrapper(func):
|
||||
self.append(Rule(target, func, args))
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_arg_names(func: Callable) -> list[str]:
|
||||
spec = inspect.getfullargspec(func)
|
||||
args = spec.args
|
||||
if spec.defaults is not None and len(spec.defaults) > 0:
|
||||
args = args[: -len(spec.defaults)]
|
||||
return args
|
||||
|
||||
|
||||
def validate_arg_names(names: list[str]):
|
||||
for n in names:
|
||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
validate_arg_names(arg_names)
|
||||
validate_arg_names(kwarg_names)
|
||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||
tmp_name = f"{func.__name__}_0"
|
||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||
scope = dict(__func__=func)
|
||||
exec(func_str, scope)
|
||||
return scope[tmp_name]
|
||||
|
||||
|
||||
default_rules: list[Rule] = [
|
||||
*Rule.deduce(
|
||||
["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "w_power_fact", "l"],
|
||||
math.build_sim_grid,
|
||||
["time_window", "t_num", "dt"],
|
||||
2,
|
||||
)
|
||||
]
|
||||
"""
|
||||
Rule("gamma", fiber.gamma_parameter),
|
||||
Rule("gamma", lambda gamma_arr: gamma_arr[0]),
|
||||
Rule(["beta", "gamma", "interp_range"], fiber.PCF_dispersion),
|
||||
Rule("n2"),
|
||||
Rule("loss"),
|
||||
Rule("loss_file"),
|
||||
Rule("effective_mode_diameter"),
|
||||
Rule("A_eff"),
|
||||
Rule("A_eff_file"),
|
||||
Rule("pitch"),
|
||||
Rule("pitch_ratio"),
|
||||
Rule("core_radius"),
|
||||
Rule("he_mode"),
|
||||
Rule("fit_parameters"),
|
||||
Rule("beta"),
|
||||
Rule("dispersion_file"),
|
||||
Rule("model"),
|
||||
Rule("length"),
|
||||
Rule("capillary_num"),
|
||||
Rule("capillary_outer_d"),
|
||||
Rule("capillary_thickness"),
|
||||
Rule("capillary_spacing"),
|
||||
Rule("capillary_resonance_strengths"),
|
||||
Rule("capillary_nested"),
|
||||
Rule("gas_name"),
|
||||
Rule("pressure"),
|
||||
Rule("temperature"),
|
||||
Rule("plasma_density"),
|
||||
Rule("field_file"),
|
||||
Rule("repetition_rate"),
|
||||
Rule("peak_power"),
|
||||
Rule("mean_power"),
|
||||
Rule("energy"),
|
||||
Rule("soliton_num"),
|
||||
Rule("quantum_noise"),
|
||||
Rule("shape"),
|
||||
Rule("wavelength"),
|
||||
Rule("intensity_noise"),
|
||||
Rule("width"),
|
||||
Rule("t0"),
|
||||
Rule("behaviors"),
|
||||
Rule("parallel"),
|
||||
Rule("raman_type"),
|
||||
Rule("ideal_gas"),
|
||||
Rule("repeat"),
|
||||
Rule("t_num"),
|
||||
Rule("z_num"),
|
||||
Rule("time_window"),
|
||||
Rule("dt"),
|
||||
Rule("tolerated_error"),
|
||||
Rule("step_size"),
|
||||
Rule("lower_wavelength_interp_limit"),
|
||||
Rule("upper_wavelength_interp_limit"),
|
||||
Rule("interpolation_degree"),
|
||||
Rule("prev_sim_dir"),
|
||||
Rule("recovery_last_stored"),
|
||||
Rule("worker_num"),
|
||||
Rule("field_0"),
|
||||
Rule("spec_0"),
|
||||
Rule("alpha"),
|
||||
Rule("gamma_arr"),
|
||||
Rule("A_eff_arr"),
|
||||
Rule("w"),
|
||||
Rule("l"),
|
||||
Rule("w_c"),
|
||||
Rule("w0"),
|
||||
Rule("w_power_fact"),
|
||||
Rule("t"),
|
||||
Rule("L_D"),
|
||||
Rule("L_NL"),
|
||||
Rule("L_sol"),
|
||||
Rule("dynamic_dispersion"),
|
||||
Rule("adapt_step_size"),
|
||||
Rule("error_ok"),
|
||||
Rule("hr_w"),
|
||||
Rule("z_targets"),
|
||||
Rule("const_qty"),
|
||||
Rule("beta_func"),
|
||||
Rule("gamma_func"),
|
||||
Rule("interp_range"),
|
||||
Rule("datetime"),
|
||||
Rule("version"),
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
evalor = Evaluator()
|
||||
evalor.append(*default_rules)
|
||||
evalor.update(
|
||||
**{
|
||||
"length": 1,
|
||||
"z_num": 128,
|
||||
"wavelength": 1500e-9,
|
||||
"interpolation_degree": 8,
|
||||
"t_num": 16384,
|
||||
"dt": 1e-15,
|
||||
}
|
||||
)
|
||||
evalor.compute("z_targets")
|
||||
print(evalor.params.keys())
|
||||
print(evalor.params["l"][evalor.params["l"] > 0].min())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -8,6 +8,9 @@ import numpy as np
|
||||
|
||||
from ..const import __version__
|
||||
|
||||
# from .evaluator import Rule, Evaluator
|
||||
# from ..physics import pulse, fiber, materials
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Validator
|
||||
@@ -187,7 +190,7 @@ def translate(p_name: str, p_value: T) -> tuple[str, T]:
|
||||
|
||||
|
||||
class Parameter:
|
||||
def __init__(self, validator, converter=None, default=None, display_info=None):
|
||||
def __init__(self, validator, converter=None, default=None, display_info=None, rules=None):
|
||||
"""Single parameter
|
||||
|
||||
Parameters
|
||||
@@ -208,6 +211,10 @@ class Parameter:
|
||||
self.converter = converter
|
||||
self.default = default
|
||||
self.display_info = display_info
|
||||
if rules is None:
|
||||
self.rules = []
|
||||
else:
|
||||
self.rules = rules
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
self.name = name
|
||||
@@ -344,14 +351,14 @@ class BareParams:
|
||||
"""
|
||||
|
||||
# root
|
||||
name: str = Parameter(string)
|
||||
name: str = Parameter(string, default="no name")
|
||||
prev_data_dir: str = Parameter(string)
|
||||
previous_config_file: str = Parameter(string)
|
||||
|
||||
# # fiber
|
||||
input_transmission: float = Parameter(in_range_incl(0, 1))
|
||||
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
|
||||
gamma: float = Parameter(non_negative(float, int))
|
||||
n2: float = Parameter(non_negative(float, int))
|
||||
n2: float = Parameter(non_negative(float, int), default=2.2e-20)
|
||||
loss: str = Parameter(literal("capillary"))
|
||||
loss_file: str = Parameter(string)
|
||||
effective_mode_diameter: float = Parameter(positive(float, int))
|
||||
@@ -360,58 +367,66 @@ class BareParams:
|
||||
pitch: float = Parameter(in_range_excl(0, 1e-3))
|
||||
pitch_ratio: float = Parameter(in_range_excl(0, 1))
|
||||
core_radius: float = Parameter(in_range_excl(0, 1e-3))
|
||||
he_mode: Tuple[int, int] = Parameter(int_pair)
|
||||
fit_parameters: Tuple[int, int] = Parameter(int_pair)
|
||||
he_mode: Tuple[int, int] = Parameter(int_pair, default=(1, 1))
|
||||
fit_parameters: Tuple[int, int] = Parameter(int_pair, default=(0.08, 200e-9))
|
||||
beta: Iterable[float] = Parameter(num_list)
|
||||
dispersion_file: str = Parameter(string)
|
||||
model: str = Parameter(literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"))
|
||||
length: float = Parameter(non_negative(float, int))
|
||||
model: str = Parameter(
|
||||
literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"), default="custom"
|
||||
)
|
||||
length: float = Parameter(non_negative(float, int), default=1.0)
|
||||
capillary_num: int = Parameter(positive(int))
|
||||
capillary_outer_d: float = Parameter(in_range_excl(0, 1e-3))
|
||||
capillary_thickness: float = Parameter(in_range_excl(0, 1e-3))
|
||||
capillary_spacing: float = Parameter(in_range_excl(0, 1e-3))
|
||||
capillary_resonance_strengths: Iterable[float] = Parameter(num_list)
|
||||
capillary_nested: int = Parameter(non_negative(int))
|
||||
capillary_resonance_strengths: Iterable[float] = Parameter(num_list, default=[])
|
||||
capillary_nested: int = Parameter(non_negative(int), default=0)
|
||||
|
||||
# gas
|
||||
gas_name: str = Parameter(string, converter=str.lower)
|
||||
gas_name: str = Parameter(string, converter=str.lower, default="vacuum")
|
||||
pressure: Union[float, Iterable[float]] = Parameter(
|
||||
validator_or(non_negative(float, int), num_list), display_info=(1e-5, "bar")
|
||||
validator_or(non_negative(float, int), num_list), display_info=(1e-5, "bar"), default=1e5
|
||||
)
|
||||
temperature: float = Parameter(positive(float, int), display_info=(1, "K"))
|
||||
plasma_density: float = Parameter(non_negative(float, int))
|
||||
temperature: float = Parameter(positive(float, int), display_info=(1, "K"), default=300)
|
||||
plasma_density: float = Parameter(non_negative(float, int), default=0)
|
||||
|
||||
# pulse
|
||||
field_file: str = Parameter(string)
|
||||
repetition_rate: float = Parameter(non_negative(float, int), display_info=(1e-6, "MHz"))
|
||||
repetition_rate: float = Parameter(
|
||||
non_negative(float, int), display_info=(1e-6, "MHz"), default=40e6
|
||||
)
|
||||
peak_power: float = Parameter(positive(float, int), display_info=(1e-3, "kW"))
|
||||
mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW"))
|
||||
energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ"))
|
||||
soliton_num: float = Parameter(non_negative(float, int))
|
||||
quantum_noise: bool = Parameter(boolean)
|
||||
shape: str = Parameter(literal("gaussian", "sech"))
|
||||
quantum_noise: bool = Parameter(boolean, default=False)
|
||||
shape: str = Parameter(literal("gaussian", "sech"), default="gaussian")
|
||||
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm"))
|
||||
intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%"))
|
||||
intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%"), default=0)
|
||||
width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
|
||||
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
|
||||
|
||||
# simulation
|
||||
behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss")))
|
||||
parallel: bool = Parameter(boolean)
|
||||
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
|
||||
ideal_gas: bool = Parameter(boolean)
|
||||
repeat: int = Parameter(positive(int))
|
||||
behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss")), default=["spm", "ss"])
|
||||
parallel: bool = Parameter(boolean, default=True)
|
||||
raman_type: str = Parameter(
|
||||
literal("measured", "agrawal", "stolen"), converter=str.lower, default="agrawal"
|
||||
)
|
||||
ideal_gas: bool = Parameter(boolean, default=False)
|
||||
repeat: int = Parameter(positive(int), default=1)
|
||||
t_num: int = Parameter(positive(int))
|
||||
z_num: int = Parameter(positive(int))
|
||||
time_window: float = Parameter(positive(float, int))
|
||||
dt: float = Parameter(in_range_excl(0, 5e-15))
|
||||
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3))
|
||||
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
|
||||
step_size: float = Parameter(positive(float, int))
|
||||
lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9))
|
||||
upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9))
|
||||
interpolation_degree: int = Parameter(positive(int))
|
||||
lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9), default=100e-9)
|
||||
upper_wavelength_interp_limit: float = Parameter(
|
||||
in_range_incl(200e-9, 5000e-9), default=2000e-9
|
||||
)
|
||||
interpolation_degree: int = Parameter(positive(int), default=8)
|
||||
prev_sim_dir: str = Parameter(string)
|
||||
recovery_last_stored: int = Parameter(non_negative(int))
|
||||
recovery_last_stored: int = Parameter(non_negative(int), default=0)
|
||||
worker_num: int = Parameter(positive(int))
|
||||
|
||||
# computed
|
||||
|
||||
Reference in New Issue
Block a user