Update evaluator
This commit is contained in:
@@ -19,6 +19,8 @@ def pbar_format(worker_id: int) -> dict[str, Any]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
INF = float("inf")
|
||||||
|
|
||||||
SPEC1_FN = "spectrum_{}.npy"
|
SPEC1_FN = "spectrum_{}.npy"
|
||||||
SPECN_FN1 = "spectra_{}.npy"
|
SPECN_FN1 = "spectra_{}.npy"
|
||||||
SPEC1_FN_N = "spectrum_{}_{}.npy"
|
SPEC1_FN_N = "spectrum_{}_{}.npy"
|
||||||
|
|||||||
@@ -30,15 +30,3 @@ class MissingParameterError(Exception):
|
|||||||
|
|
||||||
class IncompleteDataFolderError(FileNotFoundError):
|
class IncompleteDataFolderError(FileNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EvaluatorError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NoDefaultError(EvaluatorError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class OperatorError(EvaluatorError):
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -1,25 +1,96 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict
|
import traceback
|
||||||
from dataclasses import dataclass
|
from collections import ChainMap, defaultdict
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from scgenerator import math, operators, utils
|
from scgenerator import math, operators, utils
|
||||||
from scgenerator.const import MANDATORY_PARAMETERS
|
from scgenerator.const import INF, MANDATORY_PARAMETERS
|
||||||
from scgenerator.errors import EvaluatorError, NoDefaultError
|
|
||||||
from scgenerator.physics import fiber, materials, plasma, pulse, units
|
from scgenerator.physics import fiber, materials, plasma, pulse, units
|
||||||
from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger
|
from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorRecord(NamedTuple):
|
||||||
|
error: Exception
|
||||||
|
traceback: str
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluatorError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NoValidRuleError(EvaluatorError):
|
||||||
|
def __init__(self, target: str):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class CyclicDependencyError(EvaluatorError):
|
||||||
|
def __init__(self, target: str, lookup_stack: list[str]):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class AllRulesExhaustedError(EvaluatorError):
|
||||||
|
def __init__(self, target: str):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluatorErrorTree:
|
||||||
|
target: str
|
||||||
|
all_errors: dict[Type, ErrorRecord]
|
||||||
|
|
||||||
|
level: int
|
||||||
|
|
||||||
|
def __init__(self, target: str):
|
||||||
|
self.target = target
|
||||||
|
self.all_errors = {}
|
||||||
|
self.level = 0
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(target={self.target!r})"
|
||||||
|
|
||||||
|
def push(self):
|
||||||
|
self.level += 1
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
self.level -= 1
|
||||||
|
|
||||||
|
def append(self, error: Exception):
|
||||||
|
tr = traceback.format_exc().splitlines()
|
||||||
|
tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `append` frame
|
||||||
|
self.all_errors[type(error)] = ErrorRecord(error, tr)
|
||||||
|
|
||||||
|
def compile_error(self) -> EvaluatorError:
|
||||||
|
raise EvaluatorError(f"Couldn't compute {self.target}.")
|
||||||
|
|
||||||
|
def summary(self) -> dict[str, int]:
|
||||||
|
return {k.__name__: len(v) for k, v in self.all_errors.items()}
|
||||||
|
|
||||||
|
def format_summary(self) -> str:
|
||||||
|
types = [f"{v} errors of type {k!r}" for k, v in self.summary().items()]
|
||||||
|
return ", ".join(types[:-2] + [" and ".join(types[-2:])])
|
||||||
|
|
||||||
|
def details(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class Rule:
|
class Rule:
|
||||||
|
targets: dict[str, int]
|
||||||
|
func: Callable
|
||||||
|
args: list[str]
|
||||||
|
conditions: dict[str, Any]
|
||||||
|
|
||||||
|
mock_func: Callable
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: Union[str, list[Optional[str]]],
|
target: Union[str, list[Optional[str]]],
|
||||||
func: Callable,
|
func: Callable,
|
||||||
args: list[str] = None,
|
args: list[str] | None = None,
|
||||||
priorities: Union[int, list[int]] = None,
|
priorities: Union[int, list[int]] | None = None,
|
||||||
conditions: dict[str, str] = None,
|
conditions: dict[str, str] | None = None,
|
||||||
):
|
):
|
||||||
targets = list(target) if isinstance(target, (list, tuple)) else [target]
|
targets = list(target) if isinstance(target, (list, tuple)) else [target]
|
||||||
self.func = func
|
self.func = func
|
||||||
@@ -43,6 +114,16 @@ class Rule:
|
|||||||
f"{self.func.__name__} --> [{', '.join(self.targets)}]"
|
f"{self.func.__name__} --> [{', '.join(self.targets)}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __eq__(self, other: Rule) -> bool:
|
||||||
|
return self.func == other.func
|
||||||
|
|
||||||
|
def pretty_format(self) -> str:
|
||||||
|
func_name_elements = self.func_name.split(".")
|
||||||
|
targets = list(self.targets)
|
||||||
|
|
||||||
|
arg_size = max(self.args, key=len)
|
||||||
|
func_size = max(func_name_elements, key=len)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def func_name(self) -> str:
|
def func_name(self) -> str:
|
||||||
return f"{self.func.__module__}.{self.func.__name__}"
|
return f"{self.func.__module__}.{self.func.__name__}"
|
||||||
@@ -58,7 +139,7 @@ class Rule:
|
|||||||
priorities: Union[int, list[int]] = None,
|
priorities: Union[int, list[int]] = None,
|
||||||
) -> list["Rule"]:
|
) -> list["Rule"]:
|
||||||
"""
|
"""
|
||||||
given a function that doesn't need all its keyword arguemtn specified, will
|
given a function that doesn't need all its keyword arguments specified, will
|
||||||
return a list of Rule obj, one for each combination of n_var specified kwargs
|
return a list of Rule obj, one for each combination of n_var specified kwargs
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -97,15 +178,21 @@ class Rule:
|
|||||||
return rules
|
return rules
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class EvaluatedValue(NamedTuple):
|
||||||
class EvalStat:
|
value: Any
|
||||||
priority: float = np.inf
|
priority: float = INF
|
||||||
rule: Rule = None
|
rule: Rule = None
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
defaults: dict[str, Any] = {}
|
defaults: dict[str, Any] = {}
|
||||||
|
|
||||||
|
rules: dict[str, list[Rule]]
|
||||||
|
main_map: dict[str, EvaluatedValue]
|
||||||
|
param_chain_map: MutableMapping[str, EvaluatedValue]
|
||||||
|
lookup_stack: list[str]
|
||||||
|
failed_rules: dict[str, list[Rule]]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls, full_field: bool = False) -> "Evaluator":
|
def default(cls, full_field: bool = False) -> "Evaluator":
|
||||||
evaluator = cls()
|
evaluator = cls()
|
||||||
@@ -124,28 +211,33 @@ class Evaluator:
|
|||||||
evaluator.set(**params)
|
evaluator.set(**params)
|
||||||
for target in MANDATORY_PARAMETERS:
|
for target in MANDATORY_PARAMETERS:
|
||||||
evaluator.compute(target, check_only=check_only)
|
evaluator.compute(target, check_only=check_only)
|
||||||
return evaluator.params
|
return evaluator.main_map
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_default_param(cls, key, value):
|
def register_default_param(cls, key, value):
|
||||||
cls.defaults[key] = value
|
cls.defaults[key] = value
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, *rules: Rule):
|
||||||
self.rules: dict[str, list[Rule]] = defaultdict(list)
|
self.rules = defaultdict(list)
|
||||||
self.params = {}
|
self.main_map = {}
|
||||||
self.__curent_lookup: list[str] = []
|
|
||||||
self.__failed_rules: dict[str, list[Rule]] = defaultdict(list)
|
|
||||||
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
|
|
||||||
self.logger = get_logger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
|
self.failed_rules = defaultdict(list)
|
||||||
|
|
||||||
|
self.append(*rules)
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> Any:
|
||||||
|
return self.main_map[key].value
|
||||||
|
|
||||||
def append(self, *rule: Rule):
|
def append(self, *rule: Rule):
|
||||||
for r in rule:
|
for r in rule:
|
||||||
for t in r.targets:
|
for t in r.targets:
|
||||||
if t is not None:
|
if t is None:
|
||||||
self.rules[t].append(r)
|
continue
|
||||||
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
|
self.rules[t].append(r)
|
||||||
|
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
|
||||||
|
|
||||||
def set(self, dico: dict[str, Any] = None, **params: Any):
|
def set(self, **params: Any):
|
||||||
"""
|
"""
|
||||||
sets the internal set of parameters
|
sets the internal set of parameters
|
||||||
|
|
||||||
@@ -157,18 +249,11 @@ class Evaluator:
|
|||||||
params : Any
|
params : Any
|
||||||
if dico is None, update the internal dict of parameters with params
|
if dico is None, update the internal dict of parameters with params
|
||||||
"""
|
"""
|
||||||
if dico is None:
|
for k, v in params.items():
|
||||||
dico = params
|
self.main_map[k] = EvaluatedValue(v, np.inf)
|
||||||
self.params.update(dico)
|
|
||||||
else:
|
|
||||||
self.reset()
|
|
||||||
self.params = dico
|
|
||||||
for k in dico:
|
|
||||||
self.eval_stats[k].priority = np.inf
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.params = {}
|
self.main_map = {}
|
||||||
self.eval_stats = defaultdict(EvalStat)
|
|
||||||
|
|
||||||
def compute(self, target: str, check_only=False) -> Any:
|
def compute(self, target: str, check_only=False) -> Any:
|
||||||
"""
|
"""
|
||||||
@@ -191,99 +276,115 @@ class Evaluator:
|
|||||||
KeyError
|
KeyError
|
||||||
there is no saved rule for the target
|
there is no saved rule for the target
|
||||||
"""
|
"""
|
||||||
value = self.params.get(target)
|
errors = EvaluatorErrorTree(target)
|
||||||
if value is None:
|
param_chain_map = ChainMap(self.main_map)
|
||||||
prefix = "\t" * len(self.__curent_lookup)
|
try:
|
||||||
# Avoid cycles
|
value = self._compute(target, check_only, param_chain_map, [], errors)
|
||||||
if target in self.__curent_lookup:
|
except EvaluatorError as e:
|
||||||
raise EvaluatorError(
|
errors.append(e)
|
||||||
"cyclic dependency detected : "
|
raise errors.compile_error() from None
|
||||||
f"{target!r} seems to depend on itself, please provide "
|
self.merge_chain_map(param_chain_map)
|
||||||
f"a value for at least one variable in {self.__curent_lookup!r}. "
|
|
||||||
+ self.attempted_rules_str(target)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.__curent_lookup.append(target)
|
|
||||||
|
|
||||||
if len(self.rules[target]) == 0:
|
|
||||||
error = EvaluatorError(
|
|
||||||
f"no rule for {target}, trying to evaluate {self.__curent_lookup!r}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
error = None
|
|
||||||
|
|
||||||
# try every rule until one succeeds
|
|
||||||
for ii, rule in enumerate(filter(self.validate_condition, self.rules[target])):
|
|
||||||
self.logger.debug(
|
|
||||||
prefix + f"attempt {ii+1} to compute {target}, this time using {rule!r}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
args = [self.compute(k, check_only=check_only) for k in rule.args]
|
|
||||||
if check_only:
|
|
||||||
returned_values = rule.mock_func(*args)
|
|
||||||
else:
|
|
||||||
returned_values = rule.func(*args)
|
|
||||||
if len(rule.targets) == 1:
|
|
||||||
returned_values = [returned_values]
|
|
||||||
for (param_name, param_priority), returned_value in zip(
|
|
||||||
rule.targets.items(), returned_values
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
param_name == target
|
|
||||||
or param_name not in self.params
|
|
||||||
or self.eval_stats[param_name].priority < param_priority
|
|
||||||
):
|
|
||||||
if check_only:
|
|
||||||
success_str = f"able to compute {param_name} "
|
|
||||||
else:
|
|
||||||
v_str = format(returned_value).replace("\n", "")
|
|
||||||
success_str = f"computed {param_name}={v_str} "
|
|
||||||
self.logger.debug(
|
|
||||||
prefix
|
|
||||||
+ success_str
|
|
||||||
+ f"using {rule.func.__name__} from {rule.func.__module__}"
|
|
||||||
)
|
|
||||||
self.set_value(param_name, returned_value, param_priority, rule)
|
|
||||||
if param_name == target:
|
|
||||||
value = returned_value
|
|
||||||
break
|
|
||||||
except EvaluatorError as e:
|
|
||||||
error = e
|
|
||||||
self.logger.debug(
|
|
||||||
prefix + f"error using {rule.func.__name__} : {str(error).strip()}"
|
|
||||||
)
|
|
||||||
self.__failed_rules[target].append(rule)
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
raise type(e)(f"error while evaluating {target!r}")
|
|
||||||
else:
|
|
||||||
default = self.defaults.get(target)
|
|
||||||
if default is None:
|
|
||||||
error = error or NoDefaultError(
|
|
||||||
prefix
|
|
||||||
+ f"No default provided for {target}. Current lookup cycle : {self.__curent_lookup!r}. "
|
|
||||||
+ self.attempted_rules_str(target)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
value = default
|
|
||||||
self.logger.info(prefix + f"using default value of {value} for {target}")
|
|
||||||
self.set_value(target, value, 0, None)
|
|
||||||
last_target = self.__curent_lookup.pop()
|
|
||||||
assert target == last_target
|
|
||||||
self.__failed_rules[target] = []
|
|
||||||
|
|
||||||
if value is None and error is not None:
|
|
||||||
raise error
|
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> Any:
|
def _compute(
|
||||||
return self.params[key]
|
self,
|
||||||
|
target: str,
|
||||||
|
check_only: bool,
|
||||||
|
param_chain_map: MutableMapping[str, EvaluatedValue],
|
||||||
|
lookup_stack: list[str],
|
||||||
|
errors: EvaluatorErrorTree,
|
||||||
|
) -> Any:
|
||||||
|
errors.push()
|
||||||
|
if target in param_chain_map:
|
||||||
|
return param_chain_map[target].value
|
||||||
|
if target not in self.rules or len(self.rules[target]) == 0:
|
||||||
|
raise NoValidRuleError(target)
|
||||||
|
if target in lookup_stack:
|
||||||
|
raise CyclicDependencyError(target, lookup_stack.copy())
|
||||||
|
|
||||||
def set_value(self, key: str, value: Any, priority: int, rule: Rule):
|
lookup_stack.append(target)
|
||||||
self.params[key] = value
|
base_cm_length = len(param_chain_map.maps)
|
||||||
self.eval_stats[key].priority = priority
|
for rule in self.rules[target]:
|
||||||
self.eval_stats[key].rule = rule
|
values = self.apply_rule(rule, check_only, param_chain_map, lookup_stack, errors)
|
||||||
|
if self.valid_values(values, rule, check_only, param_chain_map, lookup_stack, errors):
|
||||||
|
break
|
||||||
|
lookup_stack.pop()
|
||||||
|
|
||||||
|
if values is None:
|
||||||
|
if target in self.defaults:
|
||||||
|
values = {target: self.defaults[target]}
|
||||||
|
else:
|
||||||
|
self.clear_chain_map(param_chain_map, base_cm_length)
|
||||||
|
raise AllRulesExhaustedError(target)
|
||||||
|
|
||||||
|
param_chain_map.maps.insert(0, values)
|
||||||
|
errors.pop()
|
||||||
|
return values[target].value
|
||||||
|
|
||||||
|
def valid_values(
|
||||||
|
self,
|
||||||
|
values: dict[str, EvaluatedValue] | None,
|
||||||
|
rule: Rule,
|
||||||
|
check_only: bool,
|
||||||
|
param_chain_map: MutableMapping[str, EvaluatedValue],
|
||||||
|
lookup_stack: list[str],
|
||||||
|
errors: EvaluatorErrorTree,
|
||||||
|
) -> bool:
|
||||||
|
if values is None:
|
||||||
|
return False
|
||||||
|
if check_only:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for arg, target_value in rule.conditions.items():
|
||||||
|
value = self._compute(arg, False, param_chain_map, lookup_stack, errors)
|
||||||
|
if value != target_value:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def apply_rule(
|
||||||
|
self,
|
||||||
|
rule: Rule,
|
||||||
|
check_only: bool,
|
||||||
|
param_chain_map: MutableMapping[str, EvaluatedValue],
|
||||||
|
lookup_stack: list[str],
|
||||||
|
errors: EvaluatorErrorTree,
|
||||||
|
) -> dict[str, EvaluatedValue] | None:
|
||||||
|
try:
|
||||||
|
for arg in rule.args:
|
||||||
|
self._compute(arg, check_only, param_chain_map, lookup_stack, errors)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
args = [param_chain_map[k].value for k in rule.args]
|
||||||
|
|
||||||
|
func = rule.mock_func if check_only else rule.func
|
||||||
|
|
||||||
|
try:
|
||||||
|
values = func(*args)
|
||||||
|
except Exception as e:
|
||||||
|
self.record_error(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(values, tuple):
|
||||||
|
values = (values,)
|
||||||
|
|
||||||
|
return {k: EvaluatedValue(v, rule.targets[k], rule) for k, v in zip(rule.targets, values)}
|
||||||
|
|
||||||
|
def merge_chain_map(self, param_chain_map: dict[str, EvaluatedValue]):
|
||||||
|
while len(param_chain_map.maps) > 1:
|
||||||
|
params = param_chain_map.maps.pop(0)
|
||||||
|
for k, v in params.items():
|
||||||
|
target_priority = self.main_map[k].priority if k in self.main_map else -INF
|
||||||
|
if v.priority > target_priority:
|
||||||
|
self.main_map[k] = v
|
||||||
|
|
||||||
|
def clear_chain_map(
|
||||||
|
self, param_chain_map: MutableMapping[str, EvaluatedValue], target_size: int
|
||||||
|
):
|
||||||
|
while len(param_chain_map.maps) > target_size:
|
||||||
|
param_chain_map.pop(0)
|
||||||
|
|
||||||
def validate_condition(self, rule: Rule) -> bool:
|
def validate_condition(self, rule: Rule) -> bool:
|
||||||
try:
|
try:
|
||||||
@@ -292,7 +393,7 @@ class Evaluator:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def attempted_rules_str(self, target: str) -> str:
|
def attempted_rules_str(self, target: str) -> str:
|
||||||
rules = ", ".join(str(r) for r in self.__failed_rules[target])
|
rules = ", ".join(str(r) for r in self.failed_rules[target])
|
||||||
if len(rules) == 0:
|
if len(rules) == 0:
|
||||||
return ""
|
return ""
|
||||||
return "attempted rules : " + rules
|
return "attempted rules : " + rules
|
||||||
|
|||||||
@@ -14,8 +14,7 @@ import numpy as np
|
|||||||
|
|
||||||
from scgenerator import utils
|
from scgenerator import utils
|
||||||
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
||||||
from scgenerator.errors import EvaluatorError
|
from scgenerator.evaluator import Evaluator, EvaluatorError
|
||||||
from scgenerator.evaluator import Evaluator
|
|
||||||
from scgenerator.io import DatetimeEncoder, decode_datetime_hook
|
from scgenerator.io import DatetimeEncoder, decode_datetime_hook
|
||||||
from scgenerator.operators import Qualifier, SpecOperator
|
from scgenerator.operators import Qualifier, SpecOperator
|
||||||
from scgenerator.utils import update_path_name
|
from scgenerator.utils import update_path_name
|
||||||
@@ -475,7 +474,7 @@ class Parameters:
|
|||||||
|
|
||||||
def get_evaluator(self):
|
def get_evaluator(self):
|
||||||
evaluator = Evaluator.default(self.full_field)
|
evaluator = Evaluator.default(self.full_field)
|
||||||
evaluator.set(self._param_dico.copy())
|
evaluator.set(**self._param_dico)
|
||||||
return evaluator
|
return evaluator
|
||||||
|
|
||||||
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
|
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
|
||||||
@@ -546,14 +545,15 @@ class Parameters:
|
|||||||
) from None
|
) from None
|
||||||
if exhaustive:
|
if exhaustive:
|
||||||
for p in self._p_names:
|
for p in self._p_names:
|
||||||
if p not in evaluator.params:
|
if p not in evaluator.main_map:
|
||||||
try:
|
try:
|
||||||
evaluator.compute(p)
|
evaluator.compute(p)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
computed = self.__class__(
|
computed = self.__class__(
|
||||||
**{k: v for k, v in evaluator.params.items() if k in self._p_names}
|
**{k: v.value for k, v in evaluator.main_map.items() if k in self._p_names}
|
||||||
)
|
)
|
||||||
|
computed._frozen = True
|
||||||
return computed
|
return computed
|
||||||
|
|
||||||
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
|
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
|
||||||
|
|||||||
23
tests/test_evaluator.py
Normal file
23
tests/test_evaluator.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from scgenerator.evaluator import Evaluator, Rule
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def disk_rules() -> list[Rule]:
|
||||||
|
return [
|
||||||
|
Rule("radius", lambda diameter: diameter / 2),
|
||||||
|
Rule("diameter", lambda radius: radius * 2),
|
||||||
|
Rule("diameter", lambda perimeter: perimeter / np.pi),
|
||||||
|
Rule("perimeter", lambda diameter: diameter * np.pi),
|
||||||
|
Rule("area", lambda radius: np.pi * radius**2),
|
||||||
|
Rule("radius", lambda area: np.sqrt(area / np.pi)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple(disk_rules: list[Rule]):
|
||||||
|
evaluator = Evaluator(*disk_rules, Rule("radius", lambda lol: lol * 3))
|
||||||
|
evaluator.set()
|
||||||
|
|
||||||
|
assert evaluator.compute("area") == pytest.approx(78.53981633974483)
|
||||||
Reference in New Issue
Block a user