Update evaluator

This commit is contained in:
Benoît Sierro
2023-08-02 12:50:34 +02:00
parent a8f0379c38
commit 04806ff814
5 changed files with 255 additions and 141 deletions

View File

@@ -19,6 +19,8 @@ def pbar_format(worker_id: int) -> dict[str, Any]:
) )
INF = float("inf")
SPEC1_FN = "spectrum_{}.npy" SPEC1_FN = "spectrum_{}.npy"
SPECN_FN1 = "spectra_{}.npy" SPECN_FN1 = "spectra_{}.npy"
SPEC1_FN_N = "spectrum_{}_{}.npy" SPEC1_FN_N = "spectrum_{}_{}.npy"

View File

@@ -30,15 +30,3 @@ class MissingParameterError(Exception):
class IncompleteDataFolderError(FileNotFoundError): class IncompleteDataFolderError(FileNotFoundError):
pass pass
class EvaluatorError(Exception):
pass
class NoDefaultError(EvaluatorError):
pass
class OperatorError(EvaluatorError):
pass

View File

@@ -1,25 +1,96 @@
from __future__ import annotations
import itertools import itertools
from collections import defaultdict import traceback
from dataclasses import dataclass from collections import ChainMap, defaultdict
from typing import Any, Callable, Optional, Union from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Type, Union
import numpy as np import numpy as np
from scgenerator import math, operators, utils from scgenerator import math, operators, utils
from scgenerator.const import MANDATORY_PARAMETERS from scgenerator.const import INF, MANDATORY_PARAMETERS
from scgenerator.errors import EvaluatorError, NoDefaultError
from scgenerator.physics import fiber, materials, plasma, pulse, units from scgenerator.physics import fiber, materials, plasma, pulse, units
from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger
class ErrorRecord(NamedTuple):
error: Exception
traceback: str
class EvaluatorError(Exception):
pass
class NoValidRuleError(EvaluatorError):
def __init__(self, target: str):
...
class CyclicDependencyError(EvaluatorError):
def __init__(self, target: str, lookup_stack: list[str]):
...
class AllRulesExhaustedError(EvaluatorError):
def __init__(self, target: str):
...
class EvaluatorErrorTree:
target: str
all_errors: dict[Type, ErrorRecord]
level: int
def __init__(self, target: str):
self.target = target
self.all_errors = {}
self.level = 0
def __repr__(self) -> str:
return f"{self.__class__.__name__}(target={self.target!r})"
def push(self):
self.level += 1
def pop(self):
self.level -= 1
def append(self, error: Exception):
tr = traceback.format_exc().splitlines()
tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `append` frame
self.all_errors[type(error)] = ErrorRecord(error, tr)
def compile_error(self) -> EvaluatorError:
raise EvaluatorError(f"Couldn't compute {self.target}.")
def summary(self) -> dict[str, int]:
return {k.__name__: len(v) for k, v in self.all_errors.items()}
def format_summary(self) -> str:
types = [f"{v} errors of type {k!r}" for k, v in self.summary().items()]
return ", ".join(types[:-2] + [" and ".join(types[-2:])])
def details(self) -> str:
...
class Rule: class Rule:
targets: dict[str, int]
func: Callable
args: list[str]
conditions: dict[str, Any]
mock_func: Callable
def __init__( def __init__(
self, self,
target: Union[str, list[Optional[str]]], target: Union[str, list[Optional[str]]],
func: Callable, func: Callable,
args: list[str] = None, args: list[str] | None = None,
priorities: Union[int, list[int]] = None, priorities: Union[int, list[int]] | None = None,
conditions: dict[str, str] = None, conditions: dict[str, str] | None = None,
): ):
targets = list(target) if isinstance(target, (list, tuple)) else [target] targets = list(target) if isinstance(target, (list, tuple)) else [target]
self.func = func self.func = func
@@ -43,6 +114,16 @@ class Rule:
f"{self.func.__name__} --> [{', '.join(self.targets)}]" f"{self.func.__name__} --> [{', '.join(self.targets)}]"
) )
def __eq__(self, other: Rule) -> bool:
return self.func == other.func
def pretty_format(self) -> str:
func_name_elements = self.func_name.split(".")
targets = list(self.targets)
arg_size = max(self.args, key=len)
func_size = max(func_name_elements, key=len)
@property @property
def func_name(self) -> str: def func_name(self) -> str:
return f"{self.func.__module__}.{self.func.__name__}" return f"{self.func.__module__}.{self.func.__name__}"
@@ -58,7 +139,7 @@ class Rule:
priorities: Union[int, list[int]] = None, priorities: Union[int, list[int]] = None,
) -> list["Rule"]: ) -> list["Rule"]:
""" """
given a function that doesn't need all its keyword arguemtn specified, will given a function that doesn't need all its keyword arguments specified, will
return a list of Rule obj, one for each combination of n_var specified kwargs return a list of Rule obj, one for each combination of n_var specified kwargs
Parameters Parameters
@@ -97,15 +178,21 @@ class Rule:
return rules return rules
@dataclass class EvaluatedValue(NamedTuple):
class EvalStat: value: Any
priority: float = np.inf priority: float = INF
rule: Rule = None rule: Rule = None
class Evaluator: class Evaluator:
defaults: dict[str, Any] = {} defaults: dict[str, Any] = {}
rules: dict[str, list[Rule]]
main_map: dict[str, EvaluatedValue]
param_chain_map: MutableMapping[str, EvaluatedValue]
lookup_stack: list[str]
failed_rules: dict[str, list[Rule]]
@classmethod @classmethod
def default(cls, full_field: bool = False) -> "Evaluator": def default(cls, full_field: bool = False) -> "Evaluator":
evaluator = cls() evaluator = cls()
@@ -124,28 +211,33 @@ class Evaluator:
evaluator.set(**params) evaluator.set(**params)
for target in MANDATORY_PARAMETERS: for target in MANDATORY_PARAMETERS:
evaluator.compute(target, check_only=check_only) evaluator.compute(target, check_only=check_only)
return evaluator.params return evaluator.main_map
@classmethod @classmethod
def register_default_param(cls, key, value): def register_default_param(cls, key, value):
cls.defaults[key] = value cls.defaults[key] = value
def __init__(self): def __init__(self, *rules: Rule):
self.rules: dict[str, list[Rule]] = defaultdict(list) self.rules = defaultdict(list)
self.params = {} self.main_map = {}
self.__curent_lookup: list[str] = []
self.__failed_rules: dict[str, list[Rule]] = defaultdict(list)
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
self.logger = get_logger(__name__) self.logger = get_logger(__name__)
self.failed_rules = defaultdict(list)
self.append(*rules)
def __getitem__(self, key: str) -> Any:
return self.main_map[key].value
def append(self, *rule: Rule): def append(self, *rule: Rule):
for r in rule: for r in rule:
for t in r.targets: for t in r.targets:
if t is not None: if t is None:
self.rules[t].append(r) continue
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True) self.rules[t].append(r)
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
def set(self, dico: dict[str, Any] = None, **params: Any): def set(self, **params: Any):
""" """
sets the internal set of parameters sets the internal set of parameters
@@ -157,18 +249,11 @@ class Evaluator:
params : Any params : Any
if dico is None, update the internal dict of parameters with params if dico is None, update the internal dict of parameters with params
""" """
if dico is None: for k, v in params.items():
dico = params self.main_map[k] = EvaluatedValue(v, np.inf)
self.params.update(dico)
else:
self.reset()
self.params = dico
for k in dico:
self.eval_stats[k].priority = np.inf
def reset(self): def reset(self):
self.params = {} self.main_map = {}
self.eval_stats = defaultdict(EvalStat)
def compute(self, target: str, check_only=False) -> Any: def compute(self, target: str, check_only=False) -> Any:
""" """
@@ -191,99 +276,115 @@ class Evaluator:
KeyError KeyError
there is no saved rule for the target there is no saved rule for the target
""" """
value = self.params.get(target) errors = EvaluatorErrorTree(target)
if value is None: param_chain_map = ChainMap(self.main_map)
prefix = "\t" * len(self.__curent_lookup) try:
# Avoid cycles value = self._compute(target, check_only, param_chain_map, [], errors)
if target in self.__curent_lookup: except EvaluatorError as e:
raise EvaluatorError( errors.append(e)
"cyclic dependency detected : " raise errors.compile_error() from None
f"{target!r} seems to depend on itself, please provide " self.merge_chain_map(param_chain_map)
f"a value for at least one variable in {self.__curent_lookup!r}. "
+ self.attempted_rules_str(target)
)
else:
self.__curent_lookup.append(target)
if len(self.rules[target]) == 0:
error = EvaluatorError(
f"no rule for {target}, trying to evaluate {self.__curent_lookup!r}"
)
else:
error = None
# try every rule until one succeeds
for ii, rule in enumerate(filter(self.validate_condition, self.rules[target])):
self.logger.debug(
prefix + f"attempt {ii+1} to compute {target}, this time using {rule!r}"
)
try:
args = [self.compute(k, check_only=check_only) for k in rule.args]
if check_only:
returned_values = rule.mock_func(*args)
else:
returned_values = rule.func(*args)
if len(rule.targets) == 1:
returned_values = [returned_values]
for (param_name, param_priority), returned_value in zip(
rule.targets.items(), returned_values
):
if (
param_name == target
or param_name not in self.params
or self.eval_stats[param_name].priority < param_priority
):
if check_only:
success_str = f"able to compute {param_name} "
else:
v_str = format(returned_value).replace("\n", "")
success_str = f"computed {param_name}={v_str} "
self.logger.debug(
prefix
+ success_str
+ f"using {rule.func.__name__} from {rule.func.__module__}"
)
self.set_value(param_name, returned_value, param_priority, rule)
if param_name == target:
value = returned_value
break
except EvaluatorError as e:
error = e
self.logger.debug(
prefix + f"error using {rule.func.__name__} : {str(error).strip()}"
)
self.__failed_rules[target].append(rule)
continue
except Exception as e:
raise type(e)(f"error while evaluating {target!r}")
else:
default = self.defaults.get(target)
if default is None:
error = error or NoDefaultError(
prefix
+ f"No default provided for {target}. Current lookup cycle : {self.__curent_lookup!r}. "
+ self.attempted_rules_str(target)
)
else:
value = default
self.logger.info(prefix + f"using default value of {value} for {target}")
self.set_value(target, value, 0, None)
last_target = self.__curent_lookup.pop()
assert target == last_target
self.__failed_rules[target] = []
if value is None and error is not None:
raise error
return value return value
def __getitem__(self, key: str) -> Any: def _compute(
return self.params[key] self,
target: str,
check_only: bool,
param_chain_map: MutableMapping[str, EvaluatedValue],
lookup_stack: list[str],
errors: EvaluatorErrorTree,
) -> Any:
errors.push()
if target in param_chain_map:
return param_chain_map[target].value
if target not in self.rules or len(self.rules[target]) == 0:
raise NoValidRuleError(target)
if target in lookup_stack:
raise CyclicDependencyError(target, lookup_stack.copy())
def set_value(self, key: str, value: Any, priority: int, rule: Rule): lookup_stack.append(target)
self.params[key] = value base_cm_length = len(param_chain_map.maps)
self.eval_stats[key].priority = priority for rule in self.rules[target]:
self.eval_stats[key].rule = rule values = self.apply_rule(rule, check_only, param_chain_map, lookup_stack, errors)
if self.valid_values(values, rule, check_only, param_chain_map, lookup_stack, errors):
break
lookup_stack.pop()
if values is None:
if target in self.defaults:
values = {target: self.defaults[target]}
else:
self.clear_chain_map(param_chain_map, base_cm_length)
raise AllRulesExhaustedError(target)
param_chain_map.maps.insert(0, values)
errors.pop()
return values[target].value
def valid_values(
self,
values: dict[str, EvaluatedValue] | None,
rule: Rule,
check_only: bool,
param_chain_map: MutableMapping[str, EvaluatedValue],
lookup_stack: list[str],
errors: EvaluatorErrorTree,
) -> bool:
if values is None:
return False
if check_only:
return True
for arg, target_value in rule.conditions.items():
value = self._compute(arg, False, param_chain_map, lookup_stack, errors)
if value != target_value:
return False
return True
def apply_rule(
self,
rule: Rule,
check_only: bool,
param_chain_map: MutableMapping[str, EvaluatedValue],
lookup_stack: list[str],
errors: EvaluatorErrorTree,
) -> dict[str, EvaluatedValue] | None:
try:
for arg in rule.args:
self._compute(arg, check_only, param_chain_map, lookup_stack, errors)
except Exception as e:
errors.append(e)
return None
args = [param_chain_map[k].value for k in rule.args]
func = rule.mock_func if check_only else rule.func
try:
values = func(*args)
except Exception as e:
self.record_error(e)
return None
if not isinstance(values, tuple):
values = (values,)
return {k: EvaluatedValue(v, rule.targets[k], rule) for k, v in zip(rule.targets, values)}
def merge_chain_map(self, param_chain_map: dict[str, EvaluatedValue]):
while len(param_chain_map.maps) > 1:
params = param_chain_map.maps.pop(0)
for k, v in params.items():
target_priority = self.main_map[k].priority if k in self.main_map else -INF
if v.priority > target_priority:
self.main_map[k] = v
def clear_chain_map(
self, param_chain_map: MutableMapping[str, EvaluatedValue], target_size: int
):
while len(param_chain_map.maps) > target_size:
param_chain_map.pop(0)
def validate_condition(self, rule: Rule) -> bool: def validate_condition(self, rule: Rule) -> bool:
try: try:
@@ -292,7 +393,7 @@ class Evaluator:
return False return False
def attempted_rules_str(self, target: str) -> str: def attempted_rules_str(self, target: str) -> str:
rules = ", ".join(str(r) for r in self.__failed_rules[target]) rules = ", ".join(str(r) for r in self.failed_rules[target])
if len(rules) == 0: if len(rules) == 0:
return "" return ""
return "attempted rules : " + rules return "attempted rules : " + rules

