Files
scgenerator/tests/test_variableparameters.py
2023-10-10 11:41:38 +02:00

272 lines
7.3 KiB
Python

import numpy as np
import pytest
from scgenerator.variableparameters import (
Variable,
debase,
get_sequence,
unit_formatter,
vdataclass,
vfield,
)
def test_debase():
assert debase(586, (10, 10, 10)) == (5, 8, 6) # decimal
assert debase(57, (2, 2, 2, 2, 2, 2)) == (1, 1, 1, 0, 0, 1) # binary
assert debase(38, (8, 4, 2)) == (4, 3, 0) # 38 == 4 * (4*2) + 3 * (2) + 0 * 1
with pytest.raises(ValueError):
debase(785, (5, 1, 6))
assert debase(785, (5, 1, 6), strict=False) == (130, 0, 5)
def test_get_sequence():
s = get_sequence([1, 2, 3])
assert isinstance(s, np.ndarray) and s.dtype == int
assert np.all(s == np.array([1, 2, 3]))
s = get_sequence([[1, 2, 3]])
assert isinstance(s, np.ndarray) and s.dtype == float
assert np.all(s == np.linspace(1, 2, 3))
s = get_sequence([[0, 4, 5]])
assert isinstance(s, np.ndarray) and s.dtype == int
assert np.all(s == np.arange(5))
s = get_sequence([[3012, 3031, 20]])
assert isinstance(s, np.ndarray) and s.dtype == int
assert np.all(s == np.arange(3012, 3032))
s = get_sequence(0)
assert isinstance(s, np.ndarray) and s.dtype == int
assert np.all(s == np.array([0]))
s = get_sequence([[1e-12, 1e-5, 8, "geometric"]])
assert isinstance(s, np.ndarray) and s.dtype == float
assert np.all(s == np.geomspace(1e-12, 1e-5, 8))
s = get_sequence([[1e-12, 1e-5, 9, "geometric0"]])
assert isinstance(s, np.ndarray) and s.dtype == float
assert np.all(s == np.concatenate(((0,), np.geomspace(1e-12, 1e-5, 8))))
with pytest.raises(ValueError):
get_sequence([[1, 2, 3, "logarithm"]])
with pytest.raises(ValueError):
get_sequence([[1, 2, 3, 4]])
with pytest.raises(TypeError):
get_sequence([[0, 1, "linear"]])
with pytest.raises(TypeError):
get_sequence([[0, 1, int]])
def test_constant_number():
"""classes with constant fields don't increase size and always return the constant"""
@vdataclass
class Conf:
x: Variable = vfield(default=3)
y: Variable = vfield(default=4)
conf = Conf()
assert len(conf) == 1
assert conf.x(0) == 3
assert conf.y(0) == 4
with pytest.raises(ValueError):
conf.x(1)
with pytest.raises(ValueError):
conf.y(1)
def test_constant_list():
"""classes with constant fields don't increase size and always return the constant"""
@vdataclass
class Conf:
x: Variable = vfield(default=[1, 2, 7])
y: Variable = vfield(default=[2, 0])
conf = Conf()
assert len(conf) == 6
assert conf.x(0) == 1
assert conf.y(0) == 2
assert conf.x(len(conf) - 1) == 7
assert conf.y(len(conf) - 1) == 0
with pytest.raises(ValueError):
conf.x(6)
with pytest.raises(ValueError):
conf.y(6)
def test_decoration():
"""proper construction of the class"""
@vdataclass
class Conf:
x: Variable = vfield(default=[1, 2, 7])
y: Variable = vfield(default=[2, 0])
assert hasattr(Conf, "filter")
assert hasattr(Conf, "__vfields__")
assert hasattr(Conf, "__len__")
assert list(Conf.__vfields__["x"].default_sequence) == [1, 2, 7]
assert not hasattr(Conf, "__variables_nums__")
assert hasattr(Conf(), "__variables_nums__")
def test_name_error():
"""name/attribute errors are raised when accessing inexistent variable (direct, filter, ...)"""
@vdataclass
class Conf:
x: Variable = vfield(default=[1, 2, 7])
y: Variable = vfield(default=[2, 0])
with pytest.raises(TypeError):
Conf(z=5)
conf = Conf(y=5)
with pytest.raises(AttributeError):
conf.filter(z=45)
def test_assignment():
"""assigned values properly update the whole instance"""
@vdataclass
class Conf:
x: Variable = vfield(default=[1, 2, 7])
y: Variable = vfield(default=[2, 0])
conf = Conf(x=5, y=0)
assert len(conf) == 1
assert conf.x(0) == 5
assert conf.y(0) == 0
with pytest.raises(ValueError):
conf.x(6)
with pytest.raises(ValueError):
conf.y(6)
conf = Conf(x=[[0, 2, 5]], y=(0.1, 0.2))
assert len(conf) == 10
assert conf.x(0) == 0
assert conf.x(5) == 1.0
assert conf.y(0) == 0.1
assert conf.y(2) == 0.1
with pytest.raises(IndexError):
conf.x(-5)
with pytest.raises(IndexError):
conf.x(5, local=True)
with pytest.raises(ValueError):
conf.y(10)
conf.y = np.linspace(-1, 2, 7)
assert len(conf) == 7 * 5
assert conf.x(13) == 0.5
assert conf.y(13) == 2.0
with pytest.raises(AttributeError):
del conf.y
conf.x = 0
assert len(conf) == 7
assert conf.x(5) == 0
assert conf.y(5) == 1.5
def test_unit_forammter():
fmt = unit_formatter("s", 4)
assert fmt(5e-29) == "0s"
assert fmt(5e27) == "5000Ys"
fmt = unit_formatter("s", 4, vmin=None)
assert fmt(5e-29) == "5e-05ys"
def test_param_formatting():
"""formatting is always respected"""
@vdataclass
class Conf:
x: Variable = vfield(default=[1, 2, 7])
y: Variable = vfield(default=[2, 2.456546, 0.000022305], decimals=2)
z: Variable = vfield(default=7.5e-9, decimals=3, suffix="s")
w: Variable = vfield(default=7568.4e6, decimals=4, suffix="Hz")
conf = Conf()
assert conf.x.format(0) == "x=1"
assert conf.y.format(0) == "y=2"
assert conf.y.format(1) == "y=2.5"
assert conf.y.format(2) == "y=2.2e-05"
assert conf.z.format(0) == "z=7.5ns"
assert conf.w.format(0) == "w=7.568GHz"
def test_error_no_default():
@vdataclass
class Conf:
x: Variable = vfield(default=[1, 2, 7])
y: Variable = vfield()
with pytest.raises(ValueError):
conf = Conf()
conf = Conf(y=5)
assert conf.x(0) == 1
assert conf.y(0) == 5
def test_filter():
@vdataclass
class Conf:
x: Variable = vfield()
y: Variable = vfield(default=[2, 2.456546, 0.000022305], decimals=2)
z: Variable = vfield(default=7.5e-9, decimals=3, suffix="s")
w: Variable = vfield(default=7568.4e6, decimals=4, suffix="Hz")
conf = Conf(x=(1, 2, 3))
f = conf.filter(x=1, y=(0, 1))
assert len(f) == 2
assert list(f) == [3, 4]
conf = Conf(x=[[0, 1, 56]], w=[0, 1, 2, 3, 4])
assert len(conf) == 56 * 5 * 3
f = conf.filter(x=slice(None, None, 5), y=1, w=(0, 1, 2))
assert len(f) == 12 * 3
assert all(conf.y(i) == Conf.y.default_sequence[1] for i in f)
mat = np.random.rand(56 * 5 * 3)
assert f(mat, squeeze=False).shape == (12, 1, 1, 3)
assert f(mat).shape == (12, 3)
mat = np.random.rand(70, 12)
with pytest.raises(IndexError):
assert f(mat, squeeze=False).shape == (12, 1, 1, 3)
assert f(mat.ravel(), squeeze=False).shape == (12, 1, 1, 3)
assert f(mat.ravel()).shape == (12, 3)
mat = np.random.rand(71, 12)
with pytest.raises(IndexError):
assert f(mat, axis=0).shape == (12, 3)
mat = np.random.rand(70, 12, 11)
assert f(mat.reshape(len(conf), 11), squeeze=False).shape == (12, 1, 1, 3, 11)
def test_param_parsing():
"""parsing a string or path returns the correct values"""
def test_config_parsing():
"""parsing a string returns the exact id corresponding to the string/path"""
def test_repr():
"""print a config object correctly"""