added default values to Rule functions
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import inspect
|
||||||
import traceback
|
import traceback
|
||||||
from collections import ChainMap, defaultdict
|
from collections import ChainMap, defaultdict
|
||||||
|
from functools import cache
|
||||||
|
from inspect import Parameter as PARAM
|
||||||
from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Union
|
from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -10,13 +12,13 @@ import numpy as np
|
|||||||
from scgenerator import io, math, operators, utils
|
from scgenerator import io, math, operators, utils
|
||||||
from scgenerator.const import INF, MANDATORY_PARAMETERS
|
from scgenerator.const import INF, MANDATORY_PARAMETERS
|
||||||
from scgenerator.physics import fiber, materials, plasma, pulse, units
|
from scgenerator.physics import fiber, materials, plasma, pulse, units
|
||||||
from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger
|
from scgenerator.utils import get_logger
|
||||||
|
|
||||||
|
|
||||||
class ErrorRecord(NamedTuple):
|
class ErrorRecord(NamedTuple):
|
||||||
error: Exception
|
error: Exception
|
||||||
lookup_stack: list[str]
|
lookup_stack: tuple[str]
|
||||||
rules_stack: list[Rule]
|
rules_stack: tuple[Rule]
|
||||||
traceback: str
|
traceback: str
|
||||||
|
|
||||||
|
|
||||||
@@ -40,7 +42,7 @@ class CyclicDependencyError(EvaluatorError):
|
|||||||
class AllRulesExhaustedError(EvaluatorError):
|
class AllRulesExhaustedError(EvaluatorError):
|
||||||
def __init__(self, target: str):
|
def __init__(self, target: str):
|
||||||
self.target = target
|
self.target = target
|
||||||
super().__init__(f"tried every rule for {target!r} and no default value is set")
|
super().__init__(f"every rule for {target!r} failed and no default value is set")
|
||||||
|
|
||||||
|
|
||||||
class EvaluatorErrorTree:
|
class EvaluatorErrorTree:
|
||||||
@@ -60,11 +62,14 @@ class EvaluatorErrorTree:
|
|||||||
def append(self, error: Exception, lookup_stack: list[str], rules_stack: list[Rule]):
|
def append(self, error: Exception, lookup_stack: list[str], rules_stack: list[Rule]):
|
||||||
tr = traceback.format_exc().splitlines()
|
tr = traceback.format_exc().splitlines()
|
||||||
tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `EvaluatorErrorTree.append` frame
|
tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `EvaluatorErrorTree.append` frame
|
||||||
self.all_errors.append(ErrorRecord(error, lookup_stack.copy(), rules_stack.copy(), tr))
|
self.all_errors.append(ErrorRecord(error, tuple(lookup_stack), tuple(rules_stack), tr))
|
||||||
|
|
||||||
def compile_error(self) -> EvaluatorError:
|
def compile_error(self) -> EvaluatorError:
|
||||||
failed_rules = set(rec.rules_stack[-1] for rec in self.all_errors if rec.rules_stack)
|
failed_rules = set(rec for rec in self.all_errors if rec.rules_stack)
|
||||||
failed_rules = ("\n" + "-" * 80 + "\n").join(rule.pretty_format() for rule in failed_rules)
|
failed_rules = [
|
||||||
|
rec.rules_stack[-1].pretty_format() + "\n" + str(rec.error) for rec in failed_rules
|
||||||
|
]
|
||||||
|
failed_rules = ("\n" + "-" * 80 + "\n").join(failed_rules)
|
||||||
|
|
||||||
raise EvaluatorError(
|
raise EvaluatorError(
|
||||||
f"Couldn't compute {self.target}. {len(self)} rules failed.\n{failed_rules}"
|
f"Couldn't compute {self.target}. {len(self)} rules failed.\n{failed_rules}"
|
||||||
@@ -76,6 +81,7 @@ class Rule:
|
|||||||
func: Callable
|
func: Callable
|
||||||
args: tuple[str]
|
args: tuple[str]
|
||||||
conditions: dict[str, Any]
|
conditions: dict[str, Any]
|
||||||
|
arg_defaults: dict[str, Any]
|
||||||
|
|
||||||
mock_func: Callable
|
mock_func: Callable
|
||||||
|
|
||||||
@@ -86,6 +92,7 @@ class Rule:
|
|||||||
args: tuple[str] | None = None,
|
args: tuple[str] | None = None,
|
||||||
priorities: Union[int, list[int]] | None = None,
|
priorities: Union[int, list[int]] | None = None,
|
||||||
conditions: dict[str, str] | None = None,
|
conditions: dict[str, str] | None = None,
|
||||||
|
defaults: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
targets = list(target) if isinstance(target, (list, tuple)) else [target]
|
targets = list(target) if isinstance(target, (list, tuple)) else [target]
|
||||||
self.func = func
|
self.func = func
|
||||||
@@ -94,9 +101,26 @@ class Rule:
|
|||||||
elif isinstance(priorities, (int, float, np.integer, np.floating)):
|
elif isinstance(priorities, (int, float, np.integer, np.floating)):
|
||||||
priorities = [priorities] * len(targets)
|
priorities = [priorities] * len(targets)
|
||||||
self.targets = dict(zip(targets, priorities))
|
self.targets = dict(zip(targets, priorities))
|
||||||
if args is None:
|
|
||||||
args = get_arg_names(func)
|
func_args, func_defaults = get_arg_names(func)
|
||||||
self.args = tuple(args)
|
|
||||||
|
if args is not None:
|
||||||
|
try:
|
||||||
|
for func_arg, user_arg in zip(func_args, args, strict=True):
|
||||||
|
if func_arg in func_defaults:
|
||||||
|
func_defaults[user_arg] = func_defaults.pop(func_arg)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"length mismatch between arguments of {func.__name__}: "
|
||||||
|
f"{func_args!r} and provided ones: {args!r}"
|
||||||
|
) from e
|
||||||
|
func_args = args
|
||||||
|
|
||||||
|
func_defaults |= defaults or {}
|
||||||
|
|
||||||
|
self.args = tuple(func_args)
|
||||||
|
self.arg_defaults = func_defaults
|
||||||
|
|
||||||
self.mock_func = _mock_function(len(self.args), len(self.targets))
|
self.mock_func = _mock_function(len(self.args), len(self.targets))
|
||||||
self.conditions = conditions or {}
|
self.conditions = conditions or {}
|
||||||
|
|
||||||
@@ -126,55 +150,6 @@ class Rule:
|
|||||||
def func_name(self) -> str:
|
def func_name(self) -> str:
|
||||||
return f"{self.func.__module__}.{self.func.__name__}"
|
return f"{self.func.__module__}.{self.func.__name__}"
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def deduce(
|
|
||||||
cls,
|
|
||||||
target: Union[str, list[Optional[str]]],
|
|
||||||
func: Callable,
|
|
||||||
kwarg_names: list[str],
|
|
||||||
n_var: int,
|
|
||||||
args_const: list[str] = None,
|
|
||||||
priorities: Union[int, list[int]] = None,
|
|
||||||
) -> list["Rule"]:
|
|
||||||
"""
|
|
||||||
given a function that doesn't need all its keyword arguments specified, will
|
|
||||||
return a list of Rule obj, one for each combination of n_var specified kwargs
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
target : str | list[str | None]
|
|
||||||
name of the variable(s) that func returns
|
|
||||||
func : Callable
|
|
||||||
function to work with
|
|
||||||
kwarg_names : list[str]
|
|
||||||
list of all kwargs of the function to be used
|
|
||||||
n_var : int
|
|
||||||
how many shoulf be used per rule
|
|
||||||
arg_const : list[str], optional
|
|
||||||
override the name of the positional arguments
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[Rule]
|
|
||||||
list of all possible rules
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
>> def lol(a, b=None, c=None):
|
|
||||||
pass
|
|
||||||
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
|
|
||||||
[
|
|
||||||
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
|
|
||||||
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
rules: list[cls] = []
|
|
||||||
for var_possibility in itertools.combinations(kwarg_names, n_var):
|
|
||||||
new_func = func_rewrite(func, list(var_possibility), args_const)
|
|
||||||
|
|
||||||
rules.append(cls(target, new_func, priorities=priorities))
|
|
||||||
return rules
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluatedValue(NamedTuple):
|
class EvaluatedValue(NamedTuple):
|
||||||
value: Any
|
value: Any
|
||||||
@@ -187,9 +162,7 @@ class Evaluator:
|
|||||||
|
|
||||||
rules: dict[str, list[Rule]]
|
rules: dict[str, list[Rule]]
|
||||||
main_map: dict[str, EvaluatedValue]
|
main_map: dict[str, EvaluatedValue]
|
||||||
param_chain_map: MutableMapping[str, EvaluatedValue]
|
|
||||||
lookup_stack: list[str]
|
lookup_stack: list[str]
|
||||||
failed_rules: dict[str, list[Rule]]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls, full_field: bool = False) -> "Evaluator":
|
def default(cls, full_field: bool = False) -> "Evaluator":
|
||||||
@@ -220,8 +193,6 @@ class Evaluator:
|
|||||||
self.main_map = {}
|
self.main_map = {}
|
||||||
self.logger = get_logger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
self.failed_rules = defaultdict(list)
|
|
||||||
|
|
||||||
self.append(*rules)
|
self.append(*rules)
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> Any:
|
def __getitem__(self, key: str) -> Any:
|
||||||
@@ -302,7 +273,7 @@ class Evaluator:
|
|||||||
if target not in self.rules or len(self.rules[target]) == 0:
|
if target not in self.rules or len(self.rules[target]) == 0:
|
||||||
raise NoValidRuleError(target)
|
raise NoValidRuleError(target)
|
||||||
if target in lookup_stack:
|
if target in lookup_stack:
|
||||||
raise CyclicDependencyError(target)
|
raise CyclicDependencyError(target, lookup_stack)
|
||||||
|
|
||||||
lookup_stack.append(target)
|
lookup_stack.append(target)
|
||||||
base_cm_length = len(param_chain_map.maps)
|
base_cm_length = len(param_chain_map.maps)
|
||||||
@@ -359,19 +330,24 @@ class Evaluator:
|
|||||||
rules_stack: list[Rule],
|
rules_stack: list[Rule],
|
||||||
errors: EvaluatorErrorTree,
|
errors: EvaluatorErrorTree,
|
||||||
) -> dict[str, EvaluatedValue] | None:
|
) -> dict[str, EvaluatedValue] | None:
|
||||||
try:
|
arg_values = []
|
||||||
for arg in rule.args:
|
for arg in rule.args:
|
||||||
self._compute(arg, check_only, param_chain_map, lookup_stack, rules_stack, errors)
|
try:
|
||||||
except Exception as e:
|
val = self._compute(
|
||||||
errors.append(e, lookup_stack, rules_stack)
|
arg, check_only, param_chain_map, lookup_stack, rules_stack, errors
|
||||||
return None
|
)
|
||||||
|
except Exception as e:
|
||||||
args = [param_chain_map[k].value for k in rule.args]
|
if arg in rule.arg_defaults:
|
||||||
|
val = rule.arg_defaults[arg]
|
||||||
|
else:
|
||||||
|
errors.append(e, lookup_stack, rules_stack)
|
||||||
|
return None
|
||||||
|
arg_values.append(val)
|
||||||
|
|
||||||
func = rule.mock_func if check_only else rule.func
|
func = rule.mock_func if check_only else rule.func
|
||||||
|
|
||||||
try:
|
try:
|
||||||
values = func(*args)
|
values = func(*arg_values)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.append(e, lookup_stack, rules_stack)
|
errors.append(e, lookup_stack, rules_stack)
|
||||||
return None
|
return None
|
||||||
@@ -395,27 +371,73 @@ class Evaluator:
|
|||||||
while len(param_chain_map.maps) > target_size:
|
while len(param_chain_map.maps) > target_size:
|
||||||
param_chain_map.maps.pop(0)
|
param_chain_map.maps.pop(0)
|
||||||
|
|
||||||
|
def clear_computed(self):
|
||||||
|
"""deletes all computed values to leave only hard-set ones"""
|
||||||
|
self.main_map = {k: v for k, v in self.main_map.items() if v.rule is None}
|
||||||
|
|
||||||
def validate_condition(self, rule: Rule) -> bool:
|
def validate_condition(self, rule: Rule) -> bool:
|
||||||
try:
|
try:
|
||||||
return all(self.compute(k) == v for k, v in rule.conditions.items())
|
return all(self.compute(k) == v for k, v in rule.conditions.items())
|
||||||
except EvaluatorError:
|
except EvaluatorError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def attempted_rules_str(self, target: str) -> str:
|
|
||||||
rules = ", ".join(str(r) for r in self.failed_rules[target])
|
def get_arg_names(func: Callable) -> tuple[list[str], dict[str, Any]]:
|
||||||
if len(rules) == 0:
|
"""
|
||||||
return ""
|
returns the positional argument names of func.
|
||||||
return "attempted rules : " + rules
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
func : Callable
|
||||||
|
if a function, returns the names of the positional arguments
|
||||||
|
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[str]
|
||||||
|
list of argument names
|
||||||
|
dict[str, Any]
|
||||||
|
default values of arguments, if provided
|
||||||
|
"""
|
||||||
|
defaults = {}
|
||||||
|
args = []
|
||||||
|
for p in inspect.signature(func).parameters.values():
|
||||||
|
if p.kind in {PARAM.VAR_KEYWORD, PARAM.VAR_POSITIONAL}:
|
||||||
|
raise TypeError(
|
||||||
|
f"function {func.__name__} has variadic argument {p.name!r}, which is not allowed"
|
||||||
|
)
|
||||||
|
|
||||||
|
if p.kind is PARAM.KEYWORD_ONLY:
|
||||||
|
if p.default is PARAM.empty:
|
||||||
|
raise TypeError(
|
||||||
|
f"function {func.__name__} has keyword-only argument {p.name!r} "
|
||||||
|
"with no default. This is not supported at the moment."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
args.append(p.name)
|
||||||
|
if p.default is not PARAM.empty:
|
||||||
|
defaults[p.name] = p.default
|
||||||
|
|
||||||
|
return args, defaults
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _mock_function(num_args: int, num_returns: int) -> Callable:
|
||||||
|
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
|
||||||
|
return_str = ", ".join("True" for _ in range(num_returns))
|
||||||
|
func_name = f"__mock_{num_args}_{num_returns}"
|
||||||
|
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
|
||||||
|
scope = {}
|
||||||
|
exec(func_str, scope)
|
||||||
|
out_func = scope[func_name]
|
||||||
|
out_func.__module__ = "evaluator"
|
||||||
|
return out_func
|
||||||
|
|
||||||
|
|
||||||
default_rules: list[Rule] = [
|
default_rules: list[Rule] = [
|
||||||
# Grid
|
# Grid
|
||||||
*Rule.deduce(
|
Rule(["t", "time_window", "dt", "t_num"], math.build_t_grid),
|
||||||
["t", "time_window", "dt", "t_num"],
|
|
||||||
math.build_t_grid,
|
|
||||||
["time_window", "t_num", "dt"],
|
|
||||||
2,
|
|
||||||
),
|
|
||||||
Rule("z_targets", math.build_z_grid),
|
Rule("z_targets", math.build_z_grid),
|
||||||
Rule("adapt_step_size", lambda step_size: step_size == 0),
|
Rule("adapt_step_size", lambda step_size: step_size == 0),
|
||||||
Rule(
|
Rule(
|
||||||
@@ -434,11 +456,9 @@ default_rules: list[Rule] = [
|
|||||||
Rule(["input_time", "input_field"], pulse.load_custom_field),
|
Rule(["input_time", "input_field"], pulse.load_custom_field),
|
||||||
Rule("spec_0", lambda fft, field_0: fft(field_0)),
|
Rule("spec_0", lambda fft, field_0: fft(field_0)),
|
||||||
Rule("field_0", lambda ifft, spec_0: ifft(spec_0)),
|
Rule("field_0", lambda ifft, spec_0: ifft(spec_0)),
|
||||||
*Rule.deduce(
|
Rule(
|
||||||
["pre_field_0", "peak_power", "energy", "width"],
|
["pre_field_0", "peak_power", "energy", "width"],
|
||||||
pulse.adjust_custom_field,
|
pulse.adjust_custom_field,
|
||||||
["energy", "peak_power"],
|
|
||||||
1,
|
|
||||||
priorities=[2, 1, 1, 1],
|
priorities=[2, 1, 1, 1],
|
||||||
),
|
),
|
||||||
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
||||||
@@ -500,7 +520,7 @@ default_rules: list[Rule] = [
|
|||||||
Rule(
|
Rule(
|
||||||
"V_eff",
|
"V_eff",
|
||||||
fiber.V_eff_step_index,
|
fiber.V_eff_step_index,
|
||||||
["wavelength", "core_radius", "numerical_aperture"],
|
["l", "wavelength", "core_radius", "numerical_aperture"],
|
||||||
),
|
),
|
||||||
Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")),
|
Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")),
|
||||||
Rule(
|
Rule(
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ def initial_field_envelope(
|
|||||||
raised when shape is not recognized
|
raised when shape is not recognized
|
||||||
"""
|
"""
|
||||||
if delay is not None:
|
if delay is not None:
|
||||||
t = t + delay
|
t = t - delay
|
||||||
|
|
||||||
if shape == "gaussian":
|
if shape == "gaussian":
|
||||||
return gaussian_pulse(t, t0, peak_power)
|
return gaussian_pulse(t, t0, peak_power)
|
||||||
@@ -472,14 +472,14 @@ def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndar
|
|||||||
return units.m.inv(units.m(init_wavelength) - delta_w)
|
return units.m.inv(units.m(init_wavelength) - delta_w)
|
||||||
|
|
||||||
|
|
||||||
def E0_to_P0(E0: float, t0: float, shape: str):
|
def E0_to_P0(energy: float, t0: float, shape: str):
|
||||||
"""convert an initial total pulse energy to a pulse peak peak_power"""
|
"""convert an initial total pulse energy to a pulse peak peak_power"""
|
||||||
return E0 / (t0 * P0T0_to_E0_fac[shape])
|
return energy / (t0 * P0T0_to_E0_fac[shape])
|
||||||
|
|
||||||
|
|
||||||
def P0_to_E0(P0: float, t0: float, shape: str):
|
def P0_to_E0(peak_power: float, t0: float, shape: str):
|
||||||
"""converts initial peak peak_power to pulse energy"""
|
"""converts initial peak peak_power to pulse energy"""
|
||||||
return P0 * t0 * P0T0_to_E0_fac[shape]
|
return peak_power * t0 * P0T0_to_E0_fac[shape]
|
||||||
|
|
||||||
|
|
||||||
def sech_pulse(t: np.ndarray, t0: float, P0: float):
|
def sech_pulse(t: np.ndarray, t0: float, P0: float):
|
||||||
|
|||||||
@@ -545,7 +545,7 @@ def mean_values_plot(
|
|||||||
mean_style: dict[str, Any] = None,
|
mean_style: dict[str, Any] = None,
|
||||||
individual_style: dict[str, Any] = None,
|
individual_style: dict[str, Any] = None,
|
||||||
) -> tuple[plt.Line2D, list[plt.Line2D]]:
|
) -> tuple[plt.Line2D, list[plt.Line2D]]:
|
||||||
x_axis, mean_values, values = transform_mean_values(values, plt_range, params, log, spacing)
|
x_axis, mean_values, values = transform_mean_values_1d(values, plt_range, params, log, spacing)
|
||||||
if renormalize and log is False:
|
if renormalize and log is False:
|
||||||
maxi = mean_values.max()
|
maxi = mean_values.max()
|
||||||
mean_values = mean_values / maxi
|
mean_values = mean_values / maxi
|
||||||
@@ -570,7 +570,7 @@ def mean_values_plot(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def transform_mean_values(
|
def transform_mean_values_1d(
|
||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
plt_range: Union[PlotRange, RangeType],
|
plt_range: Union[PlotRange, RangeType],
|
||||||
params: Parameters,
|
params: Parameters,
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ scgenerator module but some function may be used in any python program
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import inspect
|
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -283,60 +282,6 @@ def to_62(i: int) -> str:
|
|||||||
return "".join(reversed(arr))
|
return "".join(reversed(arr))
|
||||||
|
|
||||||
|
|
||||||
def get_arg_names(func: Callable) -> list[str]:
|
|
||||||
"""
|
|
||||||
returns the positional argument names of func.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
func : Callable
|
|
||||||
if a function, returns the names of the positional arguments
|
|
||||||
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[str]
|
|
||||||
[description]
|
|
||||||
"""
|
|
||||||
return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty]
|
|
||||||
|
|
||||||
|
|
||||||
def validate_arg_names(names: list[str]):
|
|
||||||
for n in names:
|
|
||||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
|
||||||
raise ValueError(f"{n} is an invalid parameter name")
|
|
||||||
|
|
||||||
|
|
||||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
|
|
||||||
if arg_names is None:
|
|
||||||
arg_names = get_arg_names(func)
|
|
||||||
else:
|
|
||||||
validate_arg_names(arg_names)
|
|
||||||
validate_arg_names(kwarg_names)
|
|
||||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
|
||||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
|
||||||
tmp_name = f"{func.__name__}_0"
|
|
||||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
|
||||||
scope = dict(__func__=func)
|
|
||||||
exec(func_str, scope)
|
|
||||||
out_func = scope[tmp_name]
|
|
||||||
out_func.__module__ = "evaluator"
|
|
||||||
return out_func
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def _mock_function(num_args: int, num_returns: int) -> Callable:
|
|
||||||
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
|
|
||||||
return_str = ", ".join("True" for _ in range(num_returns))
|
|
||||||
func_name = f"__mock_{num_args}_{num_returns}"
|
|
||||||
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
|
|
||||||
scope = {}
|
|
||||||
exec(func_str, scope)
|
|
||||||
out_func = scope[func_name]
|
|
||||||
out_func.__module__ = "evaluator"
|
|
||||||
return out_func
|
|
||||||
|
|
||||||
|
|
||||||
def fft_functions(
|
def fft_functions(
|
||||||
full_field: bool,
|
full_field: bool,
|
||||||
) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:
|
) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from scgenerator.evaluator import Evaluator, Rule
|
from scgenerator import math, units
|
||||||
|
from scgenerator.evaluator import Evaluator, EvaluatorError, Rule
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -16,7 +17,7 @@ def disk_rules() -> list[Rule]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_simple(disk_rules: list[Rule]):
|
def test_trivial(disk_rules: list[Rule]):
|
||||||
evaluator = Evaluator(*disk_rules)
|
evaluator = Evaluator(*disk_rules)
|
||||||
evaluator.set(radius=5)
|
evaluator.set(radius=5)
|
||||||
|
|
||||||
@@ -25,3 +26,28 @@ def test_simple(disk_rules: list[Rule]):
|
|||||||
evaluator.set(area=5)
|
evaluator.set(area=5)
|
||||||
assert evaluator.compute("area") == 5
|
assert evaluator.compute("area") == 5
|
||||||
assert evaluator.compute("radius") == 5
|
assert evaluator.compute("radius") == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple():
|
||||||
|
evaluator = Evaluator.default()
|
||||||
|
evaluator.set(wavelength=800e-9, t_num=1024, dt=5e-15)
|
||||||
|
|
||||||
|
assert evaluator.compute("t") == pytest.approx(math.tspace(t_num=1024, dt=5e-15))
|
||||||
|
assert evaluator.compute("w0") == pytest.approx(units.nm(800))
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_args():
|
||||||
|
def some_function(a: int, b: int, c: int = 5):
|
||||||
|
return a + b + c
|
||||||
|
|
||||||
|
evaluator = Evaluator(Rule("d", some_function))
|
||||||
|
evaluator.set(a=1, b=3)
|
||||||
|
|
||||||
|
with pytest.raises(EvaluatorError):
|
||||||
|
evaluator.compute("c")
|
||||||
|
assert evaluator.compute("d") == 9
|
||||||
|
|
||||||
|
evaluator.clear_computed()
|
||||||
|
evaluator.set(c=10)
|
||||||
|
assert evaluator.compute("c") == 10
|
||||||
|
assert evaluator.compute("d") == 14
|
||||||
|
|||||||
Reference in New Issue
Block a user