70 lines
2.0 KiB
Python
70 lines
2.0 KiB
Python
import numpy as np
|
|
import pytest
|
|
|
|
from scgenerator import math, units
|
|
from scgenerator.evaluator import Evaluator, EvaluatorError, Rule
|
|
|
|
|
|
@pytest.fixture
|
|
def disk_rules() -> list[Rule]:
|
|
return [
|
|
Rule("radius", lambda diameter: diameter / 2),
|
|
Rule("diameter", lambda radius: radius * 2),
|
|
Rule("diameter", lambda perimeter: perimeter / np.pi),
|
|
Rule("perimeter", lambda diameter: diameter * np.pi),
|
|
Rule("area", lambda radius: np.pi * radius**2),
|
|
Rule("radius", lambda area: np.sqrt(area / np.pi)),
|
|
]
|
|
|
|
|
|
def test_trivial(disk_rules: list[Rule]):
|
|
evaluator = Evaluator(*disk_rules)
|
|
evaluator.set(radius=5)
|
|
|
|
assert evaluator.compute("area") == pytest.approx(78.53981633974483)
|
|
|
|
evaluator.set(area=5)
|
|
assert evaluator.compute("area") == 5
|
|
assert evaluator.compute("radius") == 5
|
|
|
|
|
|
def test_simple():
|
|
evaluator = Evaluator.default()
|
|
evaluator.set(wavelength=800e-9, t_num=1024, dt=5e-15)
|
|
|
|
assert evaluator.compute("t") == pytest.approx(math.tspace(t_num=1024, dt=5e-15))
|
|
assert evaluator.compute("w0") == pytest.approx(units.nm(800))
|
|
|
|
|
|
def test_default_args_simple():
|
|
def some_function(a: int, b: int, c: int = 5):
|
|
return a + b + c
|
|
|
|
evaluator = Evaluator(Rule("d", some_function))
|
|
evaluator.set(a=1, b=3)
|
|
|
|
with pytest.raises(EvaluatorError):
|
|
evaluator.compute("c")
|
|
assert evaluator.compute("d") == 9
|
|
|
|
evaluator.clear_computed()
|
|
evaluator.set(c=10)
|
|
assert evaluator.compute("c") == 10
|
|
assert evaluator.compute("d") == 14
|
|
|
|
|
|
def test_default_args_real():
|
|
evaluator = Evaluator.default()
|
|
evaluator.set(
|
|
wavelength=1050e-9,
|
|
peak_power=5000,
|
|
width=1500e-15,
|
|
wavelength_window=(800e-9, 1500e-9),
|
|
t_num=2048,
|
|
)
|
|
|
|
assert evaluator.compute("dt") == pytest.approx(math.dt_from_min_wl(800e-9, 1050e-9), abs=0)
|
|
assert evaluator.compute("t") == pytest.approx(
|
|
math.tspace(t_num=2048, dt=evaluator.compute("dt")), abs=0
|
|
)
|