import numpy as np import pytest from scgenerator.variableparameters import ( Variable, debase, get_sequence, unit_formatter, vdataclass, vfield, ) def test_debase(): assert debase(586, (10, 10, 10)) == (5, 8, 6) # decimal assert debase(57, (2, 2, 2, 2, 2, 2)) == (1, 1, 1, 0, 0, 1) # binary assert debase(38, (8, 4, 2)) == (4, 3, 0) # 38 == 4 * (4*2) + 3 * (2) + 0 * 1 with pytest.raises(ValueError): debase(785, (5, 1, 6)) assert debase(785, (5, 1, 6), strict=False) == (130, 0, 5) def test_get_sequence(): s = get_sequence([1, 2, 3]) assert isinstance(s, np.ndarray) and s.dtype == int assert np.all(s == np.array([1, 2, 3])) s = get_sequence([[1, 2, 3]]) assert isinstance(s, np.ndarray) and s.dtype == float assert np.all(s == np.linspace(1, 2, 3)) s = get_sequence([[0, 4, 5]]) assert isinstance(s, np.ndarray) and s.dtype == int assert np.all(s == np.arange(5)) s = get_sequence([[3012, 3031, 20]]) assert isinstance(s, np.ndarray) and s.dtype == int assert np.all(s == np.arange(3012, 3032)) s = get_sequence(0) assert isinstance(s, np.ndarray) and s.dtype == int assert np.all(s == np.array([0])) s = get_sequence([[1e-12, 1e-5, 8, "geometric"]]) assert isinstance(s, np.ndarray) and s.dtype == float assert np.all(s == np.geomspace(1e-12, 1e-5, 8)) s = get_sequence([[1e-12, 1e-5, 9, "geometric0"]]) assert isinstance(s, np.ndarray) and s.dtype == float assert np.all(s == np.concatenate(((0,), np.geomspace(1e-12, 1e-5, 8)))) with pytest.raises(ValueError): get_sequence([[1, 2, 3, "logarithm"]]) with pytest.raises(ValueError): get_sequence([[1, 2, 3, 4]]) with pytest.raises(TypeError): get_sequence([[0, 1, "linear"]]) with pytest.raises(TypeError): get_sequence([[0, 1, int]]) def test_constant_number(): """classes with constant fields don't increase size and always return the constant""" @vdataclass class Conf: x: Variable = vfield(default=3) y: Variable = vfield(default=4) conf = Conf() assert len(conf) == 1 assert conf.x(0) == 3 assert conf.y(0) == 4 with pytest.raises(ValueError): conf.x(1) with pytest.raises(ValueError): conf.y(1) def test_constant_list(): """classes with constant fields don't increase size and always return the constant""" @vdataclass class Conf: x: Variable = vfield(default=[1, 2, 7]) y: Variable = vfield(default=[2, 0]) conf = Conf() assert len(conf) == 6 assert conf.x(0) == 1 assert conf.y(0) == 2 assert conf.x(len(conf) - 1) == 7 assert conf.y(len(conf) - 1) == 0 with pytest.raises(ValueError): conf.x(6) with pytest.raises(ValueError): conf.y(6) def test_decoration(): """proper construction of the class""" @vdataclass class Conf: x: Variable = vfield(default=[1, 2, 7]) y: Variable = vfield(default=[2, 0]) assert hasattr(Conf, "filter") assert hasattr(Conf, "__vfields__") assert hasattr(Conf, "__len__") assert list(Conf.__vfields__["x"].default_sequence) == [1, 2, 7] assert not hasattr(Conf, "__variables_nums__") assert hasattr(Conf(), "__variables_nums__") def test_name_error(): """name/attribute errors are raised when accessing inexistent variable (direct, filter, ...)""" @vdataclass class Conf: x: Variable = vfield(default=[1, 2, 7]) y: Variable = vfield(default=[2, 0]) with pytest.raises(TypeError): Conf(z=5) conf = Conf(y=5) with pytest.raises(AttributeError): conf.filter(z=45) def test_assignment(): """assigned values properly update the whole instance""" @vdataclass class Conf: x: Variable = vfield(default=[1, 2, 7]) y: Variable = vfield(default=[2, 0]) conf = Conf(x=5, y=0) assert len(conf) == 1 assert conf.x(0) == 5 assert conf.y(0) == 0 with pytest.raises(ValueError): conf.x(6) with pytest.raises(ValueError): conf.y(6) conf = Conf(x=[[0, 2, 5]], y=(0.1, 0.2)) assert len(conf) == 10 assert conf.x(0) == 0 assert conf.x(5) == 1.0 assert conf.y(0) == 0.1 assert conf.y(2) == 0.1 with pytest.raises(IndexError): conf.x(-5) with pytest.raises(IndexError): conf.x(5, local=True) with pytest.raises(ValueError): conf.y(10) conf.y = np.linspace(-1, 2, 7) assert len(conf) == 7 * 5 assert conf.x(13) == 0.5 assert conf.y(13) == 2.0 with pytest.raises(AttributeError): del conf.y conf.x = 0 assert len(conf) == 7 assert conf.x(5) == 0 assert conf.y(5) == 1.5 def test_unit_forammter(): fmt = unit_formatter("s", 4) assert fmt(5e-29) == "0s" assert fmt(5e27) == "5000Ys" fmt = unit_formatter("s", 4, vmin=None) assert fmt(5e-29) == "5e-05ys" def test_param_formatting(): """formatting is always respected""" @vdataclass class Conf: x: Variable = vfield(default=[1, 2, 7]) y: Variable = vfield(default=[2, 2.456546, 0.000022305], decimals=2) z: Variable = vfield(default=7.5e-9, decimals=3, suffix="s") w: Variable = vfield(default=7568.4e6, decimals=4, suffix="Hz") conf = Conf() assert conf.x.format(0) == "x=1" assert conf.y.format(0) == "y=2" assert conf.y.format(1) == "y=2.5" assert conf.y.format(2) == "y=2.2e-05" assert conf.z.format(0) == "z=7.5ns" assert conf.w.format(0) == "w=7.568GHz" def test_error_no_default(): @vdataclass class Conf: x: Variable = vfield(default=[1, 2, 7]) y: Variable = vfield() with pytest.raises(ValueError): conf = Conf() conf = Conf(y=5) assert conf.x(0) == 1 assert conf.y(0) == 5 def test_filter(): @vdataclass class Conf: x: Variable = vfield() y: Variable = vfield(default=[2, 2.456546, 0.000022305]) z: Variable = vfield(default=7.5e-9) w: Variable = vfield(default=7568.4e6) conf = Conf(x=(1, 2, 3)) f = conf.filter(x=1, y=(0, 1)) assert len(f) == 2 assert list(f) == [3, 4] conf = Conf(x=[[0, 1, 56]], w=[0, 1, 2, 3, 4]) assert len(conf) == 56 * 5 * 3 f = conf.filter(x=slice(None, None, 5), y=1, w=(0, 1, 2)) assert len(f) == 12 * 3 assert all(conf.y(i) == Conf.y.default_sequence[1] for i in f) mat = np.random.rand(56 * 5 * 3) assert f(mat, squeeze=False).shape == (12, 1, 1, 3) assert f(mat).shape == (12, 3) mat = np.random.rand(70, 12) with pytest.raises(IndexError): assert f(mat, squeeze=False).shape == (12, 1, 1, 3) assert f(mat.ravel(), squeeze=False).shape == (12, 1, 1, 3) assert f(mat.ravel()).shape == (12, 3) mat = np.random.rand(71, 12) with pytest.raises(IndexError): assert f(mat, axis=0).shape == (12, 3) mat = np.random.rand(70, 12, 11) assert f(mat.reshape(len(conf), 11), squeeze=False).shape == (12, 1, 1, 3, 11) def test_param_parsing(): """parsing a string or path returns the correct values""" def test_config_parsing(): """parsing a string returns the exact id corresponding to the string/path""" def test_repr(): """print a config object correctly"""