added strict flag to Parameters.compute

This commit is contained in:
2024-02-06 16:28:13 +01:00
parent 3741954d69
commit c65eebb2dc
2 changed files with 31 additions and 3 deletions

View File

@@ -534,7 +534,9 @@ class Parameters:
param["version"] = __version__ param["version"] = __version__
return param return param
def compute(self, p_name: str, *other_p_names: str, **with_values: Any) -> Any | tuple[Any]: def compute(
self, p_name: str, *other_p_names: str, strict: bool = True, **with_values: Any
) -> Any | tuple[Any]:
""" """
compute a single or a set of value compute a single or a set of value
@@ -544,6 +546,8 @@ class Parameters:
parameter to compute0 parameter to compute0
other_p_names : str, positional only other_p_names : str, positional only
other parameters to compute other parameters to compute
strict : bool, optional
raise an exception when something cannot be computed, by default True
with_values : Any, keyword only with_values : Any, keyword only
compute the desired parameters as if self is updated with `with_values`. compute the desired parameters as if self is updated with `with_values`.
@@ -565,9 +569,22 @@ class Parameters:
""" """
evaluator = self.get_evaluator() evaluator = self.get_evaluator()
evaluator.set(**with_values) evaluator.set(**with_values)
first = evaluator.compute(p_name) if strict:
def comp(_pn):
return evaluator.compute(_pn)
else:
def comp(_pn):
try:
return evaluator.compute(_pn)
except EvaluatorError:
return None
first = comp(p_name)
if other_p_names: if other_p_names:
return (first, *(evaluator.compute(p) for p in other_p_names)) return (first, *(comp(p) for p in other_p_names))
else: else:
return first return first

View File

@@ -1,6 +1,7 @@
import pytest import pytest
import scgenerator as sc import scgenerator as sc
from scgenerator.evaluator import EvaluatorError
def test_dispersion_logic(): def test_dispersion_logic():
@@ -11,3 +12,13 @@ def test_dispersion_logic():
assert params.compute("beta2_coefficients") == pytest.approx( assert params.compute("beta2_coefficients") == pytest.approx(
[-6.3772409974749684e-27, 5.116448086629504e-41] [-6.3772409974749684e-27, 5.116448086629504e-41]
) )
def test_relaxed_compute():
params = sc.Parameters(shape="gaussian", energy=1e-6, width=1e-12)
params.compute("peak_power")
with pytest.raises(EvaluatorError):
params = sc.Parameters(energy=1e-6, width=1e-12)
params.compute("peak_power")
params = sc.Parameters(energy=1e-6, width=1e-12)
assert params.compute("peak_power", strict=False) is None