added default values to Rule functions
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import inspect
|
||||
import traceback
|
||||
from collections import ChainMap, defaultdict
|
||||
from functools import cache
|
||||
from inspect import Parameter as PARAM
|
||||
from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -10,13 +12,13 @@ import numpy as np
|
||||
from scgenerator import io, math, operators, utils
|
||||
from scgenerator.const import INF, MANDATORY_PARAMETERS
|
||||
from scgenerator.physics import fiber, materials, plasma, pulse, units
|
||||
from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger
|
||||
from scgenerator.utils import get_logger
|
||||
|
||||
|
||||
class ErrorRecord(NamedTuple):
|
||||
error: Exception
|
||||
lookup_stack: list[str]
|
||||
rules_stack: list[Rule]
|
||||
lookup_stack: tuple[str]
|
||||
rules_stack: tuple[Rule]
|
||||
traceback: str
|
||||
|
||||
|
||||
@@ -40,7 +42,7 @@ class CyclicDependencyError(EvaluatorError):
|
||||
class AllRulesExhaustedError(EvaluatorError):
|
||||
def __init__(self, target: str):
|
||||
self.target = target
|
||||
super().__init__(f"tried every rule for {target!r} and no default value is set")
|
||||
super().__init__(f"every rule for {target!r} failed and no default value is set")
|
||||
|
||||
|
||||
class EvaluatorErrorTree:
|
||||
@@ -60,11 +62,14 @@ class EvaluatorErrorTree:
|
||||
def append(self, error: Exception, lookup_stack: list[str], rules_stack: list[Rule]):
|
||||
tr = traceback.format_exc().splitlines()
|
||||
tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `EvaluatorErrorTree.append` frame
|
||||
self.all_errors.append(ErrorRecord(error, lookup_stack.copy(), rules_stack.copy(), tr))
|
||||
self.all_errors.append(ErrorRecord(error, tuple(lookup_stack), tuple(rules_stack), tr))
|
||||
|
||||
def compile_error(self) -> EvaluatorError:
|
||||
failed_rules = set(rec.rules_stack[-1] for rec in self.all_errors if rec.rules_stack)
|
||||
failed_rules = ("\n" + "-" * 80 + "\n").join(rule.pretty_format() for rule in failed_rules)
|
||||
failed_rules = set(rec for rec in self.all_errors if rec.rules_stack)
|
||||
failed_rules = [
|
||||
rec.rules_stack[-1].pretty_format() + "\n" + str(rec.error) for rec in failed_rules
|
||||
]
|
||||
failed_rules = ("\n" + "-" * 80 + "\n").join(failed_rules)
|
||||
|
||||
raise EvaluatorError(
|
||||
f"Couldn't compute {self.target}. {len(self)} rules failed.\n{failed_rules}"
|
||||
@@ -76,6 +81,7 @@ class Rule:
|
||||
func: Callable
|
||||
args: tuple[str]
|
||||
conditions: dict[str, Any]
|
||||
arg_defaults: dict[str, Any]
|
||||
|
||||
mock_func: Callable
|
||||
|
||||
@@ -86,6 +92,7 @@ class Rule:
|
||||
args: tuple[str] | None = None,
|
||||
priorities: Union[int, list[int]] | None = None,
|
||||
conditions: dict[str, str] | None = None,
|
||||
defaults: dict[str, Any] | None = None,
|
||||
):
|
||||
targets = list(target) if isinstance(target, (list, tuple)) else [target]
|
||||
self.func = func
|
||||
@@ -94,9 +101,26 @@ class Rule:
|
||||
elif isinstance(priorities, (int, float, np.integer, np.floating)):
|
||||
priorities = [priorities] * len(targets)
|
||||
self.targets = dict(zip(targets, priorities))
|
||||
if args is None:
|
||||
args = get_arg_names(func)
|
||||
self.args = tuple(args)
|
||||
|
||||
func_args, func_defaults = get_arg_names(func)
|
||||
|
||||
if args is not None:
|
||||
try:
|
||||
for func_arg, user_arg in zip(func_args, args, strict=True):
|
||||
if func_arg in func_defaults:
|
||||
func_defaults[user_arg] = func_defaults.pop(func_arg)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"length mismatch between arguments of {func.__name__}: "
|
||||
f"{func_args!r} and provided ones: {args!r}"
|
||||
) from e
|
||||
func_args = args
|
||||
|
||||
func_defaults |= defaults or {}
|
||||
|
||||
self.args = tuple(func_args)
|
||||
self.arg_defaults = func_defaults
|
||||
|
||||
self.mock_func = _mock_function(len(self.args), len(self.targets))
|
||||
self.conditions = conditions or {}
|
||||
|
||||
@@ -126,55 +150,6 @@ class Rule:
|
||||
def func_name(self) -> str:
|
||||
return f"{self.func.__module__}.{self.func.__name__}"
|
||||
|
||||
@classmethod
|
||||
def deduce(
|
||||
cls,
|
||||
target: Union[str, list[Optional[str]]],
|
||||
func: Callable,
|
||||
kwarg_names: list[str],
|
||||
n_var: int,
|
||||
args_const: list[str] = None,
|
||||
priorities: Union[int, list[int]] = None,
|
||||
) -> list["Rule"]:
|
||||
"""
|
||||
given a function that doesn't need all its keyword arguments specified, will
|
||||
return a list of Rule obj, one for each combination of n_var specified kwargs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : str | list[str | None]
|
||||
name of the variable(s) that func returns
|
||||
func : Callable
|
||||
function to work with
|
||||
kwarg_names : list[str]
|
||||
list of all kwargs of the function to be used
|
||||
n_var : int
|
||||
how many shoulf be used per rule
|
||||
arg_const : list[str], optional
|
||||
override the name of the positional arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Rule]
|
||||
list of all possible rules
|
||||
|
||||
Example
|
||||
-------
|
||||
>> def lol(a, b=None, c=None):
|
||||
pass
|
||||
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
|
||||
[
|
||||
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
|
||||
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
|
||||
]
|
||||
"""
|
||||
rules: list[cls] = []
|
||||
for var_possibility in itertools.combinations(kwarg_names, n_var):
|
||||
new_func = func_rewrite(func, list(var_possibility), args_const)
|
||||
|
||||
rules.append(cls(target, new_func, priorities=priorities))
|
||||
return rules
|
||||
|
||||
|
||||
class EvaluatedValue(NamedTuple):
|
||||
value: Any
|
||||
@@ -187,9 +162,7 @@ class Evaluator:
|
||||
|
||||
rules: dict[str, list[Rule]]
|
||||
main_map: dict[str, EvaluatedValue]
|
||||
param_chain_map: MutableMapping[str, EvaluatedValue]
|
||||
lookup_stack: list[str]
|
||||
failed_rules: dict[str, list[Rule]]
|
||||
|
||||
@classmethod
|
||||
def default(cls, full_field: bool = False) -> "Evaluator":
|
||||
@@ -220,8 +193,6 @@ class Evaluator:
|
||||
self.main_map = {}
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
self.failed_rules = defaultdict(list)
|
||||
|
||||
self.append(*rules)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
@@ -302,7 +273,7 @@ class Evaluator:
|
||||
if target not in self.rules or len(self.rules[target]) == 0:
|
||||
raise NoValidRuleError(target)
|
||||
if target in lookup_stack:
|
||||
raise CyclicDependencyError(target)
|
||||
raise CyclicDependencyError(target, lookup_stack)
|
||||
|
||||
lookup_stack.append(target)
|
||||
base_cm_length = len(param_chain_map.maps)
|
||||
@@ -359,19 +330,24 @@ class Evaluator:
|
||||
rules_stack: list[Rule],
|
||||
errors: EvaluatorErrorTree,
|
||||
) -> dict[str, EvaluatedValue] | None:
|
||||
try:
|
||||
for arg in rule.args:
|
||||
self._compute(arg, check_only, param_chain_map, lookup_stack, rules_stack, errors)
|
||||
except Exception as e:
|
||||
errors.append(e, lookup_stack, rules_stack)
|
||||
return None
|
||||
|
||||
args = [param_chain_map[k].value for k in rule.args]
|
||||
arg_values = []
|
||||
for arg in rule.args:
|
||||
try:
|
||||
val = self._compute(
|
||||
arg, check_only, param_chain_map, lookup_stack, rules_stack, errors
|
||||
)
|
||||
except Exception as e:
|
||||
if arg in rule.arg_defaults:
|
||||
val = rule.arg_defaults[arg]
|
||||
else:
|
||||
errors.append(e, lookup_stack, rules_stack)
|
||||
return None
|
||||
arg_values.append(val)
|
||||
|
||||
func = rule.mock_func if check_only else rule.func
|
||||
|
||||
try:
|
||||
values = func(*args)
|
||||
values = func(*arg_values)
|
||||
except Exception as e:
|
||||
errors.append(e, lookup_stack, rules_stack)
|
||||
return None
|
||||
@@ -395,27 +371,73 @@ class Evaluator:
|
||||
while len(param_chain_map.maps) > target_size:
|
||||
param_chain_map.maps.pop(0)
|
||||
|
||||
def clear_computed(self):
|
||||
"""deletes all computed values to leave only hard-set ones"""
|
||||
self.main_map = {k: v for k, v in self.main_map.items() if v.rule is None}
|
||||
|
||||
def validate_condition(self, rule: Rule) -> bool:
|
||||
try:
|
||||
return all(self.compute(k) == v for k, v in rule.conditions.items())
|
||||
except EvaluatorError:
|
||||
return False
|
||||
|
||||
def attempted_rules_str(self, target: str) -> str:
|
||||
rules = ", ".join(str(r) for r in self.failed_rules[target])
|
||||
if len(rules) == 0:
|
||||
return ""
|
||||
return "attempted rules : " + rules
|
||||
|
||||
def get_arg_names(func: Callable) -> tuple[list[str], dict[str, Any]]:
|
||||
"""
|
||||
returns the positional argument names of func.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
if a function, returns the names of the positional arguments
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
list of argument names
|
||||
dict[str, Any]
|
||||
default values of arguments, if provided
|
||||
"""
|
||||
defaults = {}
|
||||
args = []
|
||||
for p in inspect.signature(func).parameters.values():
|
||||
if p.kind in {PARAM.VAR_KEYWORD, PARAM.VAR_POSITIONAL}:
|
||||
raise TypeError(
|
||||
f"function {func.__name__} has variadic argument {p.name!r}, which is not allowed"
|
||||
)
|
||||
|
||||
if p.kind is PARAM.KEYWORD_ONLY:
|
||||
if p.default is PARAM.empty:
|
||||
raise TypeError(
|
||||
f"function {func.__name__} has keyword-only argument {p.name!r} "
|
||||
"with no default. This is not supported at the moment."
|
||||
)
|
||||
continue
|
||||
|
||||
args.append(p.name)
|
||||
if p.default is not PARAM.empty:
|
||||
defaults[p.name] = p.default
|
||||
|
||||
return args, defaults
|
||||
|
||||
|
||||
@cache
|
||||
def _mock_function(num_args: int, num_returns: int) -> Callable:
|
||||
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
|
||||
return_str = ", ".join("True" for _ in range(num_returns))
|
||||
func_name = f"__mock_{num_args}_{num_returns}"
|
||||
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
|
||||
scope = {}
|
||||
exec(func_str, scope)
|
||||
out_func = scope[func_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
default_rules: list[Rule] = [
|
||||
# Grid
|
||||
*Rule.deduce(
|
||||
["t", "time_window", "dt", "t_num"],
|
||||
math.build_t_grid,
|
||||
["time_window", "t_num", "dt"],
|
||||
2,
|
||||
),
|
||||
Rule(["t", "time_window", "dt", "t_num"], math.build_t_grid),
|
||||
Rule("z_targets", math.build_z_grid),
|
||||
Rule("adapt_step_size", lambda step_size: step_size == 0),
|
||||
Rule(
|
||||
@@ -434,11 +456,9 @@ default_rules: list[Rule] = [
|
||||
Rule(["input_time", "input_field"], pulse.load_custom_field),
|
||||
Rule("spec_0", lambda fft, field_0: fft(field_0)),
|
||||
Rule("field_0", lambda ifft, spec_0: ifft(spec_0)),
|
||||
*Rule.deduce(
|
||||
Rule(
|
||||
["pre_field_0", "peak_power", "energy", "width"],
|
||||
pulse.adjust_custom_field,
|
||||
["energy", "peak_power"],
|
||||
1,
|
||||
priorities=[2, 1, 1, 1],
|
||||
),
|
||||
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
||||
@@ -500,7 +520,7 @@ default_rules: list[Rule] = [
|
||||
Rule(
|
||||
"V_eff",
|
||||
fiber.V_eff_step_index,
|
||||
["wavelength", "core_radius", "numerical_aperture"],
|
||||
["l", "wavelength", "core_radius", "numerical_aperture"],
|
||||
),
|
||||
Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")),
|
||||
Rule(
|
||||
|
||||
@@ -174,7 +174,7 @@ def initial_field_envelope(
|
||||
raised when shape is not recognized
|
||||
"""
|
||||
if delay is not None:
|
||||
t = t + delay
|
||||
t = t - delay
|
||||
|
||||
if shape == "gaussian":
|
||||
return gaussian_pulse(t, t0, peak_power)
|
||||
@@ -472,14 +472,14 @@ def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndar
|
||||
return units.m.inv(units.m(init_wavelength) - delta_w)
|
||||
|
||||
|
||||
def E0_to_P0(E0: float, t0: float, shape: str):
|
||||
def E0_to_P0(energy: float, t0: float, shape: str):
|
||||
"""convert an initial total pulse energy to a pulse peak peak_power"""
|
||||
return E0 / (t0 * P0T0_to_E0_fac[shape])
|
||||
return energy / (t0 * P0T0_to_E0_fac[shape])
|
||||
|
||||
|
||||
def P0_to_E0(P0: float, t0: float, shape: str):
|
||||
def P0_to_E0(peak_power: float, t0: float, shape: str):
|
||||
"""converts initial peak peak_power to pulse energy"""
|
||||
return P0 * t0 * P0T0_to_E0_fac[shape]
|
||||
return peak_power * t0 * P0T0_to_E0_fac[shape]
|
||||
|
||||
|
||||
def sech_pulse(t: np.ndarray, t0: float, P0: float):
|
||||
|
||||
@@ -545,7 +545,7 @@ def mean_values_plot(
|
||||
mean_style: dict[str, Any] = None,
|
||||
individual_style: dict[str, Any] = None,
|
||||
) -> tuple[plt.Line2D, list[plt.Line2D]]:
|
||||
x_axis, mean_values, values = transform_mean_values(values, plt_range, params, log, spacing)
|
||||
x_axis, mean_values, values = transform_mean_values_1d(values, plt_range, params, log, spacing)
|
||||
if renormalize and log is False:
|
||||
maxi = mean_values.max()
|
||||
mean_values = mean_values / maxi
|
||||
@@ -570,7 +570,7 @@ def mean_values_plot(
|
||||
)
|
||||
|
||||
|
||||
def transform_mean_values(
|
||||
def transform_mean_values_1d(
|
||||
values: np.ndarray,
|
||||
plt_range: Union[PlotRange, RangeType],
|
||||
params: Parameters,
|
||||
|
||||
@@ -6,7 +6,6 @@ scgenerator module but some function may be used in any python program
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import inspect
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
@@ -283,60 +282,6 @@ def to_62(i: int) -> str:
|
||||
return "".join(reversed(arr))
|
||||
|
||||
|
||||
def get_arg_names(func: Callable) -> list[str]:
|
||||
"""
|
||||
returns the positional argument names of func.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
if a function, returns the names of the positional arguments
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
[description]
|
||||
"""
|
||||
return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty]
|
||||
|
||||
|
||||
def validate_arg_names(names: list[str]):
|
||||
for n in names:
|
||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
validate_arg_names(arg_names)
|
||||
validate_arg_names(kwarg_names)
|
||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||
tmp_name = f"{func.__name__}_0"
|
||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||
scope = dict(__func__=func)
|
||||
exec(func_str, scope)
|
||||
out_func = scope[tmp_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
@cache
|
||||
def _mock_function(num_args: int, num_returns: int) -> Callable:
|
||||
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
|
||||
return_str = ", ".join("True" for _ in range(num_returns))
|
||||
func_name = f"__mock_{num_args}_{num_returns}"
|
||||
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
|
||||
scope = {}
|
||||
exec(func_str, scope)
|
||||
out_func = scope[func_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
def fft_functions(
|
||||
full_field: bool,
|
||||
) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from scgenerator.evaluator import Evaluator, Rule
|
||||
from scgenerator import math, units
|
||||
from scgenerator.evaluator import Evaluator, EvaluatorError, Rule
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -16,7 +17,7 @@ def disk_rules() -> list[Rule]:
|
||||
]
|
||||
|
||||
|
||||
def test_simple(disk_rules: list[Rule]):
|
||||
def test_trivial(disk_rules: list[Rule]):
|
||||
evaluator = Evaluator(*disk_rules)
|
||||
evaluator.set(radius=5)
|
||||
|
||||
@@ -25,3 +26,28 @@ def test_simple(disk_rules: list[Rule]):
|
||||
evaluator.set(area=5)
|
||||
assert evaluator.compute("area") == 5
|
||||
assert evaluator.compute("radius") == 5
|
||||
|
||||
|
||||
def test_simple():
|
||||
evaluator = Evaluator.default()
|
||||
evaluator.set(wavelength=800e-9, t_num=1024, dt=5e-15)
|
||||
|
||||
assert evaluator.compute("t") == pytest.approx(math.tspace(t_num=1024, dt=5e-15))
|
||||
assert evaluator.compute("w0") == pytest.approx(units.nm(800))
|
||||
|
||||
|
||||
def test_default_args():
|
||||
def some_function(a: int, b: int, c: int = 5):
|
||||
return a + b + c
|
||||
|
||||
evaluator = Evaluator(Rule("d", some_function))
|
||||
evaluator.set(a=1, b=3)
|
||||
|
||||
with pytest.raises(EvaluatorError):
|
||||
evaluator.compute("c")
|
||||
assert evaluator.compute("d") == 9
|
||||
|
||||
evaluator.clear_computed()
|
||||
evaluator.set(c=10)
|
||||
assert evaluator.compute("c") == 10
|
||||
assert evaluator.compute("d") == 14
|
||||
|
||||
Reference in New Issue
Block a user