vdataclass implemented

This commit is contained in:
Benoît Sierro
2023-10-10 11:41:38 +02:00
parent 5bf2a080e5
commit c04d6d667e
4 changed files with 349 additions and 182 deletions

1
.gitignore vendored
View File

@@ -14,6 +14,7 @@ Archive
*.png *.png
*.pdf *.pdf
__pycache__ __pycache__
.*_cache
*.egg-info *.egg-info
*sim_data* *sim_data*
tmp* tmp*

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "scgenerator" name = "scgenerator"
version = "0.3.17" version = "0.3.18"
description = "Simulate nonlinear pulse propagation in optical fibers" description = "Simulate nonlinear pulse propagation in optical fibers"
readme = "README.md" readme = "README.md"
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
@@ -38,3 +38,6 @@ line-length = 100
[tool.isort] [tool.isort]
profile = "black" profile = "black"
skip = ["__init__.py"] skip = ["__init__.py"]
[tool.pyright]
pythonVersion = "3.11"

View File

@@ -1,20 +1,49 @@
import itertools import itertools
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Iterator, Protocol, Sequence, TypeVar from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar
import numpy as np import numpy as np
T = TypeVar("T") T = TypeVar("T")
P = ParamSpec("P")
def debase(num: int, base: tuple[int, ...]) -> tuple[int, ...]: def debase(num: int, base: tuple[int, ...], strict: bool = True) -> tuple[int, ...]:
"""
decomboses a number into its digits in a variable base
Parameters
----------
num : int
number to decompose
base : tuple[int, ...]
base to use (see examples below)
strict : bool, optional
raise an error if the number so big that the first digit would be `>= base[0]` (the default)
Returns
-------
tuple[int, ...]
digits
Examples
--------
```
assert debase(586, (10, 10, 10)) == (5, 8, 6) # decimal
assert debase(57, (2, 2, 2, 2, 2, 2)) == (1, 1, 1, 0, 0, 1) # binary
assert debase(38, (8, 4, 2)) == (4, 3, 0) # 38 == 4 * (4*2) + 3 * (2) + 0 * 1
debase(785, (5, 1, 6)) # error because `strict=True`
assert debase(785, (5, 1, 6), strict=False) == (130, 0, 5) # 130 >= 5
```
"""
indices = [] indices = []
_, *base = base for i in range(len(base) - 1):
for i in range(len(base)): ind, num = divmod(num, np.prod(base[i + 1 :]))
ind, num = divmod(num, np.prod(base[i:]))
indices.append(ind) indices.append(ind)
indices.append(num) indices.append(num)
if strict and indices[0] >= base[0]:
raise ValueError(f"{num} too big for base {base!r} in strict mode")
return tuple(indices) return tuple(indices)
@@ -46,7 +75,9 @@ def int_linear(id, d_min, d_max, d_num) -> int:
return d_min + id * ((d_max - d_min) // (d_num - 1)) return d_min + id * ((d_max - d_min) // (d_num - 1))
def unit_formatter(unit: str, decimals: int = 1) -> Callable[[float | int], str]: def unit_formatter(
unit: str, decimals: int = 1, vmin: float | None = 1e-28
) -> Callable[[float | int], str]:
if not unit: if not unit:
def formatter(val): def formatter(val):
@@ -55,67 +86,58 @@ def unit_formatter(unit: str, decimals: int = 1) -> Callable[[float | int], str]
else: else:
def formatter(val): def formatter(val):
if vmin is not None and abs(val) < vmin:
return f"0{unit}"
base, true_exp = format(val, f".{3+decimals}e").split("e") base, true_exp = format(val, f".{3+decimals}e").split("e")
true_exp = int(true_exp) true_exp = int(true_exp)
exp, mult = divmod(true_exp, 3) exp, mult = divmod(true_exp, 3)
if exp < -8:
mult -= -3 * (8 + exp)
exp = -8
elif exp > 8:
mult += 3 * (exp - 8)
exp = 8
prefix = "yzafpnµm kMGTPEZY"[8 + exp] if exp else "" prefix = "yzafpnµm kMGTPEZY"[8 + exp] if exp else ""
return f"{float(base)*10**mult:.{decimals}f}{prefix}{unit}" return f"{float(base)*10**mult:.{decimals}g}{prefix}{unit}"
return formatter return formatter
class Variable(Protocol[T]): def sequence_from_specs(
place: int start: float | int, stop: float | int, num: int, kind: str = "linear"
) -> np.ndarray:
if isinstance(start, int) and isinstance(stop, int) and kind == "linear":
step = (stop - start) // (num - 1)
if step == 0:
return sequence_from_specs(float(start), float(stop), num)
return np.arange(start, stop + step, step)
if kind == "linear":
return np.linspace(start, stop, num)
elif kind == "geometric":
return np.geomspace(start, stop, num)
elif kind == "geometric0":
num -= 1
return np.concatenate(((0,), np.geomspace(start, stop, num)))
else:
raise ValueError(f"kind {kind!r} not recognized")
def __call__(self, _: int, local: bool = False) -> T:
...
def __repr__(self) -> str: def get_sequence(value) -> np.ndarray:
... if isinstance(value, (Sequence, np.ndarray)):
if isinstance(value[0], Sequence):
def __len__(self) -> int: value = sequence_from_specs(*value[0])
... return np.asanyarray(value)
else:
def format(self, id: int, local: bool = False): return np.array([value])
return f"{self.name}={self.formatter(self(id, local))}"
def parse(self, _: str) -> T:
...
def filter(self, local_id) -> tuple[list[int], tuple[int, ...]]:
"""
filter based on the desired local id.
Parameters
----------
local_id : int
index of the desired value
Returns
-------
list[int]
list of all global indices point to instances of this variable parameter that have the
same value
tuple[int, ...]
resulting shape formed by all other variable parameters
"""
shape = self.all_nums.copy()
shape[self.place] = 1
indices = [
i
for i, indices in enumerate(itertools.product(*(range(k) for k in self.all_nums)))
if indices[self.place] == local_id
]
return indices, tuple(shape)
@dataclass @dataclass
class FuncVariable(Variable[float]): class Variable(Generic[T]):
func: Callable[[int, float, float, int], float]
args: tuple[float, float, int]
all_nums: list[int]
place: int
name: str name: str
values: Sequence[T]
place: int
all_nums: list[int]
suffix: str = "" suffix: str = ""
decimals: int = 4 decimals: int = 4
formatter: Callable[[int], str] = field(init=False) formatter: Callable[[int], str] = field(init=False)
@@ -130,82 +152,30 @@ class FuncVariable(Variable[float]):
def __post_init__(self): def __post_init__(self):
self.formatter = unit_formatter(self.suffix, self.decimals) self.formatter = unit_formatter(self.suffix, self.decimals)
def __call__(self, id: int, local: bool = False) -> float: def __call__(self, id: int, local: bool = False) -> T:
if id < 0:
raise IndexError("negative indices are not allowed")
if not local: if not local:
if not isinstance(id, (int, np.integer)): if not isinstance(id, (int, np.integer)):
raise TypeError(f"id {id!r} is not an integer") raise TypeError(f"id {id!r} is not an integer")
if None in self.all_nums: if None in self.all_nums:
raise ValueError("at least one VariableParameter has not been configured") raise ValueError("at least one VariableParameter has not been configured")
id = debase(id, self.all_nums)[self.place] id = debase(id, self.all_nums)[self.place]
if id >= self.args[2]: if id >= len(self):
raise ValueError( raise IndexError(f"id {id} is too big for variable parameter of size {len(self)}")
f"id {id} is too big for variable parameter of size {self.args[2]}" return self.values[id]
)
return self.func(id, *self.args)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}({', '.join(format(el) for el in self.args)})" return f"{self.__class__.__name__}({self.values})"
def __len__(self) -> int: def __len__(self) -> int:
return self.args[2] return len(self.values)
def parse(self, s: str): def parse(self, s: str):
return NotImplemented return NotImplemented
def format(self, id: int, local: bool = False):
@dataclass return f"{self.name}={self.formatter(self(id, local))}"
class ArrayVariable(Variable[T]):
values: Sequence[T]
all_nums: list[int]
place: int
name: str
suffix: str
decimals: int
def __call__(self, id: int, local: bool = False) -> T:
if not local:
id = debase(id, self.all_nums)[self.place]
if id >= len(self.values):
raise ValueError(
f"id {id} is too big for variable parameter of size {len(self.values)}"
)
return self.values[id]
def __len__(self) -> int:
return 1
def format(self, _: int, local: bool = False) -> str:
return self.formatted
def parse(self, _: str) -> float:
return NotImplemented
@dataclass
class Constant(Variable[T]):
value: float | int
place: int
name: str
suffix: str
decimals: int
def __post_init__(self):
self.formatted = f"{self.name}={unit_formatter(self.suffix, self.decimals)(self.value)}"
def __call__(self, _: int, local: bool = False) -> T:
return self.value
def __len__(self) -> int:
return 1
def local(self, _: int) -> T:
return self.value
def format(self, _: int, local: bool = False) -> str:
return self.formatted
def parse(self, _: str) -> float:
return NotImplemented
@dataclass(unsafe_hash=True) @dataclass(unsafe_hash=True)
@@ -214,70 +184,55 @@ class VariableParameter:
default: float | int | None = None default: float | int | None = None
suffix: str = "" suffix: str = ""
decimals: int = 4 decimals: int = 4
default_callable: Variable | None = field(init=False) default_sequence: Variable | None = field(init=False)
place: int | None = field(default=None, init=False) place: int | None = field(default=None, init=False)
public_name: str = field(init=False) public_name: str = field(init=False)
private_name: str = field(init=False) private_name: str = field(init=False)
def __set_name__(self, owner: type, name: str): def __set_name__(self, owner: type, name: str):
self.public_name = name self.public_name = name
self.private_name = "_" + name self.private_name = "_variable__" + name
if self.default is not None: if self.default is not None:
self.default_callable = Constant( self.default_sequence = get_sequence(self.default)
value=self.default,
place=self.place,
name=self.public_name,
suffix=self.suffix,
decimals=self.decimals,
)
else: else:
self.default_callable = None self.default_sequence = None
def __set__(self, instance, value): def __set__(self, instance, value):
all_nums = instance.__variables_nums__ all_nums = instance.__variables_nums__
if value is self: if value is self:
if self.default_callable is None: if self.default_sequence is None:
raise self._error_no_default() raise self._error_no_default()
var_obj = self.default_callable all_nums[self.place] = len(self.default_sequence)
var_num = 1 return
elif isinstance(value, (float, int, complex)):
var_obj = Constant(value, self.place, self.public_name, self.suffix, self.decimals)
var_num = 1
elif isinstance(value, Sequence):
if isinstance(value[0], Sequence):
value = value[0]
var_obj = FuncVariable(
func=self.func,
value=value,
all_nums=all_nums,
place=self.place,
name=self.public_name,
suffix=self.suffix,
decimals=self.decimals,
)
var_num = value[2]
else:
var_obj = ArrayVariable(
values=value,
place=self.place,
name=self.public_name,
suffix=self.suffix,
decimals=self.decimals,
)
else:
raise TypeError(f"value {value!r} of type {type(value)} not recognized")
all_nums[self.place] = var_num var_obj = Variable(
name=self.public_name,
values=get_sequence(value),
place=self.place,
suffix=self.suffix,
decimals=self.decimals,
all_nums=all_nums,
)
all_nums[self.place] = len(var_obj)
instance.__dict__[self.private_name] = var_obj instance.__dict__[self.private_name] = var_obj
def __get__(self, instance, _): def __get__(self, instance, _):
if instance is None: if instance is None:
return self return self
if self.private_name not in instance.__dict__: if self.private_name not in instance.__dict__:
if self.default_callable is None: all_nums = instance.__variables_nums__
if self.default_sequence is None:
raise self._error_no_default() raise self._error_no_default()
return self.default_callable obj = Variable(
name=self.public_name,
values=self.default_sequence,
place=self.place,
suffix=self.suffix,
decimals=self.decimals,
all_nums=all_nums,
)
all_nums[self.place] = len(obj)
instance.__dict__[self.private_name] = obj
return instance.__dict__[self.private_name] return instance.__dict__[self.private_name]
@@ -371,7 +326,7 @@ def create_filter(_vfields: dict[str, VariableParameter]):
""" """
if any(k not in _vfields for k in kwargs): if any(k not in _vfields for k in kwargs):
extra = set(kwargs) - set(_vfields) extra = set(kwargs) - set(_vfields)
raise NameError(f"class {self.__class__.__name__} has no attribute(s) {extra}") raise AttributeError(f"class {self.__class__.__name__} has no attribute(s) {extra}")
all_local_indices = { all_local_indices = {
k: tuple(range(p_num)) for k, p_num in zip(_vfields, self.__variables_nums__) k: tuple(range(p_num)) for k, p_num in zip(_vfields, self.__variables_nums__)
} }
@@ -411,6 +366,7 @@ def vdataclass(cls: type[T]) -> type[T]:
cls.__variables_nums__ = field( cls.__variables_nums__ = field(
init=False, repr=False, default_factory=(lambda: [None] * len(_vfields)) init=False, repr=False, default_factory=(lambda: [None] * len(_vfields))
) )
cls.__vfields__ = _vfields
cls = dataclass(cls) cls = dataclass(cls)
for i, v in enumerate(_vfields.values()): for i, v in enumerate(_vfields.values()):

