413 lines
11 KiB
Python
413 lines
11 KiB
Python
import numpy as np
|
|
import pytest
|
|
|
|
from scgenerator.variableparameters import (
|
|
Variable,
|
|
debase,
|
|
get_sequence,
|
|
unit_formatter,
|
|
vdataclass,
|
|
vfield,
|
|
)
|
|
|
|
|
|
def test_debase():
|
|
assert debase(586, (10, 10, 10)) == (5, 8, 6) # decimal
|
|
assert debase(57, (2, 2, 2, 2, 2, 2)) == (1, 1, 1, 0, 0, 1) # binary
|
|
assert debase(38, (8, 4, 2)) == (4, 3, 0) # 38 == 4 * (4*2) + 3 * (2) + 0 * 1
|
|
with pytest.raises(ValueError):
|
|
debase(785, (5, 1, 6))
|
|
assert debase(785, (5, 1, 6), strict=False) == (130, 0, 5)
|
|
|
|
assert debase(0, (1, 1, 1), strict=True) == (0, 0, 0)
|
|
|
|
|
|
def test_get_sequence():
|
|
s = get_sequence([[0, 0, 20]])
|
|
assert len(s) == 20
|
|
assert all(el == 0 for el in s)
|
|
|
|
s = get_sequence([1, 2, 3])
|
|
assert isinstance(s, np.ndarray) and s.dtype == int
|
|
assert np.all(s == np.array([1, 2, 3]))
|
|
|
|
s = get_sequence([[1, 2, 3]])
|
|
assert isinstance(s, np.ndarray) and s.dtype == float
|
|
assert np.all(s == np.linspace(1, 2, 3))
|
|
|
|
s = get_sequence([[0, 4, 5]])
|
|
assert isinstance(s, np.ndarray) and s.dtype == int
|
|
assert np.all(s == np.arange(5))
|
|
|
|
s = get_sequence([[3012, 3031, 20]])
|
|
assert isinstance(s, np.ndarray) and s.dtype == int
|
|
assert np.all(s == np.arange(3012, 3032))
|
|
|
|
s = get_sequence(0)
|
|
assert isinstance(s, np.ndarray) and s.dtype == int
|
|
assert np.all(s == np.array([0]))
|
|
|
|
s = get_sequence(np.zeros((1, 2)))
|
|
assert s.shape == (1, 2)
|
|
|
|
s = get_sequence([[1e-12, 1e-5, 8, "geometric"]])
|
|
assert isinstance(s, np.ndarray) and s.dtype == float
|
|
assert np.all(s == np.geomspace(1e-12, 1e-5, 8))
|
|
|
|
s = get_sequence([[1e-12, 1e-5, 9, "geometric0"]])
|
|
assert isinstance(s, np.ndarray) and s.dtype == float
|
|
assert np.all(s == np.concatenate(((0,), np.geomspace(1e-12, 1e-5, 8))))
|
|
|
|
s = get_sequence(np.arange(10).reshape(5, 2))
|
|
assert s[0][1] == 1
|
|
assert s[4][0] == 8
|
|
|
|
with pytest.raises(ValueError):
|
|
get_sequence([[1, 2, 3, "logarithm"]])
|
|
|
|
with pytest.raises(ValueError):
|
|
get_sequence([[1, 2, 3, 4]])
|
|
|
|
with pytest.raises(TypeError):
|
|
get_sequence([[0, 1, "linear"]])
|
|
|
|
with pytest.raises(TypeError):
|
|
get_sequence([[0, 1, int]])
|
|
|
|
|
|
def test_constant_number():
|
|
"""classes with constant fields don't increase size and always return the constant"""
|
|
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield(default=3)
|
|
y: Variable = vfield(default=4)
|
|
|
|
conf = Conf()
|
|
assert len(conf) == 1
|
|
assert conf.x(0) == 3
|
|
assert conf.y(0) == 4
|
|
|
|
with pytest.raises(ValueError):
|
|
conf.x(1)
|
|
with pytest.raises(ValueError):
|
|
conf.y(1)
|
|
|
|
|
|
def test_non_numbers():
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield(default=["Hello", "World"])
|
|
y: Variable = vfield()
|
|
|
|
conf = Conf(y=("a", "b", "c"))
|
|
assert conf.x(0) == "Hello"
|
|
assert conf.y(0) == "a"
|
|
assert conf.y(1) == "b"
|
|
assert conf.y(2) == "c"
|
|
|
|
assert conf.x(3) == "World"
|
|
|
|
|
|
def test_constant_list():
|
|
"""classes with constant fields don't increase size and always return the constant"""
|
|
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield(default=[1, 2, 7])
|
|
y: Variable = vfield(default=[2, 0])
|
|
z: Variable = vfield(default=np.arange(10).reshape(5, 2))
|
|
|
|
conf = Conf()
|
|
assert len(conf) == 3 * 2 * 5
|
|
assert conf.x(0) == 1
|
|
assert conf.y(0) == 2
|
|
assert conf.x(len(conf) - 1) == 7
|
|
assert conf.y(len(conf) - 1) == 0
|
|
assert list(conf.z(0)) == [0, 1]
|
|
assert list(conf.z(len(conf) - 1)) == [8, 9]
|
|
|
|
with pytest.raises(ValueError):
|
|
conf.x(30)
|
|
with pytest.raises(ValueError):
|
|
conf.y(30)
|
|
|
|
|
|
def test_simple_synchronize():
|
|
"""synchronize 2 variable fields"""
|
|
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield(default=[1, 2, 7])
|
|
y: Variable = vfield(default=[2, 0, 0], sync=x)
|
|
|
|
conf = Conf()
|
|
assert conf.x(0) == 1
|
|
assert conf.x(1) == 2
|
|
assert conf.x(2) == 7
|
|
assert conf.y(0) == 2
|
|
assert conf.y(1) == 0
|
|
assert conf.y(2) == 0
|
|
|
|
with pytest.raises(ValueError):
|
|
conf.y(3)
|
|
with pytest.raises(ValueError):
|
|
conf.x(3)
|
|
|
|
|
|
def test_sync_strict():
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield(default=[1, 2, 7])
|
|
y: Variable = vfield(default=[2, 0], sync=x)
|
|
|
|
conf = Conf()
|
|
assert conf.x(2) == 7
|
|
with pytest.raises(ValueError):
|
|
conf.y(0)
|
|
|
|
conf.x = [78, 55]
|
|
assert len(conf) == 2
|
|
assert conf.y(1) == 0
|
|
assert conf.x(1) == 55
|
|
|
|
with pytest.raises(ValueError):
|
|
conf.y = [1, 2, 3]
|
|
|
|
conf.x = [11, 22, 33, 44, 55]
|
|
conf.y = [0, 1, 33, 111, 1111]
|
|
assert conf.x(2) == conf.y(2) == 33
|
|
assert len(conf) == len(conf.x) == len(conf.y) == 5
|
|
|
|
|
|
def test_decoration():
|
|
"""proper construction of the class"""
|
|
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield(default=[1, 2, 7])
|
|
y: Variable = vfield(default=[2, 0])
|
|
|
|
assert hasattr(Conf, "filter")
|
|
assert hasattr(Conf, "__vfields__")
|
|
assert hasattr(Conf, "__len__")
|
|
assert list(Conf.__vfields__["x"].default_sequence) == [1, 2, 7]
|
|
|
|
assert not hasattr(Conf, "__variables_nums__")
|
|
assert hasattr(Conf(), "__variables_nums__")
|
|
|
|
|
|
def test_name_error():
|
|
"""name/attribute errors are raised when accessing inexistent variable (direct, filter, ...)"""
|
|
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield(default=[1, 2, 7])
|
|
y: Variable = vfield(default=[2, 0])
|
|
|
|
with pytest.raises(TypeError):
|
|
Conf(z=5)
|
|
|
|
conf = Conf(y=5)
|
|
with pytest.raises(AttributeError):
|
|
conf.filter(z=45)
|
|
|
|
|
|
def test_assignment():
|
|
"""assigned values properly update the whole instance"""
|
|
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield(default=[1, 2, 7])
|
|
y: Variable = vfield(default=[2, 0])
|
|
|
|
conf = Conf(x=5, y=0)
|
|
assert len(conf) == 1
|
|
assert conf.x(0) == 5
|
|
assert conf.y(0) == 0
|
|
|
|
with pytest.raises(ValueError):
|
|
conf.x(6)
|
|
with pytest.raises(ValueError):
|
|
conf.y(6)
|
|
conf = Conf(x=[[0, 2, 5]], y=(0.1, 0.2))
|
|
assert len(conf) == 10
|
|
assert conf.x(0) == 0
|
|
assert conf.x(5) == 1.0
|
|
assert conf.y(0) == 0.1
|
|
assert conf.y(2) == 0.1
|
|
|
|
with pytest.raises(IndexError):
|
|
conf.x(-5)
|
|
with pytest.raises(IndexError):
|
|
conf.x(5, local=True)
|
|
with pytest.raises(ValueError):
|
|
conf.y(10)
|
|
|
|
conf.y = np.linspace(-1, 2, 7)
|
|
assert len(conf) == 7 * 5
|
|
assert conf.x(13) == 0.5
|
|
assert conf.y(13) == 2.0
|
|
|
|
with pytest.raises(AttributeError):
|
|
del conf.y
|
|
|
|
conf.x = 0
|
|
assert len(conf) == 7
|
|
assert conf.x(5) == 0
|
|
assert conf.y(5) == 1.5
|
|
|
|
|
|
def test_unit_forammter():
|
|
fmt = unit_formatter("s", 4)
|
|
|
|
assert fmt(5e-29) == "0s"
|
|
assert fmt(5e27) == "5000Ys"
|
|
|
|
fmt = unit_formatter("s", 4, vmin=None)
|
|
assert fmt(5e-29) == "5e-05ys"
|
|
|
|
fmt = unit_formatter("%", 4)
|
|
assert fmt(1.000001235) == "100%"
|
|
assert fmt(0.01235) == "1.235%"
|
|
assert fmt(0.001235) == "0.1235%"
|
|
assert fmt(0.0001235) == "0.01235%"
|
|
assert fmt(0.00001235) == "0.001235%"
|
|
assert fmt(0.000001235) == "0.0001235%"
|
|
assert fmt(0.0000001235) == "1.235e-05%"
|
|
assert fmt(0.00000001235) == "1.235e-06%"
|
|
assert fmt(0.000000001235) == "1.235e-07%"
|
|
|
|
fmt = unit_formatter("%", 2)
|
|
assert fmt(0.01235) == "1.2%"
|
|
assert fmt(0.001235) == "0.12%"
|
|
assert fmt(1.000001235) == "1e+02%"
|
|
assert fmt(0.000001235) == "0.00012%"
|
|
assert fmt(0.000000001235) == "1.2e-07%"
|
|
|
|
fmt = unit_formatter("", 4)
|
|
assert fmt((4, 2)) == "(4, 2)"
|
|
assert fmt([4, 2]) == "(4, 2)"
|
|
assert fmt(np.array([4, 2])) == "(4, 2)"
|
|
assert fmt(np.zeros((2, 2, 2))) == "(((0, 0), (0, 0)), ((0, 0), (0, 0)))"
|
|
|
|
|
|
def test_format_str():
|
|
"""format_str takes precedence"""
|
|
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield(default=123456789)
|
|
y: Variable = vfield(default=123456789, decimals=2)
|
|
z: Variable = vfield(default=123456789, decimals=2, format_str="d")
|
|
|
|
conf = Conf()
|
|
assert conf.x.format(0) == "x=1.235e+08"
|
|
assert conf.y.format(0) == "y=1.2e+08"
|
|
assert conf.z.format(0) == "z=123456789"
|
|
|
|
|
|
def test_param_formatting():
|
|
"""formatting is always respected"""
|
|
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield(default=[1, 2, 7])
|
|
y: Variable = vfield(default=[2, 2.456546, 0.000022305], decimals=2)
|
|
z: Variable = vfield(default=7.5e-9, decimals=3, suffix="s")
|
|
w: Variable = vfield(default=7568.4e6, decimals=4, suffix="Hz")
|
|
|
|
conf = Conf()
|
|
assert conf.x.format(0) == "x=1"
|
|
assert conf.y.format(0) == "y=2"
|
|
assert conf.y.format(1) == "y=2.5"
|
|
assert conf.y.format(2) == "y=2.2e-05"
|
|
assert conf.z.format(0) == "z=7.5ns"
|
|
assert conf.w.format(0) == "w=7.568GHz"
|
|
|
|
|
|
def test_error_no_default():
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield(default=[1, 2, 7])
|
|
y: Variable = vfield()
|
|
|
|
with pytest.raises(ValueError):
|
|
conf = Conf()
|
|
conf = Conf(y=5)
|
|
assert conf.x(0) == 1
|
|
assert conf.y(0) == 5
|
|
|
|
|
|
def test_filter():
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield()
|
|
y: Variable = vfield(default=[2, 2.456546, 0.000022305])
|
|
z: Variable = vfield(default=7.5e-9)
|
|
w: Variable = vfield(default=7568.4e6)
|
|
|
|
conf = Conf(x=(1, 2, 3))
|
|
f = conf.filter(x=1, y=(0, 1))
|
|
assert len(f) == 2
|
|
assert list(f) == [3, 4]
|
|
|
|
conf = Conf(x=[[0, 1, 56]], w=[0, 1, 2, 3, 4])
|
|
assert len(conf) == 56 * 5 * 3
|
|
f = conf.filter(x=slice(None, None, 5), y=1, w=(0, 1, 2))
|
|
assert len(f) == 12 * 3
|
|
assert all(conf.y(i) == Conf.y.default_sequence[1] for i in f)
|
|
|
|
mat = np.random.rand(56 * 5 * 3)
|
|
assert f(mat, squeeze=False).shape == (12, 1, 1, 3)
|
|
assert f(mat).shape == (12, 3)
|
|
|
|
mat = np.random.rand(70, 12)
|
|
with pytest.raises(IndexError):
|
|
assert f(mat, squeeze=False).shape == (12, 1, 1, 3)
|
|
assert f(mat.ravel(), squeeze=False).shape == (12, 1, 1, 3)
|
|
assert f(mat.ravel()).shape == (12, 3)
|
|
|
|
mat = np.random.rand(71, 12)
|
|
with pytest.raises(IndexError):
|
|
assert f(mat, axis=0).shape == (12, 3)
|
|
|
|
mat = np.random.rand(70, 12, 11)
|
|
assert f(mat.reshape(len(conf), 11), squeeze=False).shape == (12, 1, 1, 3, 11)
|
|
|
|
|
|
def test_auto_sync():
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield()
|
|
y: Variable = vfield(default=[7, 8, 9])
|
|
z: Variable = vfield(default=7.5e-9, auto_sync=True)
|
|
w: Variable = vfield(default=7568.4e6, auto_sync=True)
|
|
|
|
conf = Conf([1, 2, 3])
|
|
assert len(conf) == 9
|
|
|
|
@vdataclass
|
|
class Conf:
|
|
x: Variable = vfield()
|
|
y: Variable = vfield(default=[7, 8, 9])
|
|
z: Variable = vfield(default=7.5e-9, auto_sync=True)
|
|
|
|
conf = Conf(1, z=[1, 2, 3])
|
|
assert len(conf) == 3
|
|
|
|
conf.z = [0, 5]
|
|
assert len(conf) == 6
|
|
assert conf.y(0) == 7
|
|
assert conf.y(1) == 7
|
|
assert conf.y(2) == 8
|
|
assert conf.y(3) == 8
|
|
assert conf.y(4) == 9
|
|
assert conf.y(5) == 9
|
|
assert conf.z(0) == 0
|
|
assert conf.z(1) == 5
|
|
assert conf.z(2) == 0
|
|
assert conf.z(3) == 5
|
|
assert conf.z(4) == 0
|
|
assert conf.z(5) == 5
|