better evaluator errors

This commit is contained in:
Benoît Sierro
2023-08-03 10:26:56 +02:00
parent ef30f618cc
commit df5e974860
2 changed files with 57 additions and 43 deletions

View File

@@ -15,71 +15,66 @@ from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_l
class ErrorRecord(NamedTuple): class ErrorRecord(NamedTuple):
error: Exception error: Exception
lookup_stack: list[str]
rules_stack: list[Rule]
traceback: str traceback: str
class EvaluatorError(Exception): class EvaluatorError(Exception):
pass target: str | None = None
class NoValidRuleError(EvaluatorError): class NoValidRuleError(EvaluatorError):
def __init__(self, target: str): def __init__(self, target: str):
... self.target = target
super().__init__(f"no valid rule to compute {target!r}")
class CyclicDependencyError(EvaluatorError): class CyclicDependencyError(EvaluatorError):
def __init__(self, target: str, lookup_stack: list[str]): def __init__(self, target: str, lookup_stack: list[str]):
... self.target = target
cycle = "".join(lookup_stack)
super().__init__(f"cycle detected while computing {target!r}:\n{cycle}")
class AllRulesExhaustedError(EvaluatorError): class AllRulesExhaustedError(EvaluatorError):
def __init__(self, target: str): def __init__(self, target: str):
... self.target = target
super().__init__(f"tried every rule for {target!r} and no default value is set")
class EvaluatorErrorTree: class EvaluatorErrorTree:
target: str target: str
all_errors: dict[Type, ErrorRecord] all_errors: list[ErrorRecord]
level: int
def __init__(self, target: str): def __init__(self, target: str):
self.target = target self.target = target
self.all_errors = {} self.all_errors = []
self.level = 0
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}(target={self.target!r})" return f"{self.__class__.__name__}(target={self.target!r})"
def push(self): def __len__(self) -> int:
self.level += 1 return len(self.all_errors)
def pop(self): def append(self, error: Exception, lookup_stack: list[str], rules_stack: list[Rule]):
self.level -= 1
def append(self, error: Exception):
tr = traceback.format_exc().splitlines() tr = traceback.format_exc().splitlines()
tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `append` frame tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `EvaluatorErrorTree.append` frame
self.all_errors[type(error)] = ErrorRecord(error, tr) self.all_errors.append(ErrorRecord(error, lookup_stack.copy(), rules_stack.copy(), tr))
def compile_error(self) -> EvaluatorError: def compile_error(self) -> EvaluatorError:
raise EvaluatorError(f"Couldn't compute {self.target}.") failed_rules = set(rec.rules_stack[-1] for rec in self.all_errors if rec.rules_stack)
failed_rules = ("\n" + "-" * 80 + "\n").join(rule.pretty_format() for rule in failed_rules)
def summary(self) -> dict[str, int]: raise EvaluatorError(
return {k.__name__: len(v) for k, v in self.all_errors.items()} f"Couldn't compute {self.target}. {len(self)} rules failed.\n{failed_rules}"
)
def format_summary(self) -> str:
types = [f"{v} errors of type {k!r}" for k, v in self.summary().items()]
return ", ".join(types[:-2] + [" and ".join(types[-2:])])
def details(self) -> str:
...
class Rule: class Rule:
targets: dict[str, int] targets: dict[str, int]
func: Callable func: Callable
args: list[str] args: tuple[str]
conditions: dict[str, Any] conditions: dict[str, Any]
mock_func: Callable mock_func: Callable
@@ -88,7 +83,7 @@ class Rule:
self, self,
target: Union[str, list[Optional[str]]], target: Union[str, list[Optional[str]]],
func: Callable, func: Callable,
args: list[str] | None = None, args: tuple[str] | None = None,
priorities: Union[int, list[int]] | None = None, priorities: Union[int, list[int]] | None = None,
conditions: dict[str, str] | None = None, conditions: dict[str, str] | None = None,
): ):
@@ -101,7 +96,7 @@ class Rule:
self.targets = dict(zip(targets, priorities)) self.targets = dict(zip(targets, priorities))
if args is None: if args is None:
args = get_arg_names(func) args = get_arg_names(func)
self.args = args self.args = tuple(args)
self.mock_func = _mock_function(len(self.args), len(self.targets)) self.mock_func = _mock_function(len(self.args), len(self.targets))
self.conditions = conditions or {} self.conditions = conditions or {}
@@ -115,7 +110,14 @@ class Rule:
) )
def __eq__(self, other: Rule) -> bool: def __eq__(self, other: Rule) -> bool:
return self.func == other.func return (
self.args == other.args
and tuple(self.targets) == tuple(other.targets)
and self.func == other.func
)
def __hash__(self) -> int:
return hash((self.args, self.func, tuple(self.targets)))
def pretty_format(self) -> str: def pretty_format(self) -> str:
return io.format_graph(self.args, self.func_name, self.targets) return io.format_graph(self.args, self.func_name, self.targets)
@@ -274,10 +276,14 @@ class Evaluator:
""" """
errors = EvaluatorErrorTree(target) errors = EvaluatorErrorTree(target)
param_chain_map = ChainMap(self.main_map) param_chain_map = ChainMap(self.main_map)
lookup_stack = []
rules_stack = []
try: try:
value = self._compute(target, check_only, param_chain_map, [], errors) value = self._compute(
target, check_only, param_chain_map, lookup_stack, rules_stack, errors
)
except EvaluatorError as e: except EvaluatorError as e:
errors.append(e) errors.append(e, lookup_stack, rules_stack)
raise errors.compile_error() from None raise errors.compile_error() from None
self.merge_chain_map(param_chain_map) self.merge_chain_map(param_chain_map)
return value return value
@@ -288,22 +294,28 @@ class Evaluator:
check_only: bool, check_only: bool,
param_chain_map: MutableMapping[str, EvaluatedValue], param_chain_map: MutableMapping[str, EvaluatedValue],
lookup_stack: list[str], lookup_stack: list[str],
rules_stack: list[Rule],
errors: EvaluatorErrorTree, errors: EvaluatorErrorTree,
) -> Any: ) -> Any:
errors.push()
if target in param_chain_map: if target in param_chain_map:
return param_chain_map[target].value return param_chain_map[target].value
if target not in self.rules or len(self.rules[target]) == 0: if target not in self.rules or len(self.rules[target]) == 0:
raise NoValidRuleError(target) raise NoValidRuleError(target)
if target in lookup_stack: if target in lookup_stack:
raise CyclicDependencyError(target, lookup_stack.copy()) raise CyclicDependencyError(target)
lookup_stack.append(target) lookup_stack.append(target)
base_cm_length = len(param_chain_map.maps) base_cm_length = len(param_chain_map.maps)
for rule in self.rules[target]: for rule in self.rules[target]:
values = self.apply_rule(rule, check_only, param_chain_map, lookup_stack, errors) rules_stack.append(rule)
if self.valid_values(values, rule, check_only, param_chain_map, lookup_stack, errors): values = self.apply_rule(
rule, check_only, param_chain_map, lookup_stack, rules_stack, errors
)
if self.valid_values(
values, rule, check_only, param_chain_map, lookup_stack, rules_stack, errors
):
break break
rules_stack.pop()
lookup_stack.pop() lookup_stack.pop()
if values is None: if values is None:
@@ -314,7 +326,6 @@ class Evaluator:
raise AllRulesExhaustedError(target) raise AllRulesExhaustedError(target)
param_chain_map.maps.insert(0, values) param_chain_map.maps.insert(0, values)
errors.pop()
return values[target].value return values[target].value
def valid_values( def valid_values(
@@ -324,6 +335,7 @@ class Evaluator:
check_only: bool, check_only: bool,
param_chain_map: MutableMapping[str, EvaluatedValue], param_chain_map: MutableMapping[str, EvaluatedValue],
lookup_stack: list[str], lookup_stack: list[str],
rules_stack: list[Rule],
errors: EvaluatorErrorTree, errors: EvaluatorErrorTree,
) -> bool: ) -> bool:
if values is None: if values is None:
@@ -332,7 +344,7 @@ class Evaluator:
return True return True
for arg, target_value in rule.conditions.items(): for arg, target_value in rule.conditions.items():
value = self._compute(arg, False, param_chain_map, lookup_stack, errors) value = self._compute(arg, False, param_chain_map, lookup_stack, rules_stack, errors)
if value != target_value: if value != target_value:
return False return False
@@ -344,13 +356,14 @@ class Evaluator:
check_only: bool, check_only: bool,
param_chain_map: MutableMapping[str, EvaluatedValue], param_chain_map: MutableMapping[str, EvaluatedValue],
lookup_stack: list[str], lookup_stack: list[str],
rules_stack: list[Rule],
errors: EvaluatorErrorTree, errors: EvaluatorErrorTree,
) -> dict[str, EvaluatedValue] | None: ) -> dict[str, EvaluatedValue] | None:
try: try:
for arg in rule.args: for arg in rule.args:
self._compute(arg, check_only, param_chain_map, lookup_stack, errors) self._compute(arg, check_only, param_chain_map, lookup_stack, rules_stack, errors)
except Exception as e: except Exception as e:
errors.append(e) errors.append(e, lookup_stack, rules_stack)
return None return None
args = [param_chain_map[k].value for k in rule.args] args = [param_chain_map[k].value for k in rule.args]
@@ -360,7 +373,7 @@ class Evaluator:
try: try:
values = func(*args) values = func(*args)
except Exception as e: except Exception as e:
self.record_error(e) errors.append(e, lookup_stack, rules_stack)
return None return None
if not isinstance(values, tuple): if not isinstance(values, tuple):
@@ -380,7 +393,7 @@ class Evaluator:
self, param_chain_map: MutableMapping[str, EvaluatedValue], target_size: int self, param_chain_map: MutableMapping[str, EvaluatedValue], target_size: int
): ):
while len(param_chain_map.maps) > target_size: while len(param_chain_map.maps) > target_size:
param_chain_map.pop(0) param_chain_map.maps.pop(0)
def validate_condition(self, rule: Rule) -> bool: def validate_condition(self, rule: Rule) -> bool:
try: try:

View File

@@ -1,6 +1,7 @@
import datetime import datetime
import json import json
from pathlib import Path from pathlib import Path
from typing import Sequence
import pkg_resources import pkg_resources