vdataclass implemented
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -14,6 +14,7 @@ Archive
|
|||||||
*.png
|
*.png
|
||||||
*.pdf
|
*.pdf
|
||||||
__pycache__
|
__pycache__
|
||||||
|
.*_cache
|
||||||
*.egg-info
|
*.egg-info
|
||||||
*sim_data*
|
*sim_data*
|
||||||
tmp*
|
tmp*
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "scgenerator"
|
name = "scgenerator"
|
||||||
version = "0.3.17"
|
version = "0.3.18"
|
||||||
description = "Simulate nonlinear pulse propagation in optical fibers"
|
description = "Simulate nonlinear pulse propagation in optical fibers"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
||||||
@@ -38,3 +38,6 @@ line-length = 100
|
|||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
skip = ["__init__.py"]
|
skip = ["__init__.py"]
|
||||||
|
|
||||||
|
[tool.pyright]
|
||||||
|
pythonVersion = "3.11"
|
||||||
|
|||||||
@@ -1,20 +1,49 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Iterator, Protocol, Sequence, TypeVar
|
from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
def debase(num: int, base: tuple[int, ...]) -> tuple[int, ...]:
|
def debase(num: int, base: tuple[int, ...], strict: bool = True) -> tuple[int, ...]:
|
||||||
|
"""
|
||||||
|
decomboses a number into its digits in a variable base
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
num : int
|
||||||
|
number to decompose
|
||||||
|
base : tuple[int, ...]
|
||||||
|
base to use (see examples below)
|
||||||
|
strict : bool, optional
|
||||||
|
raise an error if the number so big that the first digit would be `>= base[0]` (the default)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[int, ...]
|
||||||
|
digits
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
```
|
||||||
|
assert debase(586, (10, 10, 10)) == (5, 8, 6) # decimal
|
||||||
|
assert debase(57, (2, 2, 2, 2, 2, 2)) == (1, 1, 1, 0, 0, 1) # binary
|
||||||
|
assert debase(38, (8, 4, 2)) == (4, 3, 0) # 38 == 4 * (4*2) + 3 * (2) + 0 * 1
|
||||||
|
debase(785, (5, 1, 6)) # error because `strict=True`
|
||||||
|
assert debase(785, (5, 1, 6), strict=False) == (130, 0, 5) # 130 >= 5
|
||||||
|
```
|
||||||
|
"""
|
||||||
indices = []
|
indices = []
|
||||||
_, *base = base
|
for i in range(len(base) - 1):
|
||||||
for i in range(len(base)):
|
ind, num = divmod(num, np.prod(base[i + 1 :]))
|
||||||
ind, num = divmod(num, np.prod(base[i:]))
|
|
||||||
indices.append(ind)
|
indices.append(ind)
|
||||||
indices.append(num)
|
indices.append(num)
|
||||||
|
if strict and indices[0] >= base[0]:
|
||||||
|
raise ValueError(f"{num} too big for base {base!r} in strict mode")
|
||||||
return tuple(indices)
|
return tuple(indices)
|
||||||
|
|
||||||
|
|
||||||
@@ -46,7 +75,9 @@ def int_linear(id, d_min, d_max, d_num) -> int:
|
|||||||
return d_min + id * ((d_max - d_min) // (d_num - 1))
|
return d_min + id * ((d_max - d_min) // (d_num - 1))
|
||||||
|
|
||||||
|
|
||||||
def unit_formatter(unit: str, decimals: int = 1) -> Callable[[float | int], str]:
|
def unit_formatter(
|
||||||
|
unit: str, decimals: int = 1, vmin: float | None = 1e-28
|
||||||
|
) -> Callable[[float | int], str]:
|
||||||
if not unit:
|
if not unit:
|
||||||
|
|
||||||
def formatter(val):
|
def formatter(val):
|
||||||
@@ -55,67 +86,58 @@ def unit_formatter(unit: str, decimals: int = 1) -> Callable[[float | int], str]
|
|||||||
else:
|
else:
|
||||||
|
|
||||||
def formatter(val):
|
def formatter(val):
|
||||||
|
if vmin is not None and abs(val) < vmin:
|
||||||
|
return f"0{unit}"
|
||||||
base, true_exp = format(val, f".{3+decimals}e").split("e")
|
base, true_exp = format(val, f".{3+decimals}e").split("e")
|
||||||
true_exp = int(true_exp)
|
true_exp = int(true_exp)
|
||||||
exp, mult = divmod(true_exp, 3)
|
exp, mult = divmod(true_exp, 3)
|
||||||
|
if exp < -8:
|
||||||
|
mult -= -3 * (8 + exp)
|
||||||
|
exp = -8
|
||||||
|
elif exp > 8:
|
||||||
|
mult += 3 * (exp - 8)
|
||||||
|
exp = 8
|
||||||
|
|
||||||
prefix = "yzafpnµm kMGTPEZY"[8 + exp] if exp else ""
|
prefix = "yzafpnµm kMGTPEZY"[8 + exp] if exp else ""
|
||||||
return f"{float(base)*10**mult:.{decimals}f}{prefix}{unit}"
|
return f"{float(base)*10**mult:.{decimals}g}{prefix}{unit}"
|
||||||
|
|
||||||
return formatter
|
return formatter
|
||||||
|
|
||||||
|
|
||||||
class Variable(Protocol[T]):
|
def sequence_from_specs(
|
||||||
place: int
|
start: float | int, stop: float | int, num: int, kind: str = "linear"
|
||||||
|
) -> np.ndarray:
|
||||||
|
if isinstance(start, int) and isinstance(stop, int) and kind == "linear":
|
||||||
|
step = (stop - start) // (num - 1)
|
||||||
|
if step == 0:
|
||||||
|
return sequence_from_specs(float(start), float(stop), num)
|
||||||
|
return np.arange(start, stop + step, step)
|
||||||
|
if kind == "linear":
|
||||||
|
return np.linspace(start, stop, num)
|
||||||
|
elif kind == "geometric":
|
||||||
|
return np.geomspace(start, stop, num)
|
||||||
|
elif kind == "geometric0":
|
||||||
|
num -= 1
|
||||||
|
return np.concatenate(((0,), np.geomspace(start, stop, num)))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"kind {kind!r} not recognized")
|
||||||
|
|
||||||
def __call__(self, _: int, local: bool = False) -> T:
|
|
||||||
...
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def get_sequence(value) -> np.ndarray:
|
||||||
...
|
if isinstance(value, (Sequence, np.ndarray)):
|
||||||
|
if isinstance(value[0], Sequence):
|
||||||
def __len__(self) -> int:
|
value = sequence_from_specs(*value[0])
|
||||||
...
|
return np.asanyarray(value)
|
||||||
|
else:
|
||||||
def format(self, id: int, local: bool = False):
|
return np.array([value])
|
||||||
return f"{self.name}={self.formatter(self(id, local))}"
|
|
||||||
|
|
||||||
def parse(self, _: str) -> T:
|
|
||||||
...
|
|
||||||
|
|
||||||
def filter(self, local_id) -> tuple[list[int], tuple[int, ...]]:
|
|
||||||
"""
|
|
||||||
filter based on the desired local id.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
local_id : int
|
|
||||||
index of the desired value
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[int]
|
|
||||||
list of all global indices point to instances of this variable parameter that have the
|
|
||||||
same value
|
|
||||||
tuple[int, ...]
|
|
||||||
resulting shape formed by all other variable parameters
|
|
||||||
"""
|
|
||||||
shape = self.all_nums.copy()
|
|
||||||
shape[self.place] = 1
|
|
||||||
indices = [
|
|
||||||
i
|
|
||||||
for i, indices in enumerate(itertools.product(*(range(k) for k in self.all_nums)))
|
|
||||||
if indices[self.place] == local_id
|
|
||||||
]
|
|
||||||
return indices, tuple(shape)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FuncVariable(Variable[float]):
|
class Variable(Generic[T]):
|
||||||
func: Callable[[int, float, float, int], float]
|
|
||||||
args: tuple[float, float, int]
|
|
||||||
all_nums: list[int]
|
|
||||||
place: int
|
|
||||||
name: str
|
name: str
|
||||||
|
values: Sequence[T]
|
||||||
|
place: int
|
||||||
|
all_nums: list[int]
|
||||||
suffix: str = ""
|
suffix: str = ""
|
||||||
decimals: int = 4
|
decimals: int = 4
|
||||||
formatter: Callable[[int], str] = field(init=False)
|
formatter: Callable[[int], str] = field(init=False)
|
||||||
@@ -130,82 +152,30 @@ class FuncVariable(Variable[float]):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.formatter = unit_formatter(self.suffix, self.decimals)
|
self.formatter = unit_formatter(self.suffix, self.decimals)
|
||||||
|
|
||||||
def __call__(self, id: int, local: bool = False) -> float:
|
def __call__(self, id: int, local: bool = False) -> T:
|
||||||
|
if id < 0:
|
||||||
|
raise IndexError("negative indices are not allowed")
|
||||||
if not local:
|
if not local:
|
||||||
if not isinstance(id, (int, np.integer)):
|
if not isinstance(id, (int, np.integer)):
|
||||||
raise TypeError(f"id {id!r} is not an integer")
|
raise TypeError(f"id {id!r} is not an integer")
|
||||||
if None in self.all_nums:
|
if None in self.all_nums:
|
||||||
raise ValueError("at least one VariableParameter has not been configured")
|
raise ValueError("at least one VariableParameter has not been configured")
|
||||||
id = debase(id, self.all_nums)[self.place]
|
id = debase(id, self.all_nums)[self.place]
|
||||||
if id >= self.args[2]:
|
if id >= len(self):
|
||||||
raise ValueError(
|
raise IndexError(f"id {id} is too big for variable parameter of size {len(self)}")
|
||||||
f"id {id} is too big for variable parameter of size {self.args[2]}"
|
return self.values[id]
|
||||||
)
|
|
||||||
return self.func(id, *self.args)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"{self.__class__.__name__}({', '.join(format(el) for el in self.args)})"
|
return f"{self.__class__.__name__}({self.values})"
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.args[2]
|
return len(self.values)
|
||||||
|
|
||||||
def parse(self, s: str):
|
def parse(self, s: str):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
|
def format(self, id: int, local: bool = False):
|
||||||
@dataclass
|
return f"{self.name}={self.formatter(self(id, local))}"
|
||||||
class ArrayVariable(Variable[T]):
|
|
||||||
values: Sequence[T]
|
|
||||||
all_nums: list[int]
|
|
||||||
place: int
|
|
||||||
name: str
|
|
||||||
suffix: str
|
|
||||||
decimals: int
|
|
||||||
|
|
||||||
def __call__(self, id: int, local: bool = False) -> T:
|
|
||||||
if not local:
|
|
||||||
id = debase(id, self.all_nums)[self.place]
|
|
||||||
if id >= len(self.values):
|
|
||||||
raise ValueError(
|
|
||||||
f"id {id} is too big for variable parameter of size {len(self.values)}"
|
|
||||||
)
|
|
||||||
return self.values[id]
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def format(self, _: int, local: bool = False) -> str:
|
|
||||||
return self.formatted
|
|
||||||
|
|
||||||
def parse(self, _: str) -> float:
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Constant(Variable[T]):
|
|
||||||
value: float | int
|
|
||||||
place: int
|
|
||||||
name: str
|
|
||||||
suffix: str
|
|
||||||
decimals: int
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self.formatted = f"{self.name}={unit_formatter(self.suffix, self.decimals)(self.value)}"
|
|
||||||
|
|
||||||
def __call__(self, _: int, local: bool = False) -> T:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def local(self, _: int) -> T:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
def format(self, _: int, local: bool = False) -> str:
|
|
||||||
return self.formatted
|
|
||||||
|
|
||||||
def parse(self, _: str) -> float:
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(unsafe_hash=True)
|
@dataclass(unsafe_hash=True)
|
||||||
@@ -214,70 +184,55 @@ class VariableParameter:
|
|||||||
default: float | int | None = None
|
default: float | int | None = None
|
||||||
suffix: str = ""
|
suffix: str = ""
|
||||||
decimals: int = 4
|
decimals: int = 4
|
||||||
default_callable: Variable | None = field(init=False)
|
default_sequence: Variable | None = field(init=False)
|
||||||
place: int | None = field(default=None, init=False)
|
place: int | None = field(default=None, init=False)
|
||||||
public_name: str = field(init=False)
|
public_name: str = field(init=False)
|
||||||
private_name: str = field(init=False)
|
private_name: str = field(init=False)
|
||||||
|
|
||||||
def __set_name__(self, owner: type, name: str):
|
def __set_name__(self, owner: type, name: str):
|
||||||
self.public_name = name
|
self.public_name = name
|
||||||
self.private_name = "_" + name
|
self.private_name = "_variable__" + name
|
||||||
if self.default is not None:
|
if self.default is not None:
|
||||||
self.default_callable = Constant(
|
self.default_sequence = get_sequence(self.default)
|
||||||
value=self.default,
|
|
||||||
place=self.place,
|
|
||||||
name=self.public_name,
|
|
||||||
suffix=self.suffix,
|
|
||||||
decimals=self.decimals,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.default_callable = None
|
self.default_sequence = None
|
||||||
|
|
||||||
def __set__(self, instance, value):
|
def __set__(self, instance, value):
|
||||||
all_nums = instance.__variables_nums__
|
all_nums = instance.__variables_nums__
|
||||||
|
|
||||||
if value is self:
|
if value is self:
|
||||||
if self.default_callable is None:
|
if self.default_sequence is None:
|
||||||
raise self._error_no_default()
|
raise self._error_no_default()
|
||||||
var_obj = self.default_callable
|
all_nums[self.place] = len(self.default_sequence)
|
||||||
var_num = 1
|
return
|
||||||
elif isinstance(value, (float, int, complex)):
|
|
||||||
var_obj = Constant(value, self.place, self.public_name, self.suffix, self.decimals)
|
|
||||||
var_num = 1
|
|
||||||
elif isinstance(value, Sequence):
|
|
||||||
if isinstance(value[0], Sequence):
|
|
||||||
value = value[0]
|
|
||||||
var_obj = FuncVariable(
|
|
||||||
func=self.func,
|
|
||||||
value=value,
|
|
||||||
all_nums=all_nums,
|
|
||||||
place=self.place,
|
|
||||||
name=self.public_name,
|
|
||||||
suffix=self.suffix,
|
|
||||||
decimals=self.decimals,
|
|
||||||
)
|
|
||||||
var_num = value[2]
|
|
||||||
else:
|
|
||||||
var_obj = ArrayVariable(
|
|
||||||
values=value,
|
|
||||||
place=self.place,
|
|
||||||
name=self.public_name,
|
|
||||||
suffix=self.suffix,
|
|
||||||
decimals=self.decimals,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise TypeError(f"value {value!r} of type {type(value)} not recognized")
|
|
||||||
|
|
||||||
all_nums[self.place] = var_num
|
var_obj = Variable(
|
||||||
|
name=self.public_name,
|
||||||
|
values=get_sequence(value),
|
||||||
|
place=self.place,
|
||||||
|
suffix=self.suffix,
|
||||||
|
decimals=self.decimals,
|
||||||
|
all_nums=all_nums,
|
||||||
|
)
|
||||||
|
all_nums[self.place] = len(var_obj)
|
||||||
instance.__dict__[self.private_name] = var_obj
|
instance.__dict__[self.private_name] = var_obj
|
||||||
|
|
||||||
def __get__(self, instance, _):
|
def __get__(self, instance, _):
|
||||||
if instance is None:
|
if instance is None:
|
||||||
return self
|
return self
|
||||||
if self.private_name not in instance.__dict__:
|
if self.private_name not in instance.__dict__:
|
||||||
if self.default_callable is None:
|
all_nums = instance.__variables_nums__
|
||||||
|
if self.default_sequence is None:
|
||||||
raise self._error_no_default()
|
raise self._error_no_default()
|
||||||
return self.default_callable
|
obj = Variable(
|
||||||
|
name=self.public_name,
|
||||||
|
values=self.default_sequence,
|
||||||
|
place=self.place,
|
||||||
|
suffix=self.suffix,
|
||||||
|
decimals=self.decimals,
|
||||||
|
all_nums=all_nums,
|
||||||
|
)
|
||||||
|
all_nums[self.place] = len(obj)
|
||||||
|
instance.__dict__[self.private_name] = obj
|
||||||
|
|
||||||
return instance.__dict__[self.private_name]
|
return instance.__dict__[self.private_name]
|
||||||
|
|
||||||
@@ -371,7 +326,7 @@ def create_filter(_vfields: dict[str, VariableParameter]):
|
|||||||
"""
|
"""
|
||||||
if any(k not in _vfields for k in kwargs):
|
if any(k not in _vfields for k in kwargs):
|
||||||
extra = set(kwargs) - set(_vfields)
|
extra = set(kwargs) - set(_vfields)
|
||||||
raise NameError(f"class {self.__class__.__name__} has no attribute(s) {extra}")
|
raise AttributeError(f"class {self.__class__.__name__} has no attribute(s) {extra}")
|
||||||
all_local_indices = {
|
all_local_indices = {
|
||||||
k: tuple(range(p_num)) for k, p_num in zip(_vfields, self.__variables_nums__)
|
k: tuple(range(p_num)) for k, p_num in zip(_vfields, self.__variables_nums__)
|
||||||
}
|
}
|
||||||
@@ -411,6 +366,7 @@ def vdataclass(cls: type[T]) -> type[T]:
|
|||||||
cls.__variables_nums__ = field(
|
cls.__variables_nums__ = field(
|
||||||
init=False, repr=False, default_factory=(lambda: [None] * len(_vfields))
|
init=False, repr=False, default_factory=(lambda: [None] * len(_vfields))
|
||||||
)
|
)
|
||||||
|
cls.__vfields__ = _vfields
|
||||||
|
|
||||||
cls = dataclass(cls)
|
cls = dataclass(cls)
|
||||||
for i, v in enumerate(_vfields.values()):
|
for i, v in enumerate(_vfields.values()):
|
||||||
|
|||||||
@@ -1,27 +1,68 @@
|
|||||||
from dataclasses import field
|
import numpy as np
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from scgenerator.variableparameters import Variable, vdataclass, vfield
|
from scgenerator.variableparameters import (
|
||||||
|
Variable,
|
||||||
|
debase,
|
||||||
|
get_sequence,
|
||||||
|
unit_formatter,
|
||||||
|
vdataclass,
|
||||||
|
vfield,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_vfield_identification():
|
def test_debase():
|
||||||
"""only vfields and not normal fields count"""
|
assert debase(586, (10, 10, 10)) == (5, 8, 6) # decimal
|
||||||
|
assert debase(57, (2, 2, 2, 2, 2, 2)) == (1, 1, 1, 0, 0, 1) # binary
|
||||||
class Conf:
|
assert debase(38, (8, 4, 2)) == (4, 3, 0) # 38 == 4 * (4*2) + 3 * (2) + 0 * 1
|
||||||
x: int
|
with pytest.raises(ValueError):
|
||||||
|
debase(785, (5, 1, 6))
|
||||||
with pytest.warns(UserWarning):
|
assert debase(785, (5, 1, 6), strict=False) == (130, 0, 5)
|
||||||
vdataclass(Conf)
|
|
||||||
|
|
||||||
@vdataclass
|
|
||||||
class Conf2:
|
|
||||||
x: Variable = vfield(default=5)
|
|
||||||
|
|
||||||
assert hasattr(Conf2, "filter")
|
|
||||||
|
|
||||||
|
|
||||||
def test_constant():
|
def test_get_sequence():
|
||||||
|
s = get_sequence([1, 2, 3])
|
||||||
|
assert isinstance(s, np.ndarray) and s.dtype == int
|
||||||
|
assert np.all(s == np.array([1, 2, 3]))
|
||||||
|
|
||||||
|
s = get_sequence([[1, 2, 3]])
|
||||||
|
assert isinstance(s, np.ndarray) and s.dtype == float
|
||||||
|
assert np.all(s == np.linspace(1, 2, 3))
|
||||||
|
|
||||||
|
s = get_sequence([[0, 4, 5]])
|
||||||
|
assert isinstance(s, np.ndarray) and s.dtype == int
|
||||||
|
assert np.all(s == np.arange(5))
|
||||||
|
|
||||||
|
s = get_sequence([[3012, 3031, 20]])
|
||||||
|
assert isinstance(s, np.ndarray) and s.dtype == int
|
||||||
|
assert np.all(s == np.arange(3012, 3032))
|
||||||
|
|
||||||
|
s = get_sequence(0)
|
||||||
|
assert isinstance(s, np.ndarray) and s.dtype == int
|
||||||
|
assert np.all(s == np.array([0]))
|
||||||
|
|
||||||
|
s = get_sequence([[1e-12, 1e-5, 8, "geometric"]])
|
||||||
|
assert isinstance(s, np.ndarray) and s.dtype == float
|
||||||
|
assert np.all(s == np.geomspace(1e-12, 1e-5, 8))
|
||||||
|
|
||||||
|
s = get_sequence([[1e-12, 1e-5, 9, "geometric0"]])
|
||||||
|
assert isinstance(s, np.ndarray) and s.dtype == float
|
||||||
|
assert np.all(s == np.concatenate(((0,), np.geomspace(1e-12, 1e-5, 8))))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
get_sequence([[1, 2, 3, "logarithm"]])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
get_sequence([[1, 2, 3, 4]])
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
get_sequence([[0, 1, "linear"]])
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
get_sequence([[0, 1, int]])
|
||||||
|
|
||||||
|
|
||||||
|
def test_constant_number():
|
||||||
"""classes with constant fields don't increase size and always return the constant"""
|
"""classes with constant fields don't increase size and always return the constant"""
|
||||||
|
|
||||||
@vdataclass
|
@vdataclass
|
||||||
@@ -40,21 +81,183 @@ def test_constant():
|
|||||||
conf.y(1)
|
conf.y(1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_constant_list():
|
||||||
|
"""classes with constant fields don't increase size and always return the constant"""
|
||||||
|
|
||||||
|
@vdataclass
|
||||||
|
class Conf:
|
||||||
|
x: Variable = vfield(default=[1, 2, 7])
|
||||||
|
y: Variable = vfield(default=[2, 0])
|
||||||
|
|
||||||
|
conf = Conf()
|
||||||
|
assert len(conf) == 6
|
||||||
|
assert conf.x(0) == 1
|
||||||
|
assert conf.y(0) == 2
|
||||||
|
assert conf.x(len(conf) - 1) == 7
|
||||||
|
assert conf.y(len(conf) - 1) == 0
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
conf.x(6)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
conf.y(6)
|
||||||
|
|
||||||
|
|
||||||
def test_decoration():
|
def test_decoration():
|
||||||
"""proper construction of the class"""
|
"""proper construction of the class"""
|
||||||
|
|
||||||
|
@vdataclass
|
||||||
|
class Conf:
|
||||||
|
x: Variable = vfield(default=[1, 2, 7])
|
||||||
|
y: Variable = vfield(default=[2, 0])
|
||||||
|
|
||||||
|
assert hasattr(Conf, "filter")
|
||||||
|
assert hasattr(Conf, "__vfields__")
|
||||||
|
assert hasattr(Conf, "__len__")
|
||||||
|
assert list(Conf.__vfields__["x"].default_sequence) == [1, 2, 7]
|
||||||
|
|
||||||
|
assert not hasattr(Conf, "__variables_nums__")
|
||||||
|
assert hasattr(Conf(), "__variables_nums__")
|
||||||
|
|
||||||
|
|
||||||
def test_name_error():
|
def test_name_error():
|
||||||
"""name errors are raised when accessing inexistent variable (direct, filter, ...)"""
|
"""name/attribute errors are raised when accessing inexistent variable (direct, filter, ...)"""
|
||||||
|
|
||||||
|
@vdataclass
|
||||||
|
class Conf:
|
||||||
|
x: Variable = vfield(default=[1, 2, 7])
|
||||||
|
y: Variable = vfield(default=[2, 0])
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
Conf(z=5)
|
||||||
|
|
||||||
|
conf = Conf(y=5)
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
conf.filter(z=45)
|
||||||
|
|
||||||
|
|
||||||
def test_defaults():
|
def test_assignment():
|
||||||
"""default values are respected, of proper type and of size one"""
|
"""assigned values properly update the whole instance"""
|
||||||
|
|
||||||
|
@vdataclass
|
||||||
|
class Conf:
|
||||||
|
x: Variable = vfield(default=[1, 2, 7])
|
||||||
|
y: Variable = vfield(default=[2, 0])
|
||||||
|
|
||||||
|
conf = Conf(x=5, y=0)
|
||||||
|
assert len(conf) == 1
|
||||||
|
assert conf.x(0) == 5
|
||||||
|
assert conf.y(0) == 0
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
conf.x(6)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
conf.y(6)
|
||||||
|
conf = Conf(x=[[0, 2, 5]], y=(0.1, 0.2))
|
||||||
|
assert len(conf) == 10
|
||||||
|
assert conf.x(0) == 0
|
||||||
|
assert conf.x(5) == 1.0
|
||||||
|
assert conf.y(0) == 0.1
|
||||||
|
assert conf.y(2) == 0.1
|
||||||
|
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
conf.x(-5)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
conf.x(5, local=True)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
conf.y(10)
|
||||||
|
|
||||||
|
conf.y = np.linspace(-1, 2, 7)
|
||||||
|
assert len(conf) == 7 * 5
|
||||||
|
assert conf.x(13) == 0.5
|
||||||
|
assert conf.y(13) == 2.0
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
del conf.y
|
||||||
|
|
||||||
|
conf.x = 0
|
||||||
|
assert len(conf) == 7
|
||||||
|
assert conf.x(5) == 0
|
||||||
|
assert conf.y(5) == 1.5
|
||||||
|
|
||||||
|
|
||||||
def test_formatting():
|
def test_unit_forammter():
|
||||||
|
fmt = unit_formatter("s", 4)
|
||||||
|
|
||||||
|
assert fmt(5e-29) == "0s"
|
||||||
|
assert fmt(5e27) == "5000Ys"
|
||||||
|
|
||||||
|
fmt = unit_formatter("s", 4, vmin=None)
|
||||||
|
assert fmt(5e-29) == "5e-05ys"
|
||||||
|
|
||||||
|
|
||||||
|
def test_param_formatting():
|
||||||
"""formatting is always respected"""
|
"""formatting is always respected"""
|
||||||
|
|
||||||
|
@vdataclass
|
||||||
|
class Conf:
|
||||||
|
x: Variable = vfield(default=[1, 2, 7])
|
||||||
|
y: Variable = vfield(default=[2, 2.456546, 0.000022305], decimals=2)
|
||||||
|
z: Variable = vfield(default=7.5e-9, decimals=3, suffix="s")
|
||||||
|
w: Variable = vfield(default=7568.4e6, decimals=4, suffix="Hz")
|
||||||
|
|
||||||
|
conf = Conf()
|
||||||
|
assert conf.x.format(0) == "x=1"
|
||||||
|
assert conf.y.format(0) == "y=2"
|
||||||
|
assert conf.y.format(1) == "y=2.5"
|
||||||
|
assert conf.y.format(2) == "y=2.2e-05"
|
||||||
|
assert conf.z.format(0) == "z=7.5ns"
|
||||||
|
assert conf.w.format(0) == "w=7.568GHz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_no_default():
|
||||||
|
@vdataclass
|
||||||
|
class Conf:
|
||||||
|
x: Variable = vfield(default=[1, 2, 7])
|
||||||
|
y: Variable = vfield()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
conf = Conf()
|
||||||
|
conf = Conf(y=5)
|
||||||
|
assert conf.x(0) == 1
|
||||||
|
assert conf.y(0) == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter():
|
||||||
|
@vdataclass
|
||||||
|
class Conf:
|
||||||
|
x: Variable = vfield()
|
||||||
|
y: Variable = vfield(default=[2, 2.456546, 0.000022305], decimals=2)
|
||||||
|
z: Variable = vfield(default=7.5e-9, decimals=3, suffix="s")
|
||||||
|
w: Variable = vfield(default=7568.4e6, decimals=4, suffix="Hz")
|
||||||
|
|
||||||
|
conf = Conf(x=(1, 2, 3))
|
||||||
|
f = conf.filter(x=1, y=(0, 1))
|
||||||
|
assert len(f) == 2
|
||||||
|
assert list(f) == [3, 4]
|
||||||
|
|
||||||
|
conf = Conf(x=[[0, 1, 56]], w=[0, 1, 2, 3, 4])
|
||||||
|
assert len(conf) == 56 * 5 * 3
|
||||||
|
f = conf.filter(x=slice(None, None, 5), y=1, w=(0, 1, 2))
|
||||||
|
assert len(f) == 12 * 3
|
||||||
|
assert all(conf.y(i) == Conf.y.default_sequence[1] for i in f)
|
||||||
|
|
||||||
|
mat = np.random.rand(56 * 5 * 3)
|
||||||
|
assert f(mat, squeeze=False).shape == (12, 1, 1, 3)
|
||||||
|
assert f(mat).shape == (12, 3)
|
||||||
|
|
||||||
|
mat = np.random.rand(70, 12)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
assert f(mat, squeeze=False).shape == (12, 1, 1, 3)
|
||||||
|
assert f(mat.ravel(), squeeze=False).shape == (12, 1, 1, 3)
|
||||||
|
assert f(mat.ravel()).shape == (12, 3)
|
||||||
|
|
||||||
|
mat = np.random.rand(71, 12)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
assert f(mat, axis=0).shape == (12, 3)
|
||||||
|
|
||||||
|
mat = np.random.rand(70, 12, 11)
|
||||||
|
assert f(mat.reshape(len(conf), 11), squeeze=False).shape == (12, 1, 1, 3, 11)
|
||||||
|
|
||||||
|
|
||||||
def test_param_parsing():
|
def test_param_parsing():
|
||||||
"""parsing a string or path returns the correct values"""
|
"""parsing a string or path returns the correct values"""
|
||||||
@@ -62,3 +265,7 @@ def test_param_parsing():
|
|||||||
|
|
||||||
def test_config_parsing():
|
def test_config_parsing():
|
||||||
"""parsing a string returns the exact id corresponding to the string/path"""
|
"""parsing a string returns the exact id corresponding to the string/path"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_repr():
|
||||||
|
"""print a config object correctly"""
|
||||||
|
|||||||
Reference in New Issue
Block a user