View File

@@ -1,27 +1,68 @@
from dataclasses import field import numpy as np
import pytest import pytest
from scgenerator.variableparameters import Variable, vdataclass, vfield from scgenerator.variableparameters import (
Variable,
debase,
get_sequence,
unit_formatter,
vdataclass,
vfield,
)
def test_vfield_identification(): def test_debase():
"""only vfields and not normal fields count""" assert debase(586, (10, 10, 10)) == (5, 8, 6) # decimal
assert debase(57, (2, 2, 2, 2, 2, 2)) == (1, 1, 1, 0, 0, 1) # binary
class Conf: assert debase(38, (8, 4, 2)) == (4, 3, 0) # 38 == 4 * (4*2) + 3 * (2) + 0 * 1
x: int with pytest.raises(ValueError):
debase(785, (5, 1, 6))
with pytest.warns(UserWarning): assert debase(785, (5, 1, 6), strict=False) == (130, 0, 5)
vdataclass(Conf)
@vdataclass
class Conf2:
x: Variable = vfield(default=5)
assert hasattr(Conf2, "filter")
def test_constant(): def test_get_sequence():
s = get_sequence([1, 2, 3])
assert isinstance(s, np.ndarray) and s.dtype == int
assert np.all(s == np.array([1, 2, 3]))
s = get_sequence([[1, 2, 3]])
assert isinstance(s, np.ndarray) and s.dtype == float
assert np.all(s == np.linspace(1, 2, 3))
s = get_sequence([[0, 4, 5]])
assert isinstance(s, np.ndarray) and s.dtype == int
assert np.all(s == np.arange(5))
s = get_sequence([[3012, 3031, 20]])
assert isinstance(s, np.ndarray) and s.dtype == int
assert np.all(s == np.arange(3012, 3032))
s = get_sequence(0)
assert isinstance(s, np.ndarray) and s.dtype == int
assert np.all(s == np.array([0]))
s = get_sequence([[1e-12, 1e-5, 8, "geometric"]])
assert isinstance(s, np.ndarray) and s.dtype == float
assert np.all(s == np.geomspace(1e-12, 1e-5, 8))
s = get_sequence([[1e-12, 1e-5, 9, "geometric0"]])
assert isinstance(s, np.ndarray) and s.dtype == float
assert np.all(s == np.concatenate(((0,), np.geomspace(1e-12, 1e-5, 8))))
with pytest.raises(ValueError):
get_sequence([[1, 2, 3, "logarithm"]])
with pytest.raises(ValueError):
get_sequence([[1, 2, 3, 4]])
with pytest.raises(TypeError):
get_sequence([[0, 1, "linear"]])
with pytest.raises(TypeError):
get_sequence([[0, 1, int]])
def test_constant_number():
"""classes with constant fields don't increase size and always return the constant""" """classes with constant fields don't increase size and always return the constant"""
@vdataclass @vdataclass
@@ -40,21 +81,183 @@ def test_constant():
conf.y(1) conf.y(1)
def test_constant_list():
"""classes with constant fields don't increase size and always return the constant"""
@vdataclass
class Conf:
x: Variable = vfield(default=[1, 2, 7])
y: Variable = vfield(default=[2, 0])
conf = Conf()
assert len(conf) == 6
assert conf.x(0) == 1
assert conf.y(0) == 2
assert conf.x(len(conf) - 1) == 7
assert conf.y(len(conf) - 1) == 0
with pytest.raises(ValueError):
conf.x(6)
with pytest.raises(ValueError):
conf.y(6)
def test_decoration(): def test_decoration():
"""proper construction of the class""" """proper construction of the class"""
@vdataclass
class Conf:
x: Variable = vfield(default=[1, 2, 7])
y: Variable = vfield(default=[2, 0])
assert hasattr(Conf, "filter")
assert hasattr(Conf, "__vfields__")
assert hasattr(Conf, "__len__")
assert list(Conf.__vfields__["x"].default_sequence) == [1, 2, 7]
assert not hasattr(Conf, "__variables_nums__")
assert hasattr(Conf(), "__variables_nums__")
def test_name_error(): def test_name_error():
"""name errors are raised when accessing inexistent variable (direct, filter, ...)""" """name/attribute errors are raised when accessing inexistent variable (direct, filter, ...)"""
@vdataclass
class Conf:
x: Variable = vfield(default=[1, 2, 7])
y: Variable = vfield(default=[2, 0])
with pytest.raises(TypeError):
Conf(z=5)
conf = Conf(y=5)
with pytest.raises(AttributeError):
conf.filter(z=45)
def test_defaults(): def test_assignment():
"""default values are respected, of proper type and of size one""" """assigned values properly update the whole instance"""
@vdataclass
class Conf:
x: Variable = vfield(default=[1, 2, 7])
y: Variable = vfield(default=[2, 0])
conf = Conf(x=5, y=0)
assert len(conf) == 1
assert conf.x(0) == 5
assert conf.y(0) == 0
with pytest.raises(ValueError):
conf.x(6)
with pytest.raises(ValueError):
conf.y(6)
conf = Conf(x=[[0, 2, 5]], y=(0.1, 0.2))
assert len(conf) == 10
assert conf.x(0) == 0
assert conf.x(5) == 1.0
assert conf.y(0) == 0.1
assert conf.y(2) == 0.1
with pytest.raises(IndexError):
conf.x(-5)
with pytest.raises(IndexError):
conf.x(5, local=True)
with pytest.raises(ValueError):
conf.y(10)
conf.y = np.linspace(-1, 2, 7)
assert len(conf) == 7 * 5
assert conf.x(13) == 0.5
assert conf.y(13) == 2.0
with pytest.raises(AttributeError):
del conf.y
conf.x = 0
assert len(conf) == 7
assert conf.x(5) == 0
assert conf.y(5) == 1.5
def test_formatting(): def test_unit_forammter():
fmt = unit_formatter("s", 4)
assert fmt(5e-29) == "0s"
assert fmt(5e27) == "5000Ys"
fmt = unit_formatter("s", 4, vmin=None)
assert fmt(5e-29) == "5e-05ys"
def test_param_formatting():
"""formatting is always respected""" """formatting is always respected"""
@vdataclass
class Conf:
x: Variable = vfield(default=[1, 2, 7])
y: Variable = vfield(default=[2, 2.456546, 0.000022305], decimals=2)
z: Variable = vfield(default=7.5e-9, decimals=3, suffix="s")
w: Variable = vfield(default=7568.4e6, decimals=4, suffix="Hz")
conf = Conf()
assert conf.x.format(0) == "x=1"
assert conf.y.format(0) == "y=2"
assert conf.y.format(1) == "y=2.5"
assert conf.y.format(2) == "y=2.2e-05"
assert conf.z.format(0) == "z=7.5ns"
assert conf.w.format(0) == "w=7.568GHz"
def test_error_no_default():
@vdataclass
class Conf:
x: Variable = vfield(default=[1, 2, 7])
y: Variable = vfield()
with pytest.raises(ValueError):
conf = Conf()
conf = Conf(y=5)
assert conf.x(0) == 1
assert conf.y(0) == 5
def test_filter():
@vdataclass
class Conf:
x: Variable = vfield()
y: Variable = vfield(default=[2, 2.456546, 0.000022305], decimals=2)
z: Variable = vfield(default=7.5e-9, decimals=3, suffix="s")
w: Variable = vfield(default=7568.4e6, decimals=4, suffix="Hz")
conf = Conf(x=(1, 2, 3))
f = conf.filter(x=1, y=(0, 1))
assert len(f) == 2
assert list(f) == [3, 4]
conf = Conf(x=[[0, 1, 56]], w=[0, 1, 2, 3, 4])
assert len(conf) == 56 * 5 * 3
f = conf.filter(x=slice(None, None, 5), y=1, w=(0, 1, 2))
assert len(f) == 12 * 3
assert all(conf.y(i) == Conf.y.default_sequence[1] for i in f)
mat = np.random.rand(56 * 5 * 3)
assert f(mat, squeeze=False).shape == (12, 1, 1, 3)
assert f(mat).shape == (12, 3)
mat = np.random.rand(70, 12)
with pytest.raises(IndexError):
assert f(mat, squeeze=False).shape == (12, 1, 1, 3)
assert f(mat.ravel(), squeeze=False).shape == (12, 1, 1, 3)
assert f(mat.ravel()).shape == (12, 3)
mat = np.random.rand(71, 12)
with pytest.raises(IndexError):
assert f(mat, axis=0).shape == (12, 3)
mat = np.random.rand(70, 12, 11)
assert f(mat.reshape(len(conf), 11), squeeze=False).shape == (12, 1, 1, 3, 11)
def test_param_parsing(): def test_param_parsing():
"""parsing a string or path returns the correct values""" """parsing a string or path returns the correct values"""
@@ -62,3 +265,7 @@ def test_param_parsing():
def test_config_parsing(): def test_config_parsing():
"""parsing a string returns the exact id corresponding to the string/path""" """parsing a string returns the exact id corresponding to the string/path"""
def test_repr():
"""print a config object correctly"""