diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index 3ec5b5c..470ec17 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -19,6 +19,8 @@ def pbar_format(worker_id: int) -> dict[str, Any]: ) +INF = float("inf") + SPEC1_FN = "spectrum_{}.npy" SPECN_FN1 = "spectra_{}.npy" SPEC1_FN_N = "spectrum_{}_{}.npy" diff --git a/src/scgenerator/errors.py b/src/scgenerator/errors.py index 416ec08..df5a1a0 100644 --- a/src/scgenerator/errors.py +++ b/src/scgenerator/errors.py @@ -30,15 +30,3 @@ class MissingParameterError(Exception): class IncompleteDataFolderError(FileNotFoundError): pass - - -class EvaluatorError(Exception): - pass - - -class NoDefaultError(EvaluatorError): - pass - - -class OperatorError(EvaluatorError): - pass diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index a8c7d18..b2544e2 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -1,25 +1,96 @@ +from __future__ import annotations + import itertools -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, Callable, Optional, Union +import traceback +from collections import ChainMap, defaultdict +from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Type, Union import numpy as np from scgenerator import math, operators, utils -from scgenerator.const import MANDATORY_PARAMETERS -from scgenerator.errors import EvaluatorError, NoDefaultError +from scgenerator.const import INF, MANDATORY_PARAMETERS from scgenerator.physics import fiber, materials, plasma, pulse, units from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger +class ErrorRecord(NamedTuple): + error: Exception + traceback: str + + +class EvaluatorError(Exception): + pass + + +class NoValidRuleError(EvaluatorError): + def __init__(self, target: str): + ... + + +class CyclicDependencyError(EvaluatorError): + def __init__(self, target: str, lookup_stack: list[str]): + ... + + +class AllRulesExhaustedError(EvaluatorError): + def __init__(self, target: str): + ... + + +class EvaluatorErrorTree: + target: str + all_errors: dict[Type, ErrorRecord] + + level: int + + def __init__(self, target: str): + self.target = target + self.all_errors = {} + self.level = 0 + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(target={self.target!r})" + + def push(self): + self.level += 1 + + def pop(self): + self.level -= 1 + + def append(self, error: Exception): + tr = traceback.format_exc().splitlines() + tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `append` frame + self.all_errors[type(error)] = ErrorRecord(error, tr) + + def compile_error(self) -> EvaluatorError: + raise EvaluatorError(f"Couldn't compute {self.target}.") + + def summary(self) -> dict[str, int]: + return {k.__name__: len(v) for k, v in self.all_errors.items()} + + def format_summary(self) -> str: + types = [f"{v} errors of type {k!r}" for k, v in self.summary().items()] + return ", ".join(types[:-2] + [" and ".join(types[-2:])]) + + def details(self) -> str: + ... + + class Rule: + targets: dict[str, int] + func: Callable + args: list[str] + conditions: dict[str, Any] + + mock_func: Callable + def __init__( self, target: Union[str, list[Optional[str]]], func: Callable, - args: list[str] = None, - priorities: Union[int, list[int]] = None, - conditions: dict[str, str] = None, + args: list[str] | None = None, + priorities: Union[int, list[int]] | None = None, + conditions: dict[str, str] | None = None, ): targets = list(target) if isinstance(target, (list, tuple)) else [target] self.func = func @@ -43,6 +114,16 @@ class Rule: f"{self.func.__name__} --> [{', '.join(self.targets)}]" ) + def __eq__(self, other: Rule) -> bool: + return self.func == other.func + + def pretty_format(self) -> str: + func_name_elements = self.func_name.split(".") + targets = list(self.targets) + + arg_size = max(self.args, key=len) + func_size = max(func_name_elements, key=len) + @property def func_name(self) -> str: return f"{self.func.__module__}.{self.func.__name__}" @@ -58,7 +139,7 @@ class Rule: priorities: Union[int, list[int]] = None, ) -> list["Rule"]: """ - given a function that doesn't need all its keyword arguemtn specified, will + given a function that doesn't need all its keyword arguments specified, will return a list of Rule obj, one for each combination of n_var specified kwargs Parameters @@ -97,15 +178,21 @@ class Rule: return rules -@dataclass -class EvalStat: - priority: float = np.inf +class EvaluatedValue(NamedTuple): + value: Any + priority: float = INF rule: Rule = None class Evaluator: defaults: dict[str, Any] = {} + rules: dict[str, list[Rule]] + main_map: dict[str, EvaluatedValue] + param_chain_map: MutableMapping[str, EvaluatedValue] + lookup_stack: list[str] + failed_rules: dict[str, list[Rule]] + @classmethod def default(cls, full_field: bool = False) -> "Evaluator": evaluator = cls() @@ -124,28 +211,33 @@ class Evaluator: evaluator.set(**params) for target in MANDATORY_PARAMETERS: evaluator.compute(target, check_only=check_only) - return evaluator.params + return evaluator.main_map @classmethod def register_default_param(cls, key, value): cls.defaults[key] = value - def __init__(self): - self.rules: dict[str, list[Rule]] = defaultdict(list) - self.params = {} - self.__curent_lookup: list[str] = [] - self.__failed_rules: dict[str, list[Rule]] = defaultdict(list) - self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat) + def __init__(self, *rules: Rule): + self.rules = defaultdict(list) + self.main_map = {} self.logger = get_logger(__name__) + self.failed_rules = defaultdict(list) + + self.append(*rules) + + def __getitem__(self, key: str) -> Any: + return self.main_map[key].value + def append(self, *rule: Rule): for r in rule: for t in r.targets: - if t is not None: - self.rules[t].append(r) - self.rules[t].sort(key=lambda el: el.targets[t], reverse=True) + if t is None: + continue + self.rules[t].append(r) + self.rules[t].sort(key=lambda el: el.targets[t], reverse=True) - def set(self, dico: dict[str, Any] = None, **params: Any): + def set(self, **params: Any): """ sets the internal set of parameters @@ -157,18 +249,11 @@ class Evaluator: params : Any if dico is None, update the internal dict of parameters with params """ - if dico is None: - dico = params - self.params.update(dico) - else: - self.reset() - self.params = dico - for k in dico: - self.eval_stats[k].priority = np.inf + for k, v in params.items(): + self.main_map[k] = EvaluatedValue(v, np.inf) def reset(self): - self.params = {} - self.eval_stats = defaultdict(EvalStat) + self.main_map = {} def compute(self, target: str, check_only=False) -> Any: """ @@ -191,99 +276,115 @@ class Evaluator: KeyError there is no saved rule for the target """ - value = self.params.get(target) - if value is None: - prefix = "\t" * len(self.__curent_lookup) - # Avoid cycles - if target in self.__curent_lookup: - raise EvaluatorError( - "cyclic dependency detected : " - f"{target!r} seems to depend on itself, please provide " - f"a value for at least one variable in {self.__curent_lookup!r}. " - + self.attempted_rules_str(target) - ) - else: - self.__curent_lookup.append(target) - - if len(self.rules[target]) == 0: - error = EvaluatorError( - f"no rule for {target}, trying to evaluate {self.__curent_lookup!r}" - ) - else: - error = None - - # try every rule until one succeeds - for ii, rule in enumerate(filter(self.validate_condition, self.rules[target])): - self.logger.debug( - prefix + f"attempt {ii+1} to compute {target}, this time using {rule!r}" - ) - try: - args = [self.compute(k, check_only=check_only) for k in rule.args] - if check_only: - returned_values = rule.mock_func(*args) - else: - returned_values = rule.func(*args) - if len(rule.targets) == 1: - returned_values = [returned_values] - for (param_name, param_priority), returned_value in zip( - rule.targets.items(), returned_values - ): - if ( - param_name == target - or param_name not in self.params - or self.eval_stats[param_name].priority < param_priority - ): - if check_only: - success_str = f"able to compute {param_name} " - else: - v_str = format(returned_value).replace("\n", "") - success_str = f"computed {param_name}={v_str} " - self.logger.debug( - prefix - + success_str - + f"using {rule.func.__name__} from {rule.func.__module__}" - ) - self.set_value(param_name, returned_value, param_priority, rule) - if param_name == target: - value = returned_value - break - except EvaluatorError as e: - error = e - self.logger.debug( - prefix + f"error using {rule.func.__name__} : {str(error).strip()}" - ) - self.__failed_rules[target].append(rule) - continue - except Exception as e: - raise type(e)(f"error while evaluating {target!r}") - else: - default = self.defaults.get(target) - if default is None: - error = error or NoDefaultError( - prefix - + f"No default provided for {target}. Current lookup cycle : {self.__curent_lookup!r}. " - + self.attempted_rules_str(target) - ) - else: - value = default - self.logger.info(prefix + f"using default value of {value} for {target}") - self.set_value(target, value, 0, None) - last_target = self.__curent_lookup.pop() - assert target == last_target - self.__failed_rules[target] = [] - - if value is None and error is not None: - raise error - + errors = EvaluatorErrorTree(target) + param_chain_map = ChainMap(self.main_map) + try: + value = self._compute(target, check_only, param_chain_map, [], errors) + except EvaluatorError as e: + errors.append(e) + raise errors.compile_error() from None + self.merge_chain_map(param_chain_map) return value - def __getitem__(self, key: str) -> Any: - return self.params[key] + def _compute( + self, + target: str, + check_only: bool, + param_chain_map: MutableMapping[str, EvaluatedValue], + lookup_stack: list[str], + errors: EvaluatorErrorTree, + ) -> Any: + errors.push() + if target in param_chain_map: + return param_chain_map[target].value + if target not in self.rules or len(self.rules[target]) == 0: + raise NoValidRuleError(target) + if target in lookup_stack: + raise CyclicDependencyError(target, lookup_stack.copy()) - def set_value(self, key: str, value: Any, priority: int, rule: Rule): - self.params[key] = value - self.eval_stats[key].priority = priority - self.eval_stats[key].rule = rule + lookup_stack.append(target) + base_cm_length = len(param_chain_map.maps) + for rule in self.rules[target]: + values = self.apply_rule(rule, check_only, param_chain_map, lookup_stack, errors) + if self.valid_values(values, rule, check_only, param_chain_map, lookup_stack, errors): + break + lookup_stack.pop() + + if values is None: + if target in self.defaults: + values = {target: self.defaults[target]} + else: + self.clear_chain_map(param_chain_map, base_cm_length) + raise AllRulesExhaustedError(target) + + param_chain_map.maps.insert(0, values) + errors.pop() + return values[target].value + + def valid_values( + self, + values: dict[str, EvaluatedValue] | None, + rule: Rule, + check_only: bool, + param_chain_map: MutableMapping[str, EvaluatedValue], + lookup_stack: list[str], + errors: EvaluatorErrorTree, + ) -> bool: + if values is None: + return False + if check_only: + return True + + for arg, target_value in rule.conditions.items(): + value = self._compute(arg, False, param_chain_map, lookup_stack, errors) + if value != target_value: + return False + + return True + + def apply_rule( + self, + rule: Rule, + check_only: bool, + param_chain_map: MutableMapping[str, EvaluatedValue], + lookup_stack: list[str], + errors: EvaluatorErrorTree, + ) -> dict[str, EvaluatedValue] | None: + try: + for arg in rule.args: + self._compute(arg, check_only, param_chain_map, lookup_stack, errors) + except Exception as e: + errors.append(e) + return None + + args = [param_chain_map[k].value for k in rule.args] + + func = rule.mock_func if check_only else rule.func + + try: + values = func(*args) + except Exception as e: + self.record_error(e) + return None + + if not isinstance(values, tuple): + values = (values,) + + return {k: EvaluatedValue(v, rule.targets[k], rule) for k, v in zip(rule.targets, values)} + + def merge_chain_map(self, param_chain_map: dict[str, EvaluatedValue]): + while len(param_chain_map.maps) > 1: + params = param_chain_map.maps.pop(0) + for k, v in params.items(): + target_priority = self.main_map[k].priority if k in self.main_map else -INF + if v.priority > target_priority: + self.main_map[k] = v + + def clear_chain_map( + self, param_chain_map: MutableMapping[str, EvaluatedValue], target_size: int + ): + while len(param_chain_map.maps) > target_size: + param_chain_map.pop(0) def validate_condition(self, rule: Rule) -> bool: try: @@ -292,7 +393,7 @@ class Evaluator: return False def attempted_rules_str(self, target: str) -> str: - rules = ", ".join(str(r) for r in self.__failed_rules[target]) + rules = ", ".join(str(r) for r in self.failed_rules[target]) if len(rules) == 0: return "" return "attempted rules : " + rules diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index c484022..695295f 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -14,8 +14,7 @@ import numpy as np from scgenerator import utils from scgenerator.const import MANDATORY_PARAMETERS, __version__ -from scgenerator.errors import EvaluatorError -from scgenerator.evaluator import Evaluator +from scgenerator.evaluator import Evaluator, EvaluatorError from scgenerator.io import DatetimeEncoder, decode_datetime_hook from scgenerator.operators import Qualifier, SpecOperator from scgenerator.utils import update_path_name @@ -475,7 +474,7 @@ class Parameters: def get_evaluator(self): evaluator = Evaluator.default(self.full_field) - evaluator.set(self._param_dico.copy()) + evaluator.set(**self._param_dico) return evaluator def dump_dict(self, add_metadata=True) -> dict[str, Any]: @@ -546,14 +545,15 @@ class Parameters: ) from None if exhaustive: for p in self._p_names: - if p not in evaluator.params: + if p not in evaluator.main_map: try: evaluator.compute(p) except Exception: pass computed = self.__class__( - **{k: v for k, v in evaluator.params.items() if k in self._p_names} + **{k: v.value for k, v in evaluator.main_map.items() if k in self._p_names} ) + computed._frozen = True return computed def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str: diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py new file mode 100644 index 0000000..0f22328 --- /dev/null +++ b/tests/test_evaluator.py @@ -0,0 +1,23 @@ +import numpy as np +import pytest + +from scgenerator.evaluator import Evaluator, Rule + + +@pytest.fixture +def disk_rules() -> list[Rule]: + return [ + Rule("radius", lambda diameter: diameter / 2), + Rule("diameter", lambda radius: radius * 2), + Rule("diameter", lambda perimeter: perimeter / np.pi), + Rule("perimeter", lambda diameter: diameter * np.pi), + Rule("area", lambda radius: np.pi * radius**2), + Rule("radius", lambda area: np.sqrt(area / np.pi)), + ] + + +def test_simple(disk_rules: list[Rule]): + evaluator = Evaluator(*disk_rules, Rule("radius", lambda lol: lol * 3)) + evaluator.set() + + assert evaluator.compute("area") == pytest.approx(78.53981633974483)