added default values to Rule functions

This commit is contained in:
Benoît Sierro
2023-08-15 12:16:04 +02:00
parent 50d404158c
commit aa821eb52d
5 changed files with 144 additions and 153 deletions

View File

@@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
import itertools import inspect
import traceback import traceback
from collections import ChainMap, defaultdict from collections import ChainMap, defaultdict
from functools import cache
from inspect import Parameter as PARAM
from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Union from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Union
import numpy as np import numpy as np
@@ -10,13 +12,13 @@ import numpy as np
from scgenerator import io, math, operators, utils from scgenerator import io, math, operators, utils
from scgenerator.const import INF, MANDATORY_PARAMETERS from scgenerator.const import INF, MANDATORY_PARAMETERS
from scgenerator.physics import fiber, materials, plasma, pulse, units from scgenerator.physics import fiber, materials, plasma, pulse, units
from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger from scgenerator.utils import get_logger
class ErrorRecord(NamedTuple): class ErrorRecord(NamedTuple):
error: Exception error: Exception
lookup_stack: list[str] lookup_stack: tuple[str]
rules_stack: list[Rule] rules_stack: tuple[Rule]
traceback: str traceback: str
@@ -40,7 +42,7 @@ class CyclicDependencyError(EvaluatorError):
class AllRulesExhaustedError(EvaluatorError): class AllRulesExhaustedError(EvaluatorError):
def __init__(self, target: str): def __init__(self, target: str):
self.target = target self.target = target
super().__init__(f"tried every rule for {target!r} and no default value is set") super().__init__(f"every rule for {target!r} failed and no default value is set")
class EvaluatorErrorTree: class EvaluatorErrorTree:
@@ -60,11 +62,14 @@ class EvaluatorErrorTree:
def append(self, error: Exception, lookup_stack: list[str], rules_stack: list[Rule]): def append(self, error: Exception, lookup_stack: list[str], rules_stack: list[Rule]):
tr = traceback.format_exc().splitlines() tr = traceback.format_exc().splitlines()
tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `EvaluatorErrorTree.append` frame tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `EvaluatorErrorTree.append` frame
self.all_errors.append(ErrorRecord(error, lookup_stack.copy(), rules_stack.copy(), tr)) self.all_errors.append(ErrorRecord(error, tuple(lookup_stack), tuple(rules_stack), tr))
def compile_error(self) -> EvaluatorError: def compile_error(self) -> EvaluatorError:
failed_rules = set(rec.rules_stack[-1] for rec in self.all_errors if rec.rules_stack) failed_rules = set(rec for rec in self.all_errors if rec.rules_stack)
failed_rules = ("\n" + "-" * 80 + "\n").join(rule.pretty_format() for rule in failed_rules) failed_rules = [
rec.rules_stack[-1].pretty_format() + "\n" + str(rec.error) for rec in failed_rules
]
failed_rules = ("\n" + "-" * 80 + "\n").join(failed_rules)
raise EvaluatorError( raise EvaluatorError(
f"Couldn't compute {self.target}. {len(self)} rules failed.\n{failed_rules}" f"Couldn't compute {self.target}. {len(self)} rules failed.\n{failed_rules}"
@@ -76,6 +81,7 @@ class Rule:
func: Callable func: Callable
args: tuple[str] args: tuple[str]
conditions: dict[str, Any] conditions: dict[str, Any]
arg_defaults: dict[str, Any]
mock_func: Callable mock_func: Callable
@@ -86,6 +92,7 @@ class Rule:
args: tuple[str] | None = None, args: tuple[str] | None = None,
priorities: Union[int, list[int]] | None = None, priorities: Union[int, list[int]] | None = None,
conditions: dict[str, str] | None = None, conditions: dict[str, str] | None = None,
defaults: dict[str, Any] | None = None,
): ):
targets = list(target) if isinstance(target, (list, tuple)) else [target] targets = list(target) if isinstance(target, (list, tuple)) else [target]
self.func = func self.func = func
@@ -94,9 +101,26 @@ class Rule:
elif isinstance(priorities, (int, float, np.integer, np.floating)): elif isinstance(priorities, (int, float, np.integer, np.floating)):
priorities = [priorities] * len(targets) priorities = [priorities] * len(targets)
self.targets = dict(zip(targets, priorities)) self.targets = dict(zip(targets, priorities))
if args is None:
args = get_arg_names(func) func_args, func_defaults = get_arg_names(func)
self.args = tuple(args)
if args is not None:
try:
for func_arg, user_arg in zip(func_args, args, strict=True):
if func_arg in func_defaults:
func_defaults[user_arg] = func_defaults.pop(func_arg)
except ValueError as e:
raise ValueError(
f"length mismatch between arguments of {func.__name__}: "
f"{func_args!r} and provided ones: {args!r}"
) from e
func_args = args
func_defaults |= defaults or {}
self.args = tuple(func_args)
self.arg_defaults = func_defaults
self.mock_func = _mock_function(len(self.args), len(self.targets)) self.mock_func = _mock_function(len(self.args), len(self.targets))
self.conditions = conditions or {} self.conditions = conditions or {}
@@ -126,55 +150,6 @@ class Rule:
def func_name(self) -> str: def func_name(self) -> str:
return f"{self.func.__module__}.{self.func.__name__}" return f"{self.func.__module__}.{self.func.__name__}"
@classmethod
def deduce(
cls,
target: Union[str, list[Optional[str]]],
func: Callable,
kwarg_names: list[str],
n_var: int,
args_const: list[str] = None,
priorities: Union[int, list[int]] = None,
) -> list["Rule"]:
"""
given a function that doesn't need all its keyword arguments specified, will
return a list of Rule obj, one for each combination of n_var specified kwargs
Parameters
----------
target : str | list[str | None]
name of the variable(s) that func returns
func : Callable
function to work with
kwarg_names : list[str]
list of all kwargs of the function to be used
n_var : int
how many shoulf be used per rule
arg_const : list[str], optional
override the name of the positional arguments
Returns
-------
list[Rule]
list of all possible rules
Example
-------
>> def lol(a, b=None, c=None):
pass
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
[
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
]
"""
rules: list[cls] = []
for var_possibility in itertools.combinations(kwarg_names, n_var):
new_func = func_rewrite(func, list(var_possibility), args_const)
rules.append(cls(target, new_func, priorities=priorities))
return rules
class EvaluatedValue(NamedTuple): class EvaluatedValue(NamedTuple):
value: Any value: Any
@@ -187,9 +162,7 @@ class Evaluator:
rules: dict[str, list[Rule]] rules: dict[str, list[Rule]]
main_map: dict[str, EvaluatedValue] main_map: dict[str, EvaluatedValue]
param_chain_map: MutableMapping[str, EvaluatedValue]
lookup_stack: list[str] lookup_stack: list[str]
failed_rules: dict[str, list[Rule]]
@classmethod @classmethod
def default(cls, full_field: bool = False) -> "Evaluator": def default(cls, full_field: bool = False) -> "Evaluator":
@@ -220,8 +193,6 @@ class Evaluator:
self.main_map = {} self.main_map = {}
self.logger = get_logger(__name__) self.logger = get_logger(__name__)
self.failed_rules = defaultdict(list)
self.append(*rules) self.append(*rules)
def __getitem__(self, key: str) -> Any: def __getitem__(self, key: str) -> Any:
@@ -302,7 +273,7 @@ class Evaluator:
if target not in self.rules or len(self.rules[target]) == 0: if target not in self.rules or len(self.rules[target]) == 0:
raise NoValidRuleError(target) raise NoValidRuleError(target)
if target in lookup_stack: if target in lookup_stack:
raise CyclicDependencyError(target) raise CyclicDependencyError(target, lookup_stack)
lookup_stack.append(target) lookup_stack.append(target)
base_cm_length = len(param_chain_map.maps) base_cm_length = len(param_chain_map.maps)
@@ -359,19 +330,24 @@ class Evaluator:
rules_stack: list[Rule], rules_stack: list[Rule],
errors: EvaluatorErrorTree, errors: EvaluatorErrorTree,
) -> dict[str, EvaluatedValue] | None: ) -> dict[str, EvaluatedValue] | None:
try: arg_values = []
for arg in rule.args: for arg in rule.args:
self._compute(arg, check_only, param_chain_map, lookup_stack, rules_stack, errors) try:
except Exception as e: val = self._compute(
errors.append(e, lookup_stack, rules_stack) arg, check_only, param_chain_map, lookup_stack, rules_stack, errors
return None )
except Exception as e:
args = [param_chain_map[k].value for k in rule.args] if arg in rule.arg_defaults:
val = rule.arg_defaults[arg]
else:
errors.append(e, lookup_stack, rules_stack)
return None
arg_values.append(val)
func = rule.mock_func if check_only else rule.func func = rule.mock_func if check_only else rule.func
try: try:
values = func(*args) values = func(*arg_values)
except Exception as e: except Exception as e:
errors.append(e, lookup_stack, rules_stack) errors.append(e, lookup_stack, rules_stack)
return None return None
@@ -395,27 +371,73 @@ class Evaluator:
while len(param_chain_map.maps) > target_size: while len(param_chain_map.maps) > target_size:
param_chain_map.maps.pop(0) param_chain_map.maps.pop(0)
def clear_computed(self):
"""deletes all computed values to leave only hard-set ones"""
self.main_map = {k: v for k, v in self.main_map.items() if v.rule is None}
def validate_condition(self, rule: Rule) -> bool: def validate_condition(self, rule: Rule) -> bool:
try: try:
return all(self.compute(k) == v for k, v in rule.conditions.items()) return all(self.compute(k) == v for k, v in rule.conditions.items())
except EvaluatorError: except EvaluatorError:
return False return False
def attempted_rules_str(self, target: str) -> str:
rules = ", ".join(str(r) for r in self.failed_rules[target]) def get_arg_names(func: Callable) -> tuple[list[str], dict[str, Any]]:
if len(rules) == 0: """
return "" returns the positional argument names of func.
return "attempted rules : " + rules
Parameters
----------
func : Callable
if a function, returns the names of the positional arguments
Returns
-------
list[str]
list of argument names
dict[str, Any]
default values of arguments, if provided
"""
defaults = {}
args = []
for p in inspect.signature(func).parameters.values():
if p.kind in {PARAM.VAR_KEYWORD, PARAM.VAR_POSITIONAL}:
raise TypeError(
f"function {func.__name__} has variadic argument {p.name!r}, which is not allowed"
)
if p.kind is PARAM.KEYWORD_ONLY:
if p.default is PARAM.empty:
raise TypeError(
f"function {func.__name__} has keyword-only argument {p.name!r} "
"with no default. This is not supported at the moment."
)
continue
args.append(p.name)
if p.default is not PARAM.empty:
defaults[p.name] = p.default
return args, defaults
@cache
def _mock_function(num_args: int, num_returns: int) -> Callable:
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
return_str = ", ".join("True" for _ in range(num_returns))
func_name = f"__mock_{num_args}_{num_returns}"
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
scope = {}
exec(func_str, scope)
out_func = scope[func_name]
out_func.__module__ = "evaluator"
return out_func
default_rules: list[Rule] = [ default_rules: list[Rule] = [
# Grid # Grid
*Rule.deduce( Rule(["t", "time_window", "dt", "t_num"], math.build_t_grid),
["t", "time_window", "dt", "t_num"],
math.build_t_grid,
["time_window", "t_num", "dt"],
2,
),
Rule("z_targets", math.build_z_grid), Rule("z_targets", math.build_z_grid),
Rule("adapt_step_size", lambda step_size: step_size == 0), Rule("adapt_step_size", lambda step_size: step_size == 0),
Rule( Rule(
@@ -434,11 +456,9 @@ default_rules: list[Rule] = [
Rule(["input_time", "input_field"], pulse.load_custom_field), Rule(["input_time", "input_field"], pulse.load_custom_field),
Rule("spec_0", lambda fft, field_0: fft(field_0)), Rule("spec_0", lambda fft, field_0: fft(field_0)),
Rule("field_0", lambda ifft, spec_0: ifft(spec_0)), Rule("field_0", lambda ifft, spec_0: ifft(spec_0)),
*Rule.deduce( Rule(
["pre_field_0", "peak_power", "energy", "width"], ["pre_field_0", "peak_power", "energy", "width"],
pulse.adjust_custom_field, pulse.adjust_custom_field,
["energy", "peak_power"],
1,
priorities=[2, 1, 1, 1], priorities=[2, 1, 1, 1],
), ),
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]), Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
@@ -500,7 +520,7 @@ default_rules: list[Rule] = [
Rule( Rule(
"V_eff", "V_eff",
fiber.V_eff_step_index, fiber.V_eff_step_index,
["wavelength", "core_radius", "numerical_aperture"], ["l", "wavelength", "core_radius", "numerical_aperture"],
), ),
Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")), Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")),
Rule( Rule(

View File

@@ -174,7 +174,7 @@ def initial_field_envelope(
raised when shape is not recognized raised when shape is not recognized
""" """
if delay is not None: if delay is not None:
t = t + delay t = t - delay
if shape == "gaussian": if shape == "gaussian":
return gaussian_pulse(t, t0, peak_power) return gaussian_pulse(t, t0, peak_power)
@@ -472,14 +472,14 @@ def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndar
return units.m.inv(units.m(init_wavelength) - delta_w) return units.m.inv(units.m(init_wavelength) - delta_w)
def E0_to_P0(E0: float, t0: float, shape: str): def E0_to_P0(energy: float, t0: float, shape: str):
"""convert an initial total pulse energy to a pulse peak peak_power""" """convert an initial total pulse energy to a pulse peak peak_power"""
return E0 / (t0 * P0T0_to_E0_fac[shape]) return energy / (t0 * P0T0_to_E0_fac[shape])
def P0_to_E0(P0: float, t0: float, shape: str): def P0_to_E0(peak_power: float, t0: float, shape: str):
"""converts initial peak peak_power to pulse energy""" """converts initial peak peak_power to pulse energy"""
return P0 * t0 * P0T0_to_E0_fac[shape] return peak_power * t0 * P0T0_to_E0_fac[shape]
def sech_pulse(t: np.ndarray, t0: float, P0: float): def sech_pulse(t: np.ndarray, t0: float, P0: float):

View File

@@ -545,7 +545,7 @@ def mean_values_plot(
mean_style: dict[str, Any] = None, mean_style: dict[str, Any] = None,
individual_style: dict[str, Any] = None, individual_style: dict[str, Any] = None,
) -> tuple[plt.Line2D, list[plt.Line2D]]: ) -> tuple[plt.Line2D, list[plt.Line2D]]:
x_axis, mean_values, values = transform_mean_values(values, plt_range, params, log, spacing) x_axis, mean_values, values = transform_mean_values_1d(values, plt_range, params, log, spacing)
if renormalize and log is False: if renormalize and log is False:
maxi = mean_values.max() maxi = mean_values.max()
mean_values = mean_values / maxi mean_values = mean_values / maxi
@@ -570,7 +570,7 @@ def mean_values_plot(
) )
def transform_mean_values( def transform_mean_values_1d(
values: np.ndarray, values: np.ndarray,
plt_range: Union[PlotRange, RangeType], plt_range: Union[PlotRange, RangeType],
params: Parameters, params: Parameters,

View File

@@ -6,7 +6,6 @@ scgenerator module but some function may be used in any python program
from __future__ import annotations from __future__ import annotations
import datetime import datetime
import inspect
import itertools import itertools
import json import json
import os import os
@@ -283,60 +282,6 @@ def to_62(i: int) -> str:
return "".join(reversed(arr)) return "".join(reversed(arr))
def get_arg_names(func: Callable) -> list[str]:
"""
returns the positional argument names of func.
Parameters
----------
func : Callable
if a function, returns the names of the positional arguments
Returns
-------
list[str]
[description]
"""
return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty]
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
out_func = scope[tmp_name]
out_func.__module__ = "evaluator"
return out_func
@cache
def _mock_function(num_args: int, num_returns: int) -> Callable:
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
return_str = ", ".join("True" for _ in range(num_returns))
func_name = f"__mock_{num_args}_{num_returns}"
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
scope = {}
exec(func_str, scope)
out_func = scope[func_name]
out_func.__module__ = "evaluator"
return out_func
def fft_functions( def fft_functions(
full_field: bool, full_field: bool,
) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]: ) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:

View File

@@ -1,7 +1,8 @@
import numpy as np import numpy as np
import pytest import pytest
from scgenerator.evaluator import Evaluator, Rule from scgenerator import math, units
from scgenerator.evaluator import Evaluator, EvaluatorError, Rule
@pytest.fixture @pytest.fixture
@@ -16,7 +17,7 @@ def disk_rules() -> list[Rule]:
] ]
def test_simple(disk_rules: list[Rule]): def test_trivial(disk_rules: list[Rule]):
evaluator = Evaluator(*disk_rules) evaluator = Evaluator(*disk_rules)
evaluator.set(radius=5) evaluator.set(radius=5)
@@ -25,3 +26,28 @@ def test_simple(disk_rules: list[Rule]):
evaluator.set(area=5) evaluator.set(area=5)
assert evaluator.compute("area") == 5 assert evaluator.compute("area") == 5
assert evaluator.compute("radius") == 5 assert evaluator.compute("radius") == 5
def test_simple():
evaluator = Evaluator.default()
evaluator.set(wavelength=800e-9, t_num=1024, dt=5e-15)
assert evaluator.compute("t") == pytest.approx(math.tspace(t_num=1024, dt=5e-15))
assert evaluator.compute("w0") == pytest.approx(units.nm(800))
def test_default_args():
def some_function(a: int, b: int, c: int = 5):
return a + b + c
evaluator = Evaluator(Rule("d", some_function))
evaluator.set(a=1, b=3)
with pytest.raises(EvaluatorError):
evaluator.compute("c")
assert evaluator.compute("d") == 9
evaluator.clear_computed()
evaluator.set(c=10)
assert evaluator.compute("c") == 10
assert evaluator.compute("d") == 14