added strict flag to Parameters.compute
This commit is contained in:
@@ -534,7 +534,9 @@ class Parameters:
|
|||||||
param["version"] = __version__
|
param["version"] = __version__
|
||||||
return param
|
return param
|
||||||
|
|
||||||
def compute(self, p_name: str, *other_p_names: str, **with_values: Any) -> Any | tuple[Any]:
|
def compute(
|
||||||
|
self, p_name: str, *other_p_names: str, strict: bool = True, **with_values: Any
|
||||||
|
) -> Any | tuple[Any]:
|
||||||
"""
|
"""
|
||||||
compute a single or a set of value
|
compute a single or a set of value
|
||||||
|
|
||||||
@@ -544,6 +546,8 @@ class Parameters:
|
|||||||
parameter to compute0
|
parameter to compute0
|
||||||
other_p_names : str, positional only
|
other_p_names : str, positional only
|
||||||
other parameters to compute
|
other parameters to compute
|
||||||
|
strict : bool, optional
|
||||||
|
raise an exception when something cannot be computed, by default True
|
||||||
with_values : Any, keyword only
|
with_values : Any, keyword only
|
||||||
compute the desired parameters as if self is updated with `with_values`.
|
compute the desired parameters as if self is updated with `with_values`.
|
||||||
|
|
||||||
@@ -565,9 +569,22 @@ class Parameters:
|
|||||||
"""
|
"""
|
||||||
evaluator = self.get_evaluator()
|
evaluator = self.get_evaluator()
|
||||||
evaluator.set(**with_values)
|
evaluator.set(**with_values)
|
||||||
first = evaluator.compute(p_name)
|
if strict:
|
||||||
|
|
||||||
|
def comp(_pn):
|
||||||
|
return evaluator.compute(_pn)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def comp(_pn):
|
||||||
|
try:
|
||||||
|
return evaluator.compute(_pn)
|
||||||
|
except EvaluatorError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first = comp(p_name)
|
||||||
if other_p_names:
|
if other_p_names:
|
||||||
return (first, *(evaluator.compute(p) for p in other_p_names))
|
return (first, *(comp(p) for p in other_p_names))
|
||||||
else:
|
else:
|
||||||
return first
|
return first
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import scgenerator as sc
|
import scgenerator as sc
|
||||||
|
from scgenerator.evaluator import EvaluatorError
|
||||||
|
|
||||||
|
|
||||||
def test_dispersion_logic():
|
def test_dispersion_logic():
|
||||||
@@ -11,3 +12,13 @@ def test_dispersion_logic():
|
|||||||
assert params.compute("beta2_coefficients") == pytest.approx(
|
assert params.compute("beta2_coefficients") == pytest.approx(
|
||||||
[-6.3772409974749684e-27, 5.116448086629504e-41]
|
[-6.3772409974749684e-27, 5.116448086629504e-41]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_relaxed_compute():
|
||||||
|
params = sc.Parameters(shape="gaussian", energy=1e-6, width=1e-12)
|
||||||
|
params.compute("peak_power")
|
||||||
|
with pytest.raises(EvaluatorError):
|
||||||
|
params = sc.Parameters(energy=1e-6, width=1e-12)
|
||||||
|
params.compute("peak_power")
|
||||||
|
params = sc.Parameters(energy=1e-6, width=1e-12)
|
||||||
|
assert params.compute("peak_power", strict=False) is None
|
||||||
|
|||||||
Reference in New Issue
Block a user