added default values to Rule functions

This commit is contained in:
Benoît Sierro
2023-08-15 12:16:04 +02:00
parent 50d404158c
commit aa821eb52d
5 changed files with 144 additions and 153 deletions

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import itertools
import inspect
import traceback
from collections import ChainMap, defaultdict
from functools import cache
from inspect import Parameter as PARAM
from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Union
import numpy as np
@@ -10,13 +12,13 @@ import numpy as np
from scgenerator import io, math, operators, utils
from scgenerator.const import INF, MANDATORY_PARAMETERS
from scgenerator.physics import fiber, materials, plasma, pulse, units
from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger
from scgenerator.utils import get_logger
class ErrorRecord(NamedTuple):
error: Exception
lookup_stack: list[str]
rules_stack: list[Rule]
lookup_stack: tuple[str]
rules_stack: tuple[Rule]
traceback: str
@@ -40,7 +42,7 @@ class CyclicDependencyError(EvaluatorError):
class AllRulesExhaustedError(EvaluatorError):
def __init__(self, target: str):
self.target = target
super().__init__(f"tried every rule for {target!r} and no default value is set")
super().__init__(f"every rule for {target!r} failed and no default value is set")
class EvaluatorErrorTree:
@@ -60,11 +62,14 @@ class EvaluatorErrorTree:
def append(self, error: Exception, lookup_stack: list[str], rules_stack: list[Rule]):
tr = traceback.format_exc().splitlines()
tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `EvaluatorErrorTree.append` frame
self.all_errors.append(ErrorRecord(error, lookup_stack.copy(), rules_stack.copy(), tr))
self.all_errors.append(ErrorRecord(error, tuple(lookup_stack), tuple(rules_stack), tr))
def compile_error(self) -> EvaluatorError:
failed_rules = set(rec.rules_stack[-1] for rec in self.all_errors if rec.rules_stack)
failed_rules = ("\n" + "-" * 80 + "\n").join(rule.pretty_format() for rule in failed_rules)
failed_rules = set(rec for rec in self.all_errors if rec.rules_stack)
failed_rules = [
rec.rules_stack[-1].pretty_format() + "\n" + str(rec.error) for rec in failed_rules
]
failed_rules = ("\n" + "-" * 80 + "\n").join(failed_rules)
raise EvaluatorError(
f"Couldn't compute {self.target}. {len(self)} rules failed.\n{failed_rules}"
@@ -76,6 +81,7 @@ class Rule:
func: Callable
args: tuple[str]
conditions: dict[str, Any]
arg_defaults: dict[str, Any]
mock_func: Callable
@@ -86,6 +92,7 @@ class Rule:
args: tuple[str] | None = None,
priorities: Union[int, list[int]] | None = None,
conditions: dict[str, str] | None = None,
defaults: dict[str, Any] | None = None,
):
targets = list(target) if isinstance(target, (list, tuple)) else [target]
self.func = func
@@ -94,9 +101,26 @@ class Rule:
elif isinstance(priorities, (int, float, np.integer, np.floating)):
priorities = [priorities] * len(targets)
self.targets = dict(zip(targets, priorities))
if args is None:
args = get_arg_names(func)
self.args = tuple(args)
func_args, func_defaults = get_arg_names(func)
if args is not None:
try:
for func_arg, user_arg in zip(func_args, args, strict=True):
if func_arg in func_defaults:
func_defaults[user_arg] = func_defaults.pop(func_arg)
except ValueError as e:
raise ValueError(
f"length mismatch between arguments of {func.__name__}: "
f"{func_args!r} and provided ones: {args!r}"
) from e
func_args = args
func_defaults |= defaults or {}
self.args = tuple(func_args)
self.arg_defaults = func_defaults
self.mock_func = _mock_function(len(self.args), len(self.targets))
self.conditions = conditions or {}
@@ -126,55 +150,6 @@ class Rule:
def func_name(self) -> str:
return f"{self.func.__module__}.{self.func.__name__}"
@classmethod
def deduce(
cls,
target: Union[str, list[Optional[str]]],
func: Callable,
kwarg_names: list[str],
n_var: int,
args_const: list[str] = None,
priorities: Union[int, list[int]] = None,
) -> list["Rule"]:
"""
given a function that doesn't need all its keyword arguments specified, will
return a list of Rule obj, one for each combination of n_var specified kwargs
Parameters
----------
target : str | list[str | None]
name of the variable(s) that func returns
func : Callable
function to work with
kwarg_names : list[str]
list of all kwargs of the function to be used
n_var : int
how many shoulf be used per rule
arg_const : list[str], optional
override the name of the positional arguments
Returns
-------
list[Rule]
list of all possible rules
Example
-------
>> def lol(a, b=None, c=None):
pass
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
[
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
]
"""
rules: list[cls] = []
for var_possibility in itertools.combinations(kwarg_names, n_var):
new_func = func_rewrite(func, list(var_possibility), args_const)
rules.append(cls(target, new_func, priorities=priorities))
return rules
class EvaluatedValue(NamedTuple):
value: Any
@@ -187,9 +162,7 @@ class Evaluator:
rules: dict[str, list[Rule]]
main_map: dict[str, EvaluatedValue]
param_chain_map: MutableMapping[str, EvaluatedValue]
lookup_stack: list[str]
failed_rules: dict[str, list[Rule]]
@classmethod
def default(cls, full_field: bool = False) -> "Evaluator":
@@ -220,8 +193,6 @@ class Evaluator:
self.main_map = {}
self.logger = get_logger(__name__)
self.failed_rules = defaultdict(list)
self.append(*rules)
def __getitem__(self, key: str) -> Any:
@@ -302,7 +273,7 @@ class Evaluator:
if target not in self.rules or len(self.rules[target]) == 0:
raise NoValidRuleError(target)
if target in lookup_stack:
raise CyclicDependencyError(target)
raise CyclicDependencyError(target, lookup_stack)
lookup_stack.append(target)
base_cm_length = len(param_chain_map.maps)
@@ -359,19 +330,24 @@ class Evaluator:
rules_stack: list[Rule],
errors: EvaluatorErrorTree,
) -> dict[str, EvaluatedValue] | None:
try:
for arg in rule.args:
self._compute(arg, check_only, param_chain_map, lookup_stack, rules_stack, errors)
except Exception as e:
errors.append(e, lookup_stack, rules_stack)
return None
args = [param_chain_map[k].value for k in rule.args]
arg_values = []
for arg in rule.args:
try:
val = self._compute(
arg, check_only, param_chain_map, lookup_stack, rules_stack, errors
)
except Exception as e:
if arg in rule.arg_defaults:
val = rule.arg_defaults[arg]
else:
errors.append(e, lookup_stack, rules_stack)
return None
arg_values.append(val)
func = rule.mock_func if check_only else rule.func
try:
values = func(*args)
values = func(*arg_values)
except Exception as e:
errors.append(e, lookup_stack, rules_stack)
return None
@@ -395,27 +371,73 @@ class Evaluator:
while len(param_chain_map.maps) > target_size:
param_chain_map.maps.pop(0)
def clear_computed(self):
"""deletes all computed values to leave only hard-set ones"""
self.main_map = {k: v for k, v in self.main_map.items() if v.rule is None}
def validate_condition(self, rule: Rule) -> bool:
try:
return all(self.compute(k) == v for k, v in rule.conditions.items())
except EvaluatorError:
return False
def attempted_rules_str(self, target: str) -> str:
rules = ", ".join(str(r) for r in self.failed_rules[target])
if len(rules) == 0:
return ""
return "attempted rules : " + rules
def get_arg_names(func: Callable) -> tuple[list[str], dict[str, Any]]:
"""
returns the positional argument names of func.
Parameters
----------
func : Callable
if a function, returns the names of the positional arguments
Returns
-------
list[str]
list of argument names
dict[str, Any]
default values of arguments, if provided
"""
defaults = {}
args = []
for p in inspect.signature(func).parameters.values():
if p.kind in {PARAM.VAR_KEYWORD, PARAM.VAR_POSITIONAL}:
raise TypeError(
f"function {func.__name__} has variadic argument {p.name!r}, which is not allowed"
)
if p.kind is PARAM.KEYWORD_ONLY:
if p.default is PARAM.empty:
raise TypeError(
f"function {func.__name__} has keyword-only argument {p.name!r} "
"with no default. This is not supported at the moment."
)
continue
args.append(p.name)
if p.default is not PARAM.empty:
defaults[p.name] = p.default
return args, defaults
@cache
def _mock_function(num_args: int, num_returns: int) -> Callable:
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
return_str = ", ".join("True" for _ in range(num_returns))
func_name = f"__mock_{num_args}_{num_returns}"
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
scope = {}
exec(func_str, scope)
out_func = scope[func_name]
out_func.__module__ = "evaluator"
return out_func
default_rules: list[Rule] = [
# Grid
*Rule.deduce(
["t", "time_window", "dt", "t_num"],
math.build_t_grid,
["time_window", "t_num", "dt"],
2,
),
Rule(["t", "time_window", "dt", "t_num"], math.build_t_grid),
Rule("z_targets", math.build_z_grid),
Rule("adapt_step_size", lambda step_size: step_size == 0),
Rule(
@@ -434,11 +456,9 @@ default_rules: list[Rule] = [
Rule(["input_time", "input_field"], pulse.load_custom_field),
Rule("spec_0", lambda fft, field_0: fft(field_0)),
Rule("field_0", lambda ifft, spec_0: ifft(spec_0)),
*Rule.deduce(
Rule(
["pre_field_0", "peak_power", "energy", "width"],
pulse.adjust_custom_field,
["energy", "peak_power"],
1,
priorities=[2, 1, 1, 1],
),
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
@@ -500,7 +520,7 @@ default_rules: list[Rule] = [
Rule(
"V_eff",
fiber.V_eff_step_index,
["wavelength", "core_radius", "numerical_aperture"],
["l", "wavelength", "core_radius", "numerical_aperture"],
),
Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")),
Rule(

View File

@@ -174,7 +174,7 @@ def initial_field_envelope(
raised when shape is not recognized
"""
if delay is not None:
t = t + delay
t = t - delay
if shape == "gaussian":
return gaussian_pulse(t, t0, peak_power)
@@ -472,14 +472,14 @@ def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndar
return units.m.inv(units.m(init_wavelength) - delta_w)
def E0_to_P0(E0: float, t0: float, shape: str):
def E0_to_P0(energy: float, t0: float, shape: str):
"""convert an initial total pulse energy to a pulse peak peak_power"""
return E0 / (t0 * P0T0_to_E0_fac[shape])
return energy / (t0 * P0T0_to_E0_fac[shape])
def P0_to_E0(P0: float, t0: float, shape: str):
def P0_to_E0(peak_power: float, t0: float, shape: str):
"""converts initial peak peak_power to pulse energy"""
return P0 * t0 * P0T0_to_E0_fac[shape]
return peak_power * t0 * P0T0_to_E0_fac[shape]
def sech_pulse(t: np.ndarray, t0: float, P0: float):

View File

@@ -545,7 +545,7 @@ def mean_values_plot(
mean_style: dict[str, Any] = None,
individual_style: dict[str, Any] = None,
) -> tuple[plt.Line2D, list[plt.Line2D]]:
x_axis, mean_values, values = transform_mean_values(values, plt_range, params, log, spacing)
x_axis, mean_values, values = transform_mean_values_1d(values, plt_range, params, log, spacing)
if renormalize and log is False:
maxi = mean_values.max()
mean_values = mean_values / maxi
@@ -570,7 +570,7 @@ def mean_values_plot(
)
def transform_mean_values(
def transform_mean_values_1d(
values: np.ndarray,
plt_range: Union[PlotRange, RangeType],
params: Parameters,

View File

@@ -6,7 +6,6 @@ scgenerator module but some function may be used in any python program
from __future__ import annotations
import datetime
import inspect
import itertools
import json
import os
@@ -283,60 +282,6 @@ def to_62(i: int) -> str:
return "".join(reversed(arr))
def get_arg_names(func: Callable) -> list[str]:
"""
returns the positional argument names of func.
Parameters
----------
func : Callable
if a function, returns the names of the positional arguments
Returns
-------
list[str]
[description]
"""
return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty]
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
out_func = scope[tmp_name]
out_func.__module__ = "evaluator"
return out_func
@cache
def _mock_function(num_args: int, num_returns: int) -> Callable:
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
return_str = ", ".join("True" for _ in range(num_returns))
func_name = f"__mock_{num_args}_{num_returns}"
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
scope = {}
exec(func_str, scope)
out_func = scope[func_name]
out_func.__module__ = "evaluator"
return out_func
def fft_functions(
full_field: bool,
) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:

View File

@@ -1,7 +1,8 @@
import numpy as np
import pytest
from scgenerator.evaluator import Evaluator, Rule
from scgenerator import math, units
from scgenerator.evaluator import Evaluator, EvaluatorError, Rule
@pytest.fixture
@@ -16,7 +17,7 @@ def disk_rules() -> list[Rule]:
]
def test_simple(disk_rules: list[Rule]):
def test_trivial(disk_rules: list[Rule]):
evaluator = Evaluator(*disk_rules)
evaluator.set(radius=5)
@@ -25,3 +26,28 @@ def test_simple(disk_rules: list[Rule]):
evaluator.set(area=5)
assert evaluator.compute("area") == 5
assert evaluator.compute("radius") == 5
def test_simple():
evaluator = Evaluator.default()
evaluator.set(wavelength=800e-9, t_num=1024, dt=5e-15)
assert evaluator.compute("t") == pytest.approx(math.tspace(t_num=1024, dt=5e-15))
assert evaluator.compute("w0") == pytest.approx(units.nm(800))
def test_default_args():
def some_function(a: int, b: int, c: int = 5):
return a + b + c
evaluator = Evaluator(Rule("d", some_function))
evaluator.set(a=1, b=3)
with pytest.raises(EvaluatorError):
evaluator.compute("c")
assert evaluator.compute("d") == 9
evaluator.clear_computed()
evaluator.set(c=10)
assert evaluator.compute("c") == 10
assert evaluator.compute("d") == 14