View File

@@ -14,8 +14,7 @@ import numpy as np
from scgenerator import utils from scgenerator import utils
from scgenerator.const import MANDATORY_PARAMETERS, __version__ from scgenerator.const import MANDATORY_PARAMETERS, __version__
from scgenerator.errors import EvaluatorError from scgenerator.evaluator import Evaluator, EvaluatorError
from scgenerator.evaluator import Evaluator
from scgenerator.io import DatetimeEncoder, decode_datetime_hook from scgenerator.io import DatetimeEncoder, decode_datetime_hook
from scgenerator.operators import Qualifier, SpecOperator from scgenerator.operators import Qualifier, SpecOperator
from scgenerator.utils import update_path_name from scgenerator.utils import update_path_name
@@ -475,7 +474,7 @@ class Parameters:
def get_evaluator(self): def get_evaluator(self):
evaluator = Evaluator.default(self.full_field) evaluator = Evaluator.default(self.full_field)
evaluator.set(self._param_dico.copy()) evaluator.set(**self._param_dico)
return evaluator return evaluator
def dump_dict(self, add_metadata=True) -> dict[str, Any]: def dump_dict(self, add_metadata=True) -> dict[str, Any]:
@@ -546,14 +545,15 @@ class Parameters:
) from None ) from None
if exhaustive: if exhaustive:
for p in self._p_names: for p in self._p_names:
if p not in evaluator.params: if p not in evaluator.main_map:
try: try:
evaluator.compute(p) evaluator.compute(p)
except Exception: except Exception:
pass pass
computed = self.__class__( computed = self.__class__(
**{k: v for k, v in evaluator.params.items() if k in self._p_names} **{k: v.value for k, v in evaluator.main_map.items() if k in self._p_names}
) )
computed._frozen = True
return computed return computed
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str: def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:

23
tests/test_evaluator.py Normal file
View File

@@ -0,0 +1,23 @@
import numpy as np
import pytest
from scgenerator.evaluator import Evaluator, Rule
@pytest.fixture
def disk_rules() -> list[Rule]:
return [
Rule("radius", lambda diameter: diameter / 2),
Rule("diameter", lambda radius: radius * 2),
Rule("diameter", lambda perimeter: perimeter / np.pi),
Rule("perimeter", lambda diameter: diameter * np.pi),
Rule("area", lambda radius: np.pi * radius**2),
Rule("radius", lambda area: np.sqrt(area / np.pi)),
]
def test_simple(disk_rules: list[Rule]):
evaluator = Evaluator(*disk_rules, Rule("radius", lambda lol: lol * 3))
evaluator.set()
assert evaluator.compute("area") == pytest.approx(78.53981633974483)