From aa821eb52d46db31300cf5b48c0b91bff21fb52e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 15 Aug 2023 12:16:04 +0200 Subject: [PATCH] added default values to Rule functions --- src/scgenerator/evaluator.py | 198 +++++++++++++++++-------------- src/scgenerator/physics/pulse.py | 10 +- src/scgenerator/plotting.py | 4 +- src/scgenerator/utils.py | 55 --------- tests/test_evaluator.py | 30 ++++- 5 files changed, 144 insertions(+), 153 deletions(-) diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index 52985de..7980b49 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -1,8 +1,10 @@ from __future__ import annotations -import itertools +import inspect import traceback from collections import ChainMap, defaultdict +from functools import cache +from inspect import Parameter as PARAM from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Union import numpy as np @@ -10,13 +12,13 @@ import numpy as np from scgenerator import io, math, operators, utils from scgenerator.const import INF, MANDATORY_PARAMETERS from scgenerator.physics import fiber, materials, plasma, pulse, units -from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger +from scgenerator.utils import get_logger class ErrorRecord(NamedTuple): error: Exception - lookup_stack: list[str] - rules_stack: list[Rule] + lookup_stack: tuple[str] + rules_stack: tuple[Rule] traceback: str @@ -40,7 +42,7 @@ class CyclicDependencyError(EvaluatorError): class AllRulesExhaustedError(EvaluatorError): def __init__(self, target: str): self.target = target - super().__init__(f"tried every rule for {target!r} and no default value is set") + super().__init__(f"every rule for {target!r} failed and no default value is set") class EvaluatorErrorTree: @@ -60,11 +62,14 @@ class EvaluatorErrorTree: def append(self, error: Exception, lookup_stack: list[str], rules_stack: list[Rule]): tr = traceback.format_exc().splitlines() tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `EvaluatorErrorTree.append` frame - self.all_errors.append(ErrorRecord(error, lookup_stack.copy(), rules_stack.copy(), tr)) + self.all_errors.append(ErrorRecord(error, tuple(lookup_stack), tuple(rules_stack), tr)) def compile_error(self) -> EvaluatorError: - failed_rules = set(rec.rules_stack[-1] for rec in self.all_errors if rec.rules_stack) - failed_rules = ("\n" + "-" * 80 + "\n").join(rule.pretty_format() for rule in failed_rules) + failed_rules = set(rec for rec in self.all_errors if rec.rules_stack) + failed_rules = [ + rec.rules_stack[-1].pretty_format() + "\n" + str(rec.error) for rec in failed_rules + ] + failed_rules = ("\n" + "-" * 80 + "\n").join(failed_rules) raise EvaluatorError( f"Couldn't compute {self.target}. {len(self)} rules failed.\n{failed_rules}" @@ -76,6 +81,7 @@ class Rule: func: Callable args: tuple[str] conditions: dict[str, Any] + arg_defaults: dict[str, Any] mock_func: Callable @@ -86,6 +92,7 @@ class Rule: args: tuple[str] | None = None, priorities: Union[int, list[int]] | None = None, conditions: dict[str, str] | None = None, + defaults: dict[str, Any] | None = None, ): targets = list(target) if isinstance(target, (list, tuple)) else [target] self.func = func @@ -94,9 +101,26 @@ class Rule: elif isinstance(priorities, (int, float, np.integer, np.floating)): priorities = [priorities] * len(targets) self.targets = dict(zip(targets, priorities)) - if args is None: - args = get_arg_names(func) - self.args = tuple(args) + + func_args, func_defaults = get_arg_names(func) + + if args is not None: + try: + for func_arg, user_arg in zip(func_args, args, strict=True): + if func_arg in func_defaults: + func_defaults[user_arg] = func_defaults.pop(func_arg) + except ValueError as e: + raise ValueError( + f"length mismatch between arguments of {func.__name__}: " + f"{func_args!r} and provided ones: {args!r}" + ) from e + func_args = args + + func_defaults |= defaults or {} + + self.args = tuple(func_args) + self.arg_defaults = func_defaults + self.mock_func = _mock_function(len(self.args), len(self.targets)) self.conditions = conditions or {} @@ -126,55 +150,6 @@ class Rule: def func_name(self) -> str: return f"{self.func.__module__}.{self.func.__name__}" - @classmethod - def deduce( - cls, - target: Union[str, list[Optional[str]]], - func: Callable, - kwarg_names: list[str], - n_var: int, - args_const: list[str] = None, - priorities: Union[int, list[int]] = None, - ) -> list["Rule"]: - """ - given a function that doesn't need all its keyword arguments specified, will - return a list of Rule obj, one for each combination of n_var specified kwargs - - Parameters - ---------- - target : str | list[str | None] - name of the variable(s) that func returns - func : Callable - function to work with - kwarg_names : list[str] - list of all kwargs of the function to be used - n_var : int - how many shoulf be used per rule - arg_const : list[str], optional - override the name of the positional arguments - - Returns - ------- - list[Rule] - list of all possible rules - - Example - ------- - >> def lol(a, b=None, c=None): - pass - >> print(Rule.deduce(["d"], lol, ["b", "c"], 1)) - [ - Rule(targets={'d': 1}, func=, args=['a', 'b']), - Rule(targets={'d': 1}, func=, args=['a', 'c']) - ] - """ - rules: list[cls] = [] - for var_possibility in itertools.combinations(kwarg_names, n_var): - new_func = func_rewrite(func, list(var_possibility), args_const) - - rules.append(cls(target, new_func, priorities=priorities)) - return rules - class EvaluatedValue(NamedTuple): value: Any @@ -187,9 +162,7 @@ class Evaluator: rules: dict[str, list[Rule]] main_map: dict[str, EvaluatedValue] - param_chain_map: MutableMapping[str, EvaluatedValue] lookup_stack: list[str] - failed_rules: dict[str, list[Rule]] @classmethod def default(cls, full_field: bool = False) -> "Evaluator": @@ -220,8 +193,6 @@ class Evaluator: self.main_map = {} self.logger = get_logger(__name__) - self.failed_rules = defaultdict(list) - self.append(*rules) def __getitem__(self, key: str) -> Any: @@ -302,7 +273,7 @@ class Evaluator: if target not in self.rules or len(self.rules[target]) == 0: raise NoValidRuleError(target) if target in lookup_stack: - raise CyclicDependencyError(target) + raise CyclicDependencyError(target, lookup_stack) lookup_stack.append(target) base_cm_length = len(param_chain_map.maps) @@ -359,19 +330,24 @@ class Evaluator: rules_stack: list[Rule], errors: EvaluatorErrorTree, ) -> dict[str, EvaluatedValue] | None: - try: - for arg in rule.args: - self._compute(arg, check_only, param_chain_map, lookup_stack, rules_stack, errors) - except Exception as e: - errors.append(e, lookup_stack, rules_stack) - return None - - args = [param_chain_map[k].value for k in rule.args] + arg_values = [] + for arg in rule.args: + try: + val = self._compute( + arg, check_only, param_chain_map, lookup_stack, rules_stack, errors + ) + except Exception as e: + if arg in rule.arg_defaults: + val = rule.arg_defaults[arg] + else: + errors.append(e, lookup_stack, rules_stack) + return None + arg_values.append(val) func = rule.mock_func if check_only else rule.func try: - values = func(*args) + values = func(*arg_values) except Exception as e: errors.append(e, lookup_stack, rules_stack) return None @@ -395,27 +371,73 @@ class Evaluator: while len(param_chain_map.maps) > target_size: param_chain_map.maps.pop(0) + def clear_computed(self): + """deletes all computed values to leave only hard-set ones""" + self.main_map = {k: v for k, v in self.main_map.items() if v.rule is None} + def validate_condition(self, rule: Rule) -> bool: try: return all(self.compute(k) == v for k, v in rule.conditions.items()) except EvaluatorError: return False - def attempted_rules_str(self, target: str) -> str: - rules = ", ".join(str(r) for r in self.failed_rules[target]) - if len(rules) == 0: - return "" - return "attempted rules : " + rules + +def get_arg_names(func: Callable) -> tuple[list[str], dict[str, Any]]: + """ + returns the positional argument names of func. + + Parameters + ---------- + func : Callable + if a function, returns the names of the positional arguments + + + Returns + ------- + list[str] + list of argument names + dict[str, Any] + default values of arguments, if provided + """ + defaults = {} + args = [] + for p in inspect.signature(func).parameters.values(): + if p.kind in {PARAM.VAR_KEYWORD, PARAM.VAR_POSITIONAL}: + raise TypeError( + f"function {func.__name__} has variadic argument {p.name!r}, which is not allowed" + ) + + if p.kind is PARAM.KEYWORD_ONLY: + if p.default is PARAM.empty: + raise TypeError( + f"function {func.__name__} has keyword-only argument {p.name!r} " + "with no default. This is not supported at the moment." + ) + continue + + args.append(p.name) + if p.default is not PARAM.empty: + defaults[p.name] = p.default + + return args, defaults + + +@cache +def _mock_function(num_args: int, num_returns: int) -> Callable: + arg_str = ", ".join("a" * (n + 1) for n in range(num_args)) + return_str = ", ".join("True" for _ in range(num_returns)) + func_name = f"__mock_{num_args}_{num_returns}" + func_str = f"def {func_name}({arg_str}):\n return {return_str}" + scope = {} + exec(func_str, scope) + out_func = scope[func_name] + out_func.__module__ = "evaluator" + return out_func default_rules: list[Rule] = [ # Grid - *Rule.deduce( - ["t", "time_window", "dt", "t_num"], - math.build_t_grid, - ["time_window", "t_num", "dt"], - 2, - ), + Rule(["t", "time_window", "dt", "t_num"], math.build_t_grid), Rule("z_targets", math.build_z_grid), Rule("adapt_step_size", lambda step_size: step_size == 0), Rule( @@ -434,11 +456,9 @@ default_rules: list[Rule] = [ Rule(["input_time", "input_field"], pulse.load_custom_field), Rule("spec_0", lambda fft, field_0: fft(field_0)), Rule("field_0", lambda ifft, spec_0: ifft(spec_0)), - *Rule.deduce( + Rule( ["pre_field_0", "peak_power", "energy", "width"], pulse.adjust_custom_field, - ["energy", "peak_power"], - 1, priorities=[2, 1, 1, 1], ), Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]), @@ -500,7 +520,7 @@ default_rules: list[Rule] = [ Rule( "V_eff", fiber.V_eff_step_index, - ["wavelength", "core_radius", "numerical_aperture"], + ["l", "wavelength", "core_radius", "numerical_aperture"], ), Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")), Rule( diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index ddd9c2a..bd9e9cf 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -174,7 +174,7 @@ def initial_field_envelope( raised when shape is not recognized """ if delay is not None: - t = t + delay + t = t - delay if shape == "gaussian": return gaussian_pulse(t, t0, peak_power) @@ -472,14 +472,14 @@ def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndar return units.m.inv(units.m(init_wavelength) - delta_w) -def E0_to_P0(E0: float, t0: float, shape: str): +def E0_to_P0(energy: float, t0: float, shape: str): """convert an initial total pulse energy to a pulse peak peak_power""" - return E0 / (t0 * P0T0_to_E0_fac[shape]) + return energy / (t0 * P0T0_to_E0_fac[shape]) -def P0_to_E0(P0: float, t0: float, shape: str): +def P0_to_E0(peak_power: float, t0: float, shape: str): """converts initial peak peak_power to pulse energy""" - return P0 * t0 * P0T0_to_E0_fac[shape] + return peak_power * t0 * P0T0_to_E0_fac[shape] def sech_pulse(t: np.ndarray, t0: float, P0: float): diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 121835f..8595ecb 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -545,7 +545,7 @@ def mean_values_plot( mean_style: dict[str, Any] = None, individual_style: dict[str, Any] = None, ) -> tuple[plt.Line2D, list[plt.Line2D]]: - x_axis, mean_values, values = transform_mean_values(values, plt_range, params, log, spacing) + x_axis, mean_values, values = transform_mean_values_1d(values, plt_range, params, log, spacing) if renormalize and log is False: maxi = mean_values.max() mean_values = mean_values / maxi @@ -570,7 +570,7 @@ def mean_values_plot( ) -def transform_mean_values( +def transform_mean_values_1d( values: np.ndarray, plt_range: Union[PlotRange, RangeType], params: Parameters, diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index 44490c9..b68520c 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -6,7 +6,6 @@ scgenerator module but some function may be used in any python program from __future__ import annotations import datetime -import inspect import itertools import json import os @@ -283,60 +282,6 @@ def to_62(i: int) -> str: return "".join(reversed(arr)) -def get_arg_names(func: Callable) -> list[str]: - """ - returns the positional argument names of func. - - Parameters - ---------- - func : Callable - if a function, returns the names of the positional arguments - - - Returns - ------- - list[str] - [description] - """ - return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty] - - -def validate_arg_names(names: list[str]): - for n in names: - if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None: - raise ValueError(f"{n} is an invalid parameter name") - - -def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable: - if arg_names is None: - arg_names = get_arg_names(func) - else: - validate_arg_names(arg_names) - validate_arg_names(kwarg_names) - sign_arg_str = ", ".join(arg_names + kwarg_names) - call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names]) - tmp_name = f"{func.__name__}_0" - func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})" - scope = dict(__func__=func) - exec(func_str, scope) - out_func = scope[tmp_name] - out_func.__module__ = "evaluator" - return out_func - - -@cache -def _mock_function(num_args: int, num_returns: int) -> Callable: - arg_str = ", ".join("a" * (n + 1) for n in range(num_args)) - return_str = ", ".join("True" for _ in range(num_returns)) - func_name = f"__mock_{num_args}_{num_returns}" - func_str = f"def {func_name}({arg_str}):\n return {return_str}" - scope = {} - exec(func_str, scope) - out_func = scope[func_name] - out_func.__module__ = "evaluator" - return out_func - - def fft_functions( full_field: bool, ) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]: diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index cb08a7c..0397a11 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -1,7 +1,8 @@ import numpy as np import pytest -from scgenerator.evaluator import Evaluator, Rule +from scgenerator import math, units +from scgenerator.evaluator import Evaluator, EvaluatorError, Rule @pytest.fixture @@ -16,7 +17,7 @@ def disk_rules() -> list[Rule]: ] -def test_simple(disk_rules: list[Rule]): +def test_trivial(disk_rules: list[Rule]): evaluator = Evaluator(*disk_rules) evaluator.set(radius=5) @@ -25,3 +26,28 @@ def test_simple(disk_rules: list[Rule]): evaluator.set(area=5) assert evaluator.compute("area") == 5 assert evaluator.compute("radius") == 5 + + +def test_simple(): + evaluator = Evaluator.default() + evaluator.set(wavelength=800e-9, t_num=1024, dt=5e-15) + + assert evaluator.compute("t") == pytest.approx(math.tspace(t_num=1024, dt=5e-15)) + assert evaluator.compute("w0") == pytest.approx(units.nm(800)) + + +def test_default_args(): + def some_function(a: int, b: int, c: int = 5): + return a + b + c + + evaluator = Evaluator(Rule("d", some_function)) + evaluator.set(a=1, b=3) + + with pytest.raises(EvaluatorError): + evaluator.compute("c") + assert evaluator.compute("d") == 9 + + evaluator.clear_computed() + evaluator.set(c=10) + assert evaluator.compute("c") == 10 + assert evaluator.compute("d") == 14