From c04d6d667e81135a1342db5b69068ee7e65461c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 10 Oct 2023 11:41:38 +0200 Subject: [PATCH] vdataclass implemented --- .gitignore | 1 + pyproject.toml | 5 +- src/scgenerator/variableparameters.py | 274 +++++++++++--------------- tests/test_variableparameters.py | 251 ++++++++++++++++++++--- 4 files changed, 349 insertions(+), 182 deletions(-) diff --git a/.gitignore b/.gitignore index 9a9f200..b67d33c 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ Archive *.png *.pdf __pycache__ +.*_cache *.egg-info *sim_data* tmp* diff --git a/pyproject.toml b/pyproject.toml index 78fd877..255816a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "scgenerator" -version = "0.3.17" +version = "0.3.18" description = "Simulate nonlinear pulse propagation in optical fibers" readme = "README.md" authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] @@ -38,3 +38,6 @@ line-length = 100 [tool.isort] profile = "black" skip = ["__init__.py"] + +[tool.pyright] +pythonVersion = "3.11" diff --git a/src/scgenerator/variableparameters.py b/src/scgenerator/variableparameters.py index 9799bcb..fd9278e 100644 --- a/src/scgenerator/variableparameters.py +++ b/src/scgenerator/variableparameters.py @@ -1,20 +1,49 @@ import itertools import warnings from dataclasses import dataclass, field -from typing import Callable, Iterator, Protocol, Sequence, TypeVar +from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar import numpy as np T = TypeVar("T") +P = ParamSpec("P") -def debase(num: int, base: tuple[int, ...]) -> tuple[int, ...]: +def debase(num: int, base: tuple[int, ...], strict: bool = True) -> tuple[int, ...]: + """ + decomboses a number into its digits in a variable base + + Parameters + ---------- + num : int + number to decompose + base : tuple[int, ...] + base to use (see examples below) + strict : bool, optional + raise an error if the number so big that the first digit would be `>= base[0]` (the default) + + Returns + ------- + tuple[int, ...] + digits + + Examples + -------- + ``` + assert debase(586, (10, 10, 10)) == (5, 8, 6) # decimal + assert debase(57, (2, 2, 2, 2, 2, 2)) == (1, 1, 1, 0, 0, 1) # binary + assert debase(38, (8, 4, 2)) == (4, 3, 0) # 38 == 4 * (4*2) + 3 * (2) + 0 * 1 + debase(785, (5, 1, 6)) # error because `strict=True` + assert debase(785, (5, 1, 6), strict=False) == (130, 0, 5) # 130 >= 5 + ``` + """ indices = [] - _, *base = base - for i in range(len(base)): - ind, num = divmod(num, np.prod(base[i:])) + for i in range(len(base) - 1): + ind, num = divmod(num, np.prod(base[i + 1 :])) indices.append(ind) indices.append(num) + if strict and indices[0] >= base[0]: + raise ValueError(f"{num} too big for base {base!r} in strict mode") return tuple(indices) @@ -46,7 +75,9 @@ def int_linear(id, d_min, d_max, d_num) -> int: return d_min + id * ((d_max - d_min) // (d_num - 1)) -def unit_formatter(unit: str, decimals: int = 1) -> Callable[[float | int], str]: +def unit_formatter( + unit: str, decimals: int = 1, vmin: float | None = 1e-28 +) -> Callable[[float | int], str]: if not unit: def formatter(val): @@ -55,67 +86,58 @@ def unit_formatter(unit: str, decimals: int = 1) -> Callable[[float | int], str] else: def formatter(val): + if vmin is not None and abs(val) < vmin: + return f"0{unit}" base, true_exp = format(val, f".{3+decimals}e").split("e") true_exp = int(true_exp) exp, mult = divmod(true_exp, 3) + if exp < -8: + mult -= -3 * (8 + exp) + exp = -8 + elif exp > 8: + mult += 3 * (exp - 8) + exp = 8 + prefix = "yzafpnµm kMGTPEZY"[8 + exp] if exp else "" - return f"{float(base)*10**mult:.{decimals}f}{prefix}{unit}" + return f"{float(base)*10**mult:.{decimals}g}{prefix}{unit}" return formatter -class Variable(Protocol[T]): - place: int +def sequence_from_specs( + start: float | int, stop: float | int, num: int, kind: str = "linear" +) -> np.ndarray: + if isinstance(start, int) and isinstance(stop, int) and kind == "linear": + step = (stop - start) // (num - 1) + if step == 0: + return sequence_from_specs(float(start), float(stop), num) + return np.arange(start, stop + step, step) + if kind == "linear": + return np.linspace(start, stop, num) + elif kind == "geometric": + return np.geomspace(start, stop, num) + elif kind == "geometric0": + num -= 1 + return np.concatenate(((0,), np.geomspace(start, stop, num))) + else: + raise ValueError(f"kind {kind!r} not recognized") - def __call__(self, _: int, local: bool = False) -> T: - ... - def __repr__(self) -> str: - ... - - def __len__(self) -> int: - ... - - def format(self, id: int, local: bool = False): - return f"{self.name}={self.formatter(self(id, local))}" - - def parse(self, _: str) -> T: - ... - - def filter(self, local_id) -> tuple[list[int], tuple[int, ...]]: - """ - filter based on the desired local id. - - Parameters - ---------- - local_id : int - index of the desired value - - Returns - ------- - list[int] - list of all global indices point to instances of this variable parameter that have the - same value - tuple[int, ...] - resulting shape formed by all other variable parameters - """ - shape = self.all_nums.copy() - shape[self.place] = 1 - indices = [ - i - for i, indices in enumerate(itertools.product(*(range(k) for k in self.all_nums))) - if indices[self.place] == local_id - ] - return indices, tuple(shape) +def get_sequence(value) -> np.ndarray: + if isinstance(value, (Sequence, np.ndarray)): + if isinstance(value[0], Sequence): + value = sequence_from_specs(*value[0]) + return np.asanyarray(value) + else: + return np.array([value]) @dataclass -class FuncVariable(Variable[float]): - func: Callable[[int, float, float, int], float] - args: tuple[float, float, int] - all_nums: list[int] - place: int +class Variable(Generic[T]): name: str + values: Sequence[T] + place: int + all_nums: list[int] suffix: str = "" decimals: int = 4 formatter: Callable[[int], str] = field(init=False) @@ -130,82 +152,30 @@ class FuncVariable(Variable[float]): def __post_init__(self): self.formatter = unit_formatter(self.suffix, self.decimals) - def __call__(self, id: int, local: bool = False) -> float: + def __call__(self, id: int, local: bool = False) -> T: + if id < 0: + raise IndexError("negative indices are not allowed") if not local: if not isinstance(id, (int, np.integer)): raise TypeError(f"id {id!r} is not an integer") if None in self.all_nums: raise ValueError("at least one VariableParameter has not been configured") id = debase(id, self.all_nums)[self.place] - if id >= self.args[2]: - raise ValueError( - f"id {id} is too big for variable parameter of size {self.args[2]}" - ) - return self.func(id, *self.args) + if id >= len(self): + raise IndexError(f"id {id} is too big for variable parameter of size {len(self)}") + return self.values[id] def __repr__(self) -> str: - return f"{self.__class__.__name__}({', '.join(format(el) for el in self.args)})" + return f"{self.__class__.__name__}({self.values})" def __len__(self) -> int: - return self.args[2] + return len(self.values) def parse(self, s: str): return NotImplemented - -@dataclass -class ArrayVariable(Variable[T]): - values: Sequence[T] - all_nums: list[int] - place: int - name: str - suffix: str - decimals: int - - def __call__(self, id: int, local: bool = False) -> T: - if not local: - id = debase(id, self.all_nums)[self.place] - if id >= len(self.values): - raise ValueError( - f"id {id} is too big for variable parameter of size {len(self.values)}" - ) - return self.values[id] - - def __len__(self) -> int: - return 1 - - def format(self, _: int, local: bool = False) -> str: - return self.formatted - - def parse(self, _: str) -> float: - return NotImplemented - - -@dataclass -class Constant(Variable[T]): - value: float | int - place: int - name: str - suffix: str - decimals: int - - def __post_init__(self): - self.formatted = f"{self.name}={unit_formatter(self.suffix, self.decimals)(self.value)}" - - def __call__(self, _: int, local: bool = False) -> T: - return self.value - - def __len__(self) -> int: - return 1 - - def local(self, _: int) -> T: - return self.value - - def format(self, _: int, local: bool = False) -> str: - return self.formatted - - def parse(self, _: str) -> float: - return NotImplemented + def format(self, id: int, local: bool = False): + return f"{self.name}={self.formatter(self(id, local))}" @dataclass(unsafe_hash=True) @@ -214,70 +184,55 @@ class VariableParameter: default: float | int | None = None suffix: str = "" decimals: int = 4 - default_callable: Variable | None = field(init=False) + default_sequence: Variable | None = field(init=False) place: int | None = field(default=None, init=False) public_name: str = field(init=False) private_name: str = field(init=False) def __set_name__(self, owner: type, name: str): self.public_name = name - self.private_name = "_" + name + self.private_name = "_variable__" + name if self.default is not None: - self.default_callable = Constant( - value=self.default, - place=self.place, - name=self.public_name, - suffix=self.suffix, - decimals=self.decimals, - ) + self.default_sequence = get_sequence(self.default) else: - self.default_callable = None + self.default_sequence = None def __set__(self, instance, value): all_nums = instance.__variables_nums__ - if value is self: - if self.default_callable is None: + if self.default_sequence is None: raise self._error_no_default() - var_obj = self.default_callable - var_num = 1 - elif isinstance(value, (float, int, complex)): - var_obj = Constant(value, self.place, self.public_name, self.suffix, self.decimals) - var_num = 1 - elif isinstance(value, Sequence): - if isinstance(value[0], Sequence): - value = value[0] - var_obj = FuncVariable( - func=self.func, - value=value, - all_nums=all_nums, - place=self.place, - name=self.public_name, - suffix=self.suffix, - decimals=self.decimals, - ) - var_num = value[2] - else: - var_obj = ArrayVariable( - values=value, - place=self.place, - name=self.public_name, - suffix=self.suffix, - decimals=self.decimals, - ) - else: - raise TypeError(f"value {value!r} of type {type(value)} not recognized") + all_nums[self.place] = len(self.default_sequence) + return - all_nums[self.place] = var_num + var_obj = Variable( + name=self.public_name, + values=get_sequence(value), + place=self.place, + suffix=self.suffix, + decimals=self.decimals, + all_nums=all_nums, + ) + all_nums[self.place] = len(var_obj) instance.__dict__[self.private_name] = var_obj def __get__(self, instance, _): if instance is None: return self if self.private_name not in instance.__dict__: - if self.default_callable is None: + all_nums = instance.__variables_nums__ + if self.default_sequence is None: raise self._error_no_default() - return self.default_callable + obj = Variable( + name=self.public_name, + values=self.default_sequence, + place=self.place, + suffix=self.suffix, + decimals=self.decimals, + all_nums=all_nums, + ) + all_nums[self.place] = len(obj) + instance.__dict__[self.private_name] = obj return instance.__dict__[self.private_name] @@ -371,7 +326,7 @@ def create_filter(_vfields: dict[str, VariableParameter]): """ if any(k not in _vfields for k in kwargs): extra = set(kwargs) - set(_vfields) - raise NameError(f"class {self.__class__.__name__} has no attribute(s) {extra}") + raise AttributeError(f"class {self.__class__.__name__} has no attribute(s) {extra}") all_local_indices = { k: tuple(range(p_num)) for k, p_num in zip(_vfields, self.__variables_nums__) } @@ -411,6 +366,7 @@ def vdataclass(cls: type[T]) -> type[T]: cls.__variables_nums__ = field( init=False, repr=False, default_factory=(lambda: [None] * len(_vfields)) ) + cls.__vfields__ = _vfields cls = dataclass(cls) for i, v in enumerate(_vfields.values()): diff --git a/tests/test_variableparameters.py b/tests/test_variableparameters.py index 9e39e13..36dff65 100644 --- a/tests/test_variableparameters.py +++ b/tests/test_variableparameters.py @@ -1,27 +1,68 @@ -from dataclasses import field - +import numpy as np import pytest -from scgenerator.variableparameters import Variable, vdataclass, vfield +from scgenerator.variableparameters import ( + Variable, + debase, + get_sequence, + unit_formatter, + vdataclass, + vfield, +) -def test_vfield_identification(): - """only vfields and not normal fields count""" - - class Conf: - x: int - - with pytest.warns(UserWarning): - vdataclass(Conf) - - @vdataclass - class Conf2: - x: Variable = vfield(default=5) - - assert hasattr(Conf2, "filter") +def test_debase(): + assert debase(586, (10, 10, 10)) == (5, 8, 6) # decimal + assert debase(57, (2, 2, 2, 2, 2, 2)) == (1, 1, 1, 0, 0, 1) # binary + assert debase(38, (8, 4, 2)) == (4, 3, 0) # 38 == 4 * (4*2) + 3 * (2) + 0 * 1 + with pytest.raises(ValueError): + debase(785, (5, 1, 6)) + assert debase(785, (5, 1, 6), strict=False) == (130, 0, 5) -def test_constant(): +def test_get_sequence(): + s = get_sequence([1, 2, 3]) + assert isinstance(s, np.ndarray) and s.dtype == int + assert np.all(s == np.array([1, 2, 3])) + + s = get_sequence([[1, 2, 3]]) + assert isinstance(s, np.ndarray) and s.dtype == float + assert np.all(s == np.linspace(1, 2, 3)) + + s = get_sequence([[0, 4, 5]]) + assert isinstance(s, np.ndarray) and s.dtype == int + assert np.all(s == np.arange(5)) + + s = get_sequence([[3012, 3031, 20]]) + assert isinstance(s, np.ndarray) and s.dtype == int + assert np.all(s == np.arange(3012, 3032)) + + s = get_sequence(0) + assert isinstance(s, np.ndarray) and s.dtype == int + assert np.all(s == np.array([0])) + + s = get_sequence([[1e-12, 1e-5, 8, "geometric"]]) + assert isinstance(s, np.ndarray) and s.dtype == float + assert np.all(s == np.geomspace(1e-12, 1e-5, 8)) + + s = get_sequence([[1e-12, 1e-5, 9, "geometric0"]]) + assert isinstance(s, np.ndarray) and s.dtype == float + assert np.all(s == np.concatenate(((0,), np.geomspace(1e-12, 1e-5, 8)))) + + with pytest.raises(ValueError): + get_sequence([[1, 2, 3, "logarithm"]]) + + with pytest.raises(ValueError): + get_sequence([[1, 2, 3, 4]]) + + with pytest.raises(TypeError): + get_sequence([[0, 1, "linear"]]) + + with pytest.raises(TypeError): + get_sequence([[0, 1, int]]) + + +def test_constant_number(): """classes with constant fields don't increase size and always return the constant""" @vdataclass @@ -40,21 +81,183 @@ def test_constant(): conf.y(1) +def test_constant_list(): + """classes with constant fields don't increase size and always return the constant""" + + @vdataclass + class Conf: + x: Variable = vfield(default=[1, 2, 7]) + y: Variable = vfield(default=[2, 0]) + + conf = Conf() + assert len(conf) == 6 + assert conf.x(0) == 1 + assert conf.y(0) == 2 + assert conf.x(len(conf) - 1) == 7 + assert conf.y(len(conf) - 1) == 0 + + with pytest.raises(ValueError): + conf.x(6) + with pytest.raises(ValueError): + conf.y(6) + + def test_decoration(): """proper construction of the class""" + @vdataclass + class Conf: + x: Variable = vfield(default=[1, 2, 7]) + y: Variable = vfield(default=[2, 0]) + + assert hasattr(Conf, "filter") + assert hasattr(Conf, "__vfields__") + assert hasattr(Conf, "__len__") + assert list(Conf.__vfields__["x"].default_sequence) == [1, 2, 7] + + assert not hasattr(Conf, "__variables_nums__") + assert hasattr(Conf(), "__variables_nums__") + def test_name_error(): - """name errors are raised when accessing inexistent variable (direct, filter, ...)""" + """name/attribute errors are raised when accessing inexistent variable (direct, filter, ...)""" + + @vdataclass + class Conf: + x: Variable = vfield(default=[1, 2, 7]) + y: Variable = vfield(default=[2, 0]) + + with pytest.raises(TypeError): + Conf(z=5) + + conf = Conf(y=5) + with pytest.raises(AttributeError): + conf.filter(z=45) -def test_defaults(): - """default values are respected, of proper type and of size one""" +def test_assignment(): + """assigned values properly update the whole instance""" + + @vdataclass + class Conf: + x: Variable = vfield(default=[1, 2, 7]) + y: Variable = vfield(default=[2, 0]) + + conf = Conf(x=5, y=0) + assert len(conf) == 1 + assert conf.x(0) == 5 + assert conf.y(0) == 0 + + with pytest.raises(ValueError): + conf.x(6) + with pytest.raises(ValueError): + conf.y(6) + conf = Conf(x=[[0, 2, 5]], y=(0.1, 0.2)) + assert len(conf) == 10 + assert conf.x(0) == 0 + assert conf.x(5) == 1.0 + assert conf.y(0) == 0.1 + assert conf.y(2) == 0.1 + + with pytest.raises(IndexError): + conf.x(-5) + with pytest.raises(IndexError): + conf.x(5, local=True) + with pytest.raises(ValueError): + conf.y(10) + + conf.y = np.linspace(-1, 2, 7) + assert len(conf) == 7 * 5 + assert conf.x(13) == 0.5 + assert conf.y(13) == 2.0 + + with pytest.raises(AttributeError): + del conf.y + + conf.x = 0 + assert len(conf) == 7 + assert conf.x(5) == 0 + assert conf.y(5) == 1.5 -def test_formatting(): +def test_unit_forammter(): + fmt = unit_formatter("s", 4) + + assert fmt(5e-29) == "0s" + assert fmt(5e27) == "5000Ys" + + fmt = unit_formatter("s", 4, vmin=None) + assert fmt(5e-29) == "5e-05ys" + + +def test_param_formatting(): """formatting is always respected""" + @vdataclass + class Conf: + x: Variable = vfield(default=[1, 2, 7]) + y: Variable = vfield(default=[2, 2.456546, 0.000022305], decimals=2) + z: Variable = vfield(default=7.5e-9, decimals=3, suffix="s") + w: Variable = vfield(default=7568.4e6, decimals=4, suffix="Hz") + + conf = Conf() + assert conf.x.format(0) == "x=1" + assert conf.y.format(0) == "y=2" + assert conf.y.format(1) == "y=2.5" + assert conf.y.format(2) == "y=2.2e-05" + assert conf.z.format(0) == "z=7.5ns" + assert conf.w.format(0) == "w=7.568GHz" + + +def test_error_no_default(): + @vdataclass + class Conf: + x: Variable = vfield(default=[1, 2, 7]) + y: Variable = vfield() + + with pytest.raises(ValueError): + conf = Conf() + conf = Conf(y=5) + assert conf.x(0) == 1 + assert conf.y(0) == 5 + + +def test_filter(): + @vdataclass + class Conf: + x: Variable = vfield() + y: Variable = vfield(default=[2, 2.456546, 0.000022305], decimals=2) + z: Variable = vfield(default=7.5e-9, decimals=3, suffix="s") + w: Variable = vfield(default=7568.4e6, decimals=4, suffix="Hz") + + conf = Conf(x=(1, 2, 3)) + f = conf.filter(x=1, y=(0, 1)) + assert len(f) == 2 + assert list(f) == [3, 4] + + conf = Conf(x=[[0, 1, 56]], w=[0, 1, 2, 3, 4]) + assert len(conf) == 56 * 5 * 3 + f = conf.filter(x=slice(None, None, 5), y=1, w=(0, 1, 2)) + assert len(f) == 12 * 3 + assert all(conf.y(i) == Conf.y.default_sequence[1] for i in f) + + mat = np.random.rand(56 * 5 * 3) + assert f(mat, squeeze=False).shape == (12, 1, 1, 3) + assert f(mat).shape == (12, 3) + + mat = np.random.rand(70, 12) + with pytest.raises(IndexError): + assert f(mat, squeeze=False).shape == (12, 1, 1, 3) + assert f(mat.ravel(), squeeze=False).shape == (12, 1, 1, 3) + assert f(mat.ravel()).shape == (12, 3) + + mat = np.random.rand(71, 12) + with pytest.raises(IndexError): + assert f(mat, axis=0).shape == (12, 3) + + mat = np.random.rand(70, 12, 11) + assert f(mat.reshape(len(conf), 11), squeeze=False).shape == (12, 1, 1, 3, 11) + def test_param_parsing(): """parsing a string or path returns the correct values""" @@ -62,3 +265,7 @@ def test_param_parsing(): def test_config_parsing(): """parsing a string returns the exact id corresponding to the string/path""" + + +def test_repr(): + """print a config object correctly"""