Update evaluator
This commit is contained in:
@@ -19,6 +19,8 @@ def pbar_format(worker_id: int) -> dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
INF = float("inf")
|
||||
|
||||
SPEC1_FN = "spectrum_{}.npy"
|
||||
SPECN_FN1 = "spectra_{}.npy"
|
||||
SPEC1_FN_N = "spectrum_{}_{}.npy"
|
||||
|
||||
@@ -30,15 +30,3 @@ class MissingParameterError(Exception):
|
||||
|
||||
class IncompleteDataFolderError(FileNotFoundError):
|
||||
pass
|
||||
|
||||
|
||||
class EvaluatorError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoDefaultError(EvaluatorError):
|
||||
pass
|
||||
|
||||
|
||||
class OperatorError(EvaluatorError):
|
||||
pass
|
||||
|
||||
@@ -1,25 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
import traceback
|
||||
from collections import ChainMap, defaultdict
|
||||
from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scgenerator import math, operators, utils
|
||||
from scgenerator.const import MANDATORY_PARAMETERS
|
||||
from scgenerator.errors import EvaluatorError, NoDefaultError
|
||||
from scgenerator.const import INF, MANDATORY_PARAMETERS
|
||||
from scgenerator.physics import fiber, materials, plasma, pulse, units
|
||||
from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger
|
||||
|
||||
|
||||
class ErrorRecord(NamedTuple):
|
||||
error: Exception
|
||||
traceback: str
|
||||
|
||||
|
||||
class EvaluatorError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoValidRuleError(EvaluatorError):
|
||||
def __init__(self, target: str):
|
||||
...
|
||||
|
||||
|
||||
class CyclicDependencyError(EvaluatorError):
|
||||
def __init__(self, target: str, lookup_stack: list[str]):
|
||||
...
|
||||
|
||||
|
||||
class AllRulesExhaustedError(EvaluatorError):
|
||||
def __init__(self, target: str):
|
||||
...
|
||||
|
||||
|
||||
class EvaluatorErrorTree:
|
||||
target: str
|
||||
all_errors: dict[Type, ErrorRecord]
|
||||
|
||||
level: int
|
||||
|
||||
def __init__(self, target: str):
|
||||
self.target = target
|
||||
self.all_errors = {}
|
||||
self.level = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(target={self.target!r})"
|
||||
|
||||
def push(self):
|
||||
self.level += 1
|
||||
|
||||
def pop(self):
|
||||
self.level -= 1
|
||||
|
||||
def append(self, error: Exception):
|
||||
tr = traceback.format_exc().splitlines()
|
||||
tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `append` frame
|
||||
self.all_errors[type(error)] = ErrorRecord(error, tr)
|
||||
|
||||
def compile_error(self) -> EvaluatorError:
|
||||
raise EvaluatorError(f"Couldn't compute {self.target}.")
|
||||
|
||||
def summary(self) -> dict[str, int]:
|
||||
return {k.__name__: len(v) for k, v in self.all_errors.items()}
|
||||
|
||||
def format_summary(self) -> str:
|
||||
types = [f"{v} errors of type {k!r}" for k, v in self.summary().items()]
|
||||
return ", ".join(types[:-2] + [" and ".join(types[-2:])])
|
||||
|
||||
def details(self) -> str:
|
||||
...
|
||||
|
||||
|
||||
class Rule:
|
||||
targets: dict[str, int]
|
||||
func: Callable
|
||||
args: list[str]
|
||||
conditions: dict[str, Any]
|
||||
|
||||
mock_func: Callable
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[str, list[Optional[str]]],
|
||||
func: Callable,
|
||||
args: list[str] = None,
|
||||
priorities: Union[int, list[int]] = None,
|
||||
conditions: dict[str, str] = None,
|
||||
args: list[str] | None = None,
|
||||
priorities: Union[int, list[int]] | None = None,
|
||||
conditions: dict[str, str] | None = None,
|
||||
):
|
||||
targets = list(target) if isinstance(target, (list, tuple)) else [target]
|
||||
self.func = func
|
||||
@@ -43,6 +114,16 @@ class Rule:
|
||||
f"{self.func.__name__} --> [{', '.join(self.targets)}]"
|
||||
)
|
||||
|
||||
def __eq__(self, other: Rule) -> bool:
|
||||
return self.func == other.func
|
||||
|
||||
def pretty_format(self) -> str:
|
||||
func_name_elements = self.func_name.split(".")
|
||||
targets = list(self.targets)
|
||||
|
||||
arg_size = max(self.args, key=len)
|
||||
func_size = max(func_name_elements, key=len)
|
||||
|
||||
@property
|
||||
def func_name(self) -> str:
|
||||
return f"{self.func.__module__}.{self.func.__name__}"
|
||||
@@ -58,7 +139,7 @@ class Rule:
|
||||
priorities: Union[int, list[int]] = None,
|
||||
) -> list["Rule"]:
|
||||
"""
|
||||
given a function that doesn't need all its keyword arguemtn specified, will
|
||||
given a function that doesn't need all its keyword arguments specified, will
|
||||
return a list of Rule obj, one for each combination of n_var specified kwargs
|
||||
|
||||
Parameters
|
||||
@@ -97,15 +178,21 @@ class Rule:
|
||||
return rules
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalStat:
|
||||
priority: float = np.inf
|
||||
class EvaluatedValue(NamedTuple):
|
||||
value: Any
|
||||
priority: float = INF
|
||||
rule: Rule = None
|
||||
|
||||
|
||||
class Evaluator:
|
||||
defaults: dict[str, Any] = {}
|
||||
|
||||
rules: dict[str, list[Rule]]
|
||||
main_map: dict[str, EvaluatedValue]
|
||||
param_chain_map: MutableMapping[str, EvaluatedValue]
|
||||
lookup_stack: list[str]
|
||||
failed_rules: dict[str, list[Rule]]
|
||||
|
||||
@classmethod
|
||||
def default(cls, full_field: bool = False) -> "Evaluator":
|
||||
evaluator = cls()
|
||||
@@ -124,28 +211,33 @@ class Evaluator:
|
||||
evaluator.set(**params)
|
||||
for target in MANDATORY_PARAMETERS:
|
||||
evaluator.compute(target, check_only=check_only)
|
||||
return evaluator.params
|
||||
return evaluator.main_map
|
||||
|
||||
@classmethod
|
||||
def register_default_param(cls, key, value):
|
||||
cls.defaults[key] = value
|
||||
|
||||
def __init__(self):
|
||||
self.rules: dict[str, list[Rule]] = defaultdict(list)
|
||||
self.params = {}
|
||||
self.__curent_lookup: list[str] = []
|
||||
self.__failed_rules: dict[str, list[Rule]] = defaultdict(list)
|
||||
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
|
||||
def __init__(self, *rules: Rule):
|
||||
self.rules = defaultdict(list)
|
||||
self.main_map = {}
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
self.failed_rules = defaultdict(list)
|
||||
|
||||
self.append(*rules)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.main_map[key].value
|
||||
|
||||
def append(self, *rule: Rule):
|
||||
for r in rule:
|
||||
for t in r.targets:
|
||||
if t is not None:
|
||||
self.rules[t].append(r)
|
||||
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
|
||||
if t is None:
|
||||
continue
|
||||
self.rules[t].append(r)
|
||||
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
|
||||
|
||||
def set(self, dico: dict[str, Any] = None, **params: Any):
|
||||
def set(self, **params: Any):
|
||||
"""
|
||||
sets the internal set of parameters
|
||||
|
||||
@@ -157,18 +249,11 @@ class Evaluator:
|
||||
params : Any
|
||||
if dico is None, update the internal dict of parameters with params
|
||||
"""
|
||||
if dico is None:
|
||||
dico = params
|
||||
self.params.update(dico)
|
||||
else:
|
||||
self.reset()
|
||||
self.params = dico
|
||||
for k in dico:
|
||||
self.eval_stats[k].priority = np.inf
|
||||
for k, v in params.items():
|
||||
self.main_map[k] = EvaluatedValue(v, np.inf)
|
||||
|
||||
def reset(self):
|
||||
self.params = {}
|
||||
self.eval_stats = defaultdict(EvalStat)
|
||||
self.main_map = {}
|
||||
|
||||
def compute(self, target: str, check_only=False) -> Any:
|
||||
"""
|
||||
@@ -191,99 +276,115 @@ class Evaluator:
|
||||
KeyError
|
||||
there is no saved rule for the target
|
||||
"""
|
||||
value = self.params.get(target)
|
||||
if value is None:
|
||||
prefix = "\t" * len(self.__curent_lookup)
|
||||
# Avoid cycles
|
||||
if target in self.__curent_lookup:
|
||||
raise EvaluatorError(
|
||||
"cyclic dependency detected : "
|
||||
f"{target!r} seems to depend on itself, please provide "
|
||||
f"a value for at least one variable in {self.__curent_lookup!r}. "
|
||||
+ self.attempted_rules_str(target)
|
||||
)
|
||||
else:
|
||||
self.__curent_lookup.append(target)
|
||||
|
||||
if len(self.rules[target]) == 0:
|
||||
error = EvaluatorError(
|
||||
f"no rule for {target}, trying to evaluate {self.__curent_lookup!r}"
|
||||
)
|
||||
else:
|
||||
error = None
|
||||
|
||||
# try every rule until one succeeds
|
||||
for ii, rule in enumerate(filter(self.validate_condition, self.rules[target])):
|
||||
self.logger.debug(
|
||||
prefix + f"attempt {ii+1} to compute {target}, this time using {rule!r}"
|
||||
)
|
||||
try:
|
||||
args = [self.compute(k, check_only=check_only) for k in rule.args]
|
||||
if check_only:
|
||||
returned_values = rule.mock_func(*args)
|
||||
else:
|
||||
returned_values = rule.func(*args)
|
||||
if len(rule.targets) == 1:
|
||||
returned_values = [returned_values]
|
||||
for (param_name, param_priority), returned_value in zip(
|
||||
rule.targets.items(), returned_values
|
||||
):
|
||||
if (
|
||||
param_name == target
|
||||
or param_name not in self.params
|
||||
or self.eval_stats[param_name].priority < param_priority
|
||||
):
|
||||
if check_only:
|
||||
success_str = f"able to compute {param_name} "
|
||||
else:
|
||||
v_str = format(returned_value).replace("\n", "")
|
||||
success_str = f"computed {param_name}={v_str} "
|
||||
self.logger.debug(
|
||||
prefix
|
||||
+ success_str
|
||||
+ f"using {rule.func.__name__} from {rule.func.__module__}"
|
||||
)
|
||||
self.set_value(param_name, returned_value, param_priority, rule)
|
||||
if param_name == target:
|
||||
value = returned_value
|
||||
break
|
||||
except EvaluatorError as e:
|
||||
error = e
|
||||
self.logger.debug(
|
||||
prefix + f"error using {rule.func.__name__} : {str(error).strip()}"
|
||||
)
|
||||
self.__failed_rules[target].append(rule)
|
||||
continue
|
||||
except Exception as e:
|
||||
raise type(e)(f"error while evaluating {target!r}")
|
||||
else:
|
||||
default = self.defaults.get(target)
|
||||
if default is None:
|
||||
error = error or NoDefaultError(
|
||||
prefix
|
||||
+ f"No default provided for {target}. Current lookup cycle : {self.__curent_lookup!r}. "
|
||||
+ self.attempted_rules_str(target)
|
||||
)
|
||||
else:
|
||||
value = default
|
||||
self.logger.info(prefix + f"using default value of {value} for {target}")
|
||||
self.set_value(target, value, 0, None)
|
||||
last_target = self.__curent_lookup.pop()
|
||||
assert target == last_target
|
||||
self.__failed_rules[target] = []
|
||||
|
||||
if value is None and error is not None:
|
||||
raise error
|
||||
|
||||
errors = EvaluatorErrorTree(target)
|
||||
param_chain_map = ChainMap(self.main_map)
|
||||
try:
|
||||
value = self._compute(target, check_only, param_chain_map, [], errors)
|
||||
except EvaluatorError as e:
|
||||
errors.append(e)
|
||||
raise errors.compile_error() from None
|
||||
self.merge_chain_map(param_chain_map)
|
||||
return value
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.params[key]
|
||||
def _compute(
|
||||
self,
|
||||
target: str,
|
||||
check_only: bool,
|
||||
param_chain_map: MutableMapping[str, EvaluatedValue],
|
||||
lookup_stack: list[str],
|
||||
errors: EvaluatorErrorTree,
|
||||
) -> Any:
|
||||
errors.push()
|
||||
if target in param_chain_map:
|
||||
return param_chain_map[target].value
|
||||
if target not in self.rules or len(self.rules[target]) == 0:
|
||||
raise NoValidRuleError(target)
|
||||
if target in lookup_stack:
|
||||
raise CyclicDependencyError(target, lookup_stack.copy())
|
||||
|
||||
def set_value(self, key: str, value: Any, priority: int, rule: Rule):
|
||||
self.params[key] = value
|
||||
self.eval_stats[key].priority = priority
|
||||
self.eval_stats[key].rule = rule
|
||||
lookup_stack.append(target)
|
||||
base_cm_length = len(param_chain_map.maps)
|
||||
for rule in self.rules[target]:
|
||||
values = self.apply_rule(rule, check_only, param_chain_map, lookup_stack, errors)
|
||||
if self.valid_values(values, rule, check_only, param_chain_map, lookup_stack, errors):
|
||||
break
|
||||
lookup_stack.pop()
|
||||
|
||||
if values is None:
|
||||
if target in self.defaults:
|
||||
values = {target: self.defaults[target]}
|
||||
else:
|
||||
self.clear_chain_map(param_chain_map, base_cm_length)
|
||||
raise AllRulesExhaustedError(target)
|
||||
|
||||
param_chain_map.maps.insert(0, values)
|
||||
errors.pop()
|
||||
return values[target].value
|
||||
|
||||
def valid_values(
|
||||
self,
|
||||
values: dict[str, EvaluatedValue] | None,
|
||||
rule: Rule,
|
||||
check_only: bool,
|
||||
param_chain_map: MutableMapping[str, EvaluatedValue],
|
||||
lookup_stack: list[str],
|
||||
errors: EvaluatorErrorTree,
|
||||
) -> bool:
|
||||
if values is None:
|
||||
return False
|
||||
if check_only:
|
||||
return True
|
||||
|
||||
for arg, target_value in rule.conditions.items():
|
||||
value = self._compute(arg, False, param_chain_map, lookup_stack, errors)
|
||||
if value != target_value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def apply_rule(
|
||||
self,
|
||||
rule: Rule,
|
||||
check_only: bool,
|
||||
param_chain_map: MutableMapping[str, EvaluatedValue],
|
||||
lookup_stack: list[str],
|
||||
errors: EvaluatorErrorTree,
|
||||
) -> dict[str, EvaluatedValue] | None:
|
||||
try:
|
||||
for arg in rule.args:
|
||||
self._compute(arg, check_only, param_chain_map, lookup_stack, errors)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
return None
|
||||
|
||||
args = [param_chain_map[k].value for k in rule.args]
|
||||
|
||||
func = rule.mock_func if check_only else rule.func
|
||||
|
||||
try:
|
||||
values = func(*args)
|
||||
except Exception as e:
|
||||
self.record_error(e)
|
||||
return None
|
||||
|
||||
if not isinstance(values, tuple):
|
||||
values = (values,)
|
||||
|
||||
return {k: EvaluatedValue(v, rule.targets[k], rule) for k, v in zip(rule.targets, values)}
|
||||
|
||||
def merge_chain_map(self, param_chain_map: dict[str, EvaluatedValue]):
|
||||
while len(param_chain_map.maps) > 1:
|
||||
params = param_chain_map.maps.pop(0)
|
||||
for k, v in params.items():
|
||||
target_priority = self.main_map[k].priority if k in self.main_map else -INF
|
||||
if v.priority > target_priority:
|
||||
self.main_map[k] = v
|
||||
|
||||
def clear_chain_map(
|
||||
self, param_chain_map: MutableMapping[str, EvaluatedValue], target_size: int
|
||||
):
|
||||
while len(param_chain_map.maps) > target_size:
|
||||
param_chain_map.pop(0)
|
||||
|
||||
def validate_condition(self, rule: Rule) -> bool:
|
||||
try:
|
||||
@@ -292,7 +393,7 @@ class Evaluator:
|
||||
return False
|
||||
|
||||
def attempted_rules_str(self, target: str) -> str:
|
||||
rules = ", ".join(str(r) for r in self.__failed_rules[target])
|
||||
rules = ", ".join(str(r) for r in self.failed_rules[target])
|
||||
if len(rules) == 0:
|
||||
return ""
|
||||
return "attempted rules : " + rules
|
||||
|
||||
@@ -14,8 +14,7 @@ import numpy as np
|
||||
|
||||
from scgenerator import utils
|
||||
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
||||
from scgenerator.errors import EvaluatorError
|
||||
from scgenerator.evaluator import Evaluator
|
||||
from scgenerator.evaluator import Evaluator, EvaluatorError
|
||||
from scgenerator.io import DatetimeEncoder, decode_datetime_hook
|
||||
from scgenerator.operators import Qualifier, SpecOperator
|
||||
from scgenerator.utils import update_path_name
|
||||
@@ -475,7 +474,7 @@ class Parameters:
|
||||
|
||||
def get_evaluator(self):
|
||||
evaluator = Evaluator.default(self.full_field)
|
||||
evaluator.set(self._param_dico.copy())
|
||||
evaluator.set(**self._param_dico)
|
||||
return evaluator
|
||||
|
||||
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
|
||||
@@ -546,14 +545,15 @@ class Parameters:
|
||||
) from None
|
||||
if exhaustive:
|
||||
for p in self._p_names:
|
||||
if p not in evaluator.params:
|
||||
if p not in evaluator.main_map:
|
||||
try:
|
||||
evaluator.compute(p)
|
||||
except Exception:
|
||||
pass
|
||||
computed = self.__class__(
|
||||
**{k: v for k, v in evaluator.params.items() if k in self._p_names}
|
||||
**{k: v.value for k, v in evaluator.main_map.items() if k in self._p_names}
|
||||
)
|
||||
computed._frozen = True
|
||||
return computed
|
||||
|
||||
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
|
||||
|
||||
23
tests/test_evaluator.py
Normal file
23
tests/test_evaluator.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from scgenerator.evaluator import Evaluator, Rule
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disk_rules() -> list[Rule]:
|
||||
return [
|
||||
Rule("radius", lambda diameter: diameter / 2),
|
||||
Rule("diameter", lambda radius: radius * 2),
|
||||
Rule("diameter", lambda perimeter: perimeter / np.pi),
|
||||
Rule("perimeter", lambda diameter: diameter * np.pi),
|
||||
Rule("area", lambda radius: np.pi * radius**2),
|
||||
Rule("radius", lambda area: np.sqrt(area / np.pi)),
|
||||
]
|
||||
|
||||
|
||||
def test_simple(disk_rules: list[Rule]):
|
||||
evaluator = Evaluator(*disk_rules, Rule("radius", lambda lol: lol * 3))
|
||||
evaluator.set()
|
||||
|
||||
assert evaluator.compute("area") == pytest.approx(78.53981633974483)
|
||||
Reference in New Issue
Block a user