Update evaluator

This commit is contained in:
Benoît Sierro
2023-08-02 12:50:34 +02:00
parent a8f0379c38
commit 04806ff814
5 changed files with 255 additions and 141 deletions

View File

@@ -19,6 +19,8 @@ def pbar_format(worker_id: int) -> dict[str, Any]:
)
INF = float("inf")
SPEC1_FN = "spectrum_{}.npy"
SPECN_FN1 = "spectra_{}.npy"
SPEC1_FN_N = "spectrum_{}_{}.npy"

View File

@@ -30,15 +30,3 @@ class MissingParameterError(Exception):
class IncompleteDataFolderError(FileNotFoundError):
pass
class EvaluatorError(Exception):
pass
class NoDefaultError(EvaluatorError):
pass
class OperatorError(EvaluatorError):
pass

View File

@@ -1,25 +1,96 @@
from __future__ import annotations
import itertools
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import traceback
from collections import ChainMap, defaultdict
from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Type, Union
import numpy as np
from scgenerator import math, operators, utils
from scgenerator.const import MANDATORY_PARAMETERS
from scgenerator.errors import EvaluatorError, NoDefaultError
from scgenerator.const import INF, MANDATORY_PARAMETERS
from scgenerator.physics import fiber, materials, plasma, pulse, units
from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger
class ErrorRecord(NamedTuple):
error: Exception
traceback: str
class EvaluatorError(Exception):
pass
class NoValidRuleError(EvaluatorError):
def __init__(self, target: str):
...
class CyclicDependencyError(EvaluatorError):
def __init__(self, target: str, lookup_stack: list[str]):
...
class AllRulesExhaustedError(EvaluatorError):
def __init__(self, target: str):
...
class EvaluatorErrorTree:
target: str
all_errors: dict[Type, ErrorRecord]
level: int
def __init__(self, target: str):
self.target = target
self.all_errors = {}
self.level = 0
def __repr__(self) -> str:
return f"{self.__class__.__name__}(target={self.target!r})"
def push(self):
self.level += 1
def pop(self):
self.level -= 1
def append(self, error: Exception):
tr = traceback.format_exc().splitlines()
tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `append` frame
self.all_errors[type(error)] = ErrorRecord(error, tr)
def compile_error(self) -> EvaluatorError:
raise EvaluatorError(f"Couldn't compute {self.target}.")
def summary(self) -> dict[str, int]:
return {k.__name__: len(v) for k, v in self.all_errors.items()}
def format_summary(self) -> str:
types = [f"{v} errors of type {k!r}" for k, v in self.summary().items()]
return ", ".join(types[:-2] + [" and ".join(types[-2:])])
def details(self) -> str:
...
class Rule:
targets: dict[str, int]
func: Callable
args: list[str]
conditions: dict[str, Any]
mock_func: Callable
def __init__(
self,
target: Union[str, list[Optional[str]]],
func: Callable,
args: list[str] = None,
priorities: Union[int, list[int]] = None,
conditions: dict[str, str] = None,
args: list[str] | None = None,
priorities: Union[int, list[int]] | None = None,
conditions: dict[str, str] | None = None,
):
targets = list(target) if isinstance(target, (list, tuple)) else [target]
self.func = func
@@ -43,6 +114,16 @@ class Rule:
f"{self.func.__name__} --> [{', '.join(self.targets)}]"
)
def __eq__(self, other: Rule) -> bool:
return self.func == other.func
def pretty_format(self) -> str:
func_name_elements = self.func_name.split(".")
targets = list(self.targets)
arg_size = max(self.args, key=len)
func_size = max(func_name_elements, key=len)
@property
def func_name(self) -> str:
return f"{self.func.__module__}.{self.func.__name__}"
@@ -58,7 +139,7 @@ class Rule:
priorities: Union[int, list[int]] = None,
) -> list["Rule"]:
"""
given a function that doesn't need all its keyword arguemtn specified, will
given a function that doesn't need all its keyword arguments specified, will
return a list of Rule obj, one for each combination of n_var specified kwargs
Parameters
@@ -97,15 +178,21 @@ class Rule:
return rules
@dataclass
class EvalStat:
priority: float = np.inf
class EvaluatedValue(NamedTuple):
value: Any
priority: float = INF
rule: Rule = None
class Evaluator:
defaults: dict[str, Any] = {}
rules: dict[str, list[Rule]]
main_map: dict[str, EvaluatedValue]
param_chain_map: MutableMapping[str, EvaluatedValue]
lookup_stack: list[str]
failed_rules: dict[str, list[Rule]]
@classmethod
def default(cls, full_field: bool = False) -> "Evaluator":
evaluator = cls()
@@ -124,28 +211,33 @@ class Evaluator:
evaluator.set(**params)
for target in MANDATORY_PARAMETERS:
evaluator.compute(target, check_only=check_only)
return evaluator.params
return evaluator.main_map
@classmethod
def register_default_param(cls, key, value):
cls.defaults[key] = value
def __init__(self):
self.rules: dict[str, list[Rule]] = defaultdict(list)
self.params = {}
self.__curent_lookup: list[str] = []
self.__failed_rules: dict[str, list[Rule]] = defaultdict(list)
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
def __init__(self, *rules: Rule):
self.rules = defaultdict(list)
self.main_map = {}
self.logger = get_logger(__name__)
self.failed_rules = defaultdict(list)
self.append(*rules)
def __getitem__(self, key: str) -> Any:
return self.main_map[key].value
def append(self, *rule: Rule):
for r in rule:
for t in r.targets:
if t is not None:
if t is None:
continue
self.rules[t].append(r)
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
def set(self, dico: dict[str, Any] = None, **params: Any):
def set(self, **params: Any):
"""
sets the internal set of parameters
@@ -157,18 +249,11 @@ class Evaluator:
params : Any
if dico is None, update the internal dict of parameters with params
"""
if dico is None:
dico = params
self.params.update(dico)
else:
self.reset()
self.params = dico
for k in dico:
self.eval_stats[k].priority = np.inf
for k, v in params.items():
self.main_map[k] = EvaluatedValue(v, np.inf)
def reset(self):
self.params = {}
self.eval_stats = defaultdict(EvalStat)
self.main_map = {}
def compute(self, target: str, check_only=False) -> Any:
"""
@@ -191,99 +276,115 @@ class Evaluator:
KeyError
there is no saved rule for the target
"""
value = self.params.get(target)
if value is None:
prefix = "\t" * len(self.__curent_lookup)
# Avoid cycles
if target in self.__curent_lookup:
raise EvaluatorError(
"cyclic dependency detected : "
f"{target!r} seems to depend on itself, please provide "
f"a value for at least one variable in {self.__curent_lookup!r}. "
+ self.attempted_rules_str(target)
)
else:
self.__curent_lookup.append(target)
if len(self.rules[target]) == 0:
error = EvaluatorError(
f"no rule for {target}, trying to evaluate {self.__curent_lookup!r}"
)
else:
error = None
# try every rule until one succeeds
for ii, rule in enumerate(filter(self.validate_condition, self.rules[target])):
self.logger.debug(
prefix + f"attempt {ii+1} to compute {target}, this time using {rule!r}"
)
errors = EvaluatorErrorTree(target)
param_chain_map = ChainMap(self.main_map)
try:
args = [self.compute(k, check_only=check_only) for k in rule.args]
if check_only:
returned_values = rule.mock_func(*args)
else:
returned_values = rule.func(*args)
if len(rule.targets) == 1:
returned_values = [returned_values]
for (param_name, param_priority), returned_value in zip(
rule.targets.items(), returned_values
):
if (
param_name == target
or param_name not in self.params
or self.eval_stats[param_name].priority < param_priority
):
if check_only:
success_str = f"able to compute {param_name} "
else:
v_str = format(returned_value).replace("\n", "")
success_str = f"computed {param_name}={v_str} "
self.logger.debug(
prefix
+ success_str
+ f"using {rule.func.__name__} from {rule.func.__module__}"
)
self.set_value(param_name, returned_value, param_priority, rule)
if param_name == target:
value = returned_value
break
value = self._compute(target, check_only, param_chain_map, [], errors)
except EvaluatorError as e:
error = e
self.logger.debug(
prefix + f"error using {rule.func.__name__} : {str(error).strip()}"
)
self.__failed_rules[target].append(rule)
continue
except Exception as e:
raise type(e)(f"error while evaluating {target!r}")
else:
default = self.defaults.get(target)
if default is None:
error = error or NoDefaultError(
prefix
+ f"No default provided for {target}. Current lookup cycle : {self.__curent_lookup!r}. "
+ self.attempted_rules_str(target)
)
else:
value = default
self.logger.info(prefix + f"using default value of {value} for {target}")
self.set_value(target, value, 0, None)
last_target = self.__curent_lookup.pop()
assert target == last_target
self.__failed_rules[target] = []
if value is None and error is not None:
raise error
errors.append(e)
raise errors.compile_error() from None
self.merge_chain_map(param_chain_map)
return value
def __getitem__(self, key: str) -> Any:
return self.params[key]
def _compute(
self,
target: str,
check_only: bool,
param_chain_map: MutableMapping[str, EvaluatedValue],
lookup_stack: list[str],
errors: EvaluatorErrorTree,
) -> Any:
errors.push()
if target in param_chain_map:
return param_chain_map[target].value
if target not in self.rules or len(self.rules[target]) == 0:
raise NoValidRuleError(target)
if target in lookup_stack:
raise CyclicDependencyError(target, lookup_stack.copy())
def set_value(self, key: str, value: Any, priority: int, rule: Rule):
self.params[key] = value
self.eval_stats[key].priority = priority
self.eval_stats[key].rule = rule
lookup_stack.append(target)
base_cm_length = len(param_chain_map.maps)
for rule in self.rules[target]:
values = self.apply_rule(rule, check_only, param_chain_map, lookup_stack, errors)
if self.valid_values(values, rule, check_only, param_chain_map, lookup_stack, errors):
break
lookup_stack.pop()
if values is None:
if target in self.defaults:
values = {target: self.defaults[target]}
else:
self.clear_chain_map(param_chain_map, base_cm_length)
raise AllRulesExhaustedError(target)
param_chain_map.maps.insert(0, values)
errors.pop()
return values[target].value
def valid_values(
self,
values: dict[str, EvaluatedValue] | None,
rule: Rule,
check_only: bool,
param_chain_map: MutableMapping[str, EvaluatedValue],
lookup_stack: list[str],
errors: EvaluatorErrorTree,
) -> bool:
if values is None:
return False
if check_only:
return True
for arg, target_value in rule.conditions.items():
value = self._compute(arg, False, param_chain_map, lookup_stack, errors)
if value != target_value:
return False
return True
def apply_rule(
self,
rule: Rule,
check_only: bool,
param_chain_map: MutableMapping[str, EvaluatedValue],
lookup_stack: list[str],
errors: EvaluatorErrorTree,
) -> dict[str, EvaluatedValue] | None:
try:
for arg in rule.args:
self._compute(arg, check_only, param_chain_map, lookup_stack, errors)
except Exception as e:
errors.append(e)
return None
args = [param_chain_map[k].value for k in rule.args]
func = rule.mock_func if check_only else rule.func
try:
values = func(*args)
except Exception as e:
self.record_error(e)
return None
if not isinstance(values, tuple):
values = (values,)
return {k: EvaluatedValue(v, rule.targets[k], rule) for k, v in zip(rule.targets, values)}
def merge_chain_map(self, param_chain_map: dict[str, EvaluatedValue]):
while len(param_chain_map.maps) > 1:
params = param_chain_map.maps.pop(0)
for k, v in params.items():
target_priority = self.main_map[k].priority if k in self.main_map else -INF
if v.priority > target_priority:
self.main_map[k] = v
def clear_chain_map(
self, param_chain_map: MutableMapping[str, EvaluatedValue], target_size: int
):
while len(param_chain_map.maps) > target_size:
param_chain_map.pop(0)
def validate_condition(self, rule: Rule) -> bool:
try:
@@ -292,7 +393,7 @@ class Evaluator:
return False
def attempted_rules_str(self, target: str) -> str:
rules = ", ".join(str(r) for r in self.__failed_rules[target])
rules = ", ".join(str(r) for r in self.failed_rules[target])
if len(rules) == 0:
return ""
return "attempted rules : " + rules

View File

@@ -14,8 +14,7 @@ import numpy as np
from scgenerator import utils
from scgenerator.const import MANDATORY_PARAMETERS, __version__
from scgenerator.errors import EvaluatorError
from scgenerator.evaluator import Evaluator
from scgenerator.evaluator import Evaluator, EvaluatorError
from scgenerator.io import DatetimeEncoder, decode_datetime_hook
from scgenerator.operators import Qualifier, SpecOperator
from scgenerator.utils import update_path_name
@@ -475,7 +474,7 @@ class Parameters:
def get_evaluator(self):
evaluator = Evaluator.default(self.full_field)
evaluator.set(self._param_dico.copy())
evaluator.set(**self._param_dico)
return evaluator
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
@@ -546,14 +545,15 @@ class Parameters:
) from None
if exhaustive:
for p in self._p_names:
if p not in evaluator.params:
if p not in evaluator.main_map:
try:
evaluator.compute(p)
except Exception:
pass
computed = self.__class__(
**{k: v for k, v in evaluator.params.items() if k in self._p_names}
**{k: v.value for k, v in evaluator.main_map.items() if k in self._p_names}
)
computed._frozen = True
return computed
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:

23
tests/test_evaluator.py Normal file
View File

@@ -0,0 +1,23 @@
import numpy as np
import pytest
from scgenerator.evaluator import Evaluator, Rule
@pytest.fixture
def disk_rules() -> list[Rule]:
return [
Rule("radius", lambda diameter: diameter / 2),
Rule("diameter", lambda radius: radius * 2),
Rule("diameter", lambda perimeter: perimeter / np.pi),
Rule("perimeter", lambda diameter: diameter * np.pi),
Rule("area", lambda radius: np.pi * radius**2),
Rule("radius", lambda area: np.sqrt(area / np.pi)),
]
def test_simple(disk_rules: list[Rule]):
evaluator = Evaluator(*disk_rules, Rule("radius", lambda lol: lol * 3))
evaluator.set()
assert evaluator.compute("area") == pytest.approx(78.53981633974483)