original variable parameters
This commit is contained in:
425
src/scgenerator/variableparameters.py
Normal file
425
src/scgenerator/variableparameters.py
Normal file
@@ -0,0 +1,425 @@
|
|||||||
|
import itertools
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable, Iterator, Protocol, Sequence, TypeVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def debase(num: int, base: tuple[int, ...]) -> tuple[int, ...]:
|
||||||
|
indices = []
|
||||||
|
_, *base = base
|
||||||
|
for i in range(len(base)):
|
||||||
|
ind, num = divmod(num, np.prod(base[i:]))
|
||||||
|
indices.append(ind)
|
||||||
|
indices.append(num)
|
||||||
|
return tuple(indices)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_condition(v: slice | range | int | tuple, num: int) -> tuple[int, ...]:
|
||||||
|
if isinstance(v, slice):
|
||||||
|
return tuple(range(num)[v])
|
||||||
|
elif isinstance(v, Sequence):
|
||||||
|
return tuple(v)
|
||||||
|
elif isinstance(v, int):
|
||||||
|
return (v,)
|
||||||
|
raise TypeError(f"condition {v!r} of type {type(v)} not valid")
|
||||||
|
|
||||||
|
|
||||||
|
def get_geometric(id, b_min, b_max, b_num):
|
||||||
|
if b_num == 1:
|
||||||
|
return b_min
|
||||||
|
return 0 if id == 0 else b_min * (b_max / b_min) ** ((id - 1) / (b_num - 2))
|
||||||
|
|
||||||
|
|
||||||
|
def get_linear(id, d_min, d_max, d_num) -> float:
|
||||||
|
if d_num == 1:
|
||||||
|
return d_min
|
||||||
|
return d_min + (d_max - d_min) * (id / (d_num - 1))
|
||||||
|
|
||||||
|
|
||||||
|
def int_linear(id, d_min, d_max, d_num) -> int:
|
||||||
|
if d_num == 1:
|
||||||
|
return d_min
|
||||||
|
return d_min + id * ((d_max - d_min) // (d_num - 1))
|
||||||
|
|
||||||
|
|
||||||
|
def unit_formatter(unit: str, decimals: int = 1) -> Callable[[float | int], str]:
|
||||||
|
if not unit:
|
||||||
|
|
||||||
|
def formatter(val):
|
||||||
|
return f"{val:.{decimals}g}"
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def formatter(val):
|
||||||
|
base, true_exp = format(val, f".{3+decimals}e").split("e")
|
||||||
|
true_exp = int(true_exp)
|
||||||
|
exp, mult = divmod(true_exp, 3)
|
||||||
|
prefix = "yzafpnµm kMGTPEZY"[8 + exp] if exp else ""
|
||||||
|
return f"{float(base)*10**mult:.{decimals}f}{prefix}{unit}"
|
||||||
|
|
||||||
|
return formatter
|
||||||
|
|
||||||
|
|
||||||
|
class Variable(Protocol[T]):
|
||||||
|
place: int
|
||||||
|
|
||||||
|
def __call__(self, _: int, local: bool = False) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
def format(self, id: int, local: bool = False):
|
||||||
|
return f"{self.name}={self.formatter(self(id, local))}"
|
||||||
|
|
||||||
|
def parse(self, _: str) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
def filter(self, local_id) -> tuple[list[int], tuple[int, ...]]:
|
||||||
|
"""
|
||||||
|
filter based on the desired local id.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
local_id : int
|
||||||
|
index of the desired value
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[int]
|
||||||
|
list of all global indices point to instances of this variable parameter that have the
|
||||||
|
same value
|
||||||
|
tuple[int, ...]
|
||||||
|
resulting shape formed by all other variable parameters
|
||||||
|
"""
|
||||||
|
shape = self.all_nums.copy()
|
||||||
|
shape[self.place] = 1
|
||||||
|
indices = [
|
||||||
|
i
|
||||||
|
for i, indices in enumerate(itertools.product(*(range(k) for k in self.all_nums)))
|
||||||
|
if indices[self.place] == local_id
|
||||||
|
]
|
||||||
|
return indices, tuple(shape)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FuncVariable(Variable[float]):
|
||||||
|
func: Callable[[int, float, float, int], float]
|
||||||
|
args: tuple[float, float, int]
|
||||||
|
all_nums: list[int]
|
||||||
|
place: int
|
||||||
|
name: str
|
||||||
|
suffix: str = ""
|
||||||
|
decimals: int = 4
|
||||||
|
formatter: Callable[[int], str] = field(init=False)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return {k: v for k, v in self.__dict__.items() if k != "formatter"}
|
||||||
|
|
||||||
|
def __setstate__(self, d):
|
||||||
|
self.__dict__ |= d
|
||||||
|
self.__post_init__()
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.formatter = unit_formatter(self.suffix, self.decimals)
|
||||||
|
|
||||||
|
def __call__(self, id: int, local: bool = False) -> float:
|
||||||
|
if not local:
|
||||||
|
if not isinstance(id, (int, np.integer)):
|
||||||
|
raise TypeError(f"id {id!r} is not an integer")
|
||||||
|
if None in self.all_nums:
|
||||||
|
raise ValueError("at least one VariableParameter has not been configured")
|
||||||
|
id = debase(id, self.all_nums)[self.place]
|
||||||
|
if id >= self.args[2]:
|
||||||
|
raise ValueError(
|
||||||
|
f"id {id} is too big for variable parameter of size {self.args[2]}"
|
||||||
|
)
|
||||||
|
return self.func(id, *self.args)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}({', '.join(format(el) for el in self.args)})"
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.args[2]
|
||||||
|
|
||||||
|
def parse(self, s: str):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ArrayVariable(Variable[T]):
|
||||||
|
values: Sequence[T]
|
||||||
|
all_nums: list[int]
|
||||||
|
place: int
|
||||||
|
name: str
|
||||||
|
suffix: str
|
||||||
|
decimals: int
|
||||||
|
|
||||||
|
def __call__(self, id: int, local: bool = False) -> T:
|
||||||
|
if not local:
|
||||||
|
id = debase(id, self.all_nums)[self.place]
|
||||||
|
if id >= len(self.values):
|
||||||
|
raise ValueError(
|
||||||
|
f"id {id} is too big for variable parameter of size {len(self.values)}"
|
||||||
|
)
|
||||||
|
return self.values[id]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def format(self, _: int, local: bool = False) -> str:
|
||||||
|
return self.formatted
|
||||||
|
|
||||||
|
def parse(self, _: str) -> float:
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Constant(Variable[T]):
|
||||||
|
value: float | int
|
||||||
|
place: int
|
||||||
|
name: str
|
||||||
|
suffix: str
|
||||||
|
decimals: int
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.formatted = f"{self.name}={unit_formatter(self.suffix, self.decimals)(self.value)}"
|
||||||
|
|
||||||
|
def __call__(self, _: int, local: bool = False) -> T:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def local(self, _: int) -> T:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def format(self, _: int, local: bool = False) -> str:
|
||||||
|
return self.formatted
|
||||||
|
|
||||||
|
def parse(self, _: str) -> float:
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(unsafe_hash=True)
|
||||||
|
class VariableParameter:
|
||||||
|
func: Callable[[int, float, float, int], float] = get_linear
|
||||||
|
default: float | int | None = None
|
||||||
|
suffix: str = ""
|
||||||
|
decimals: int = 4
|
||||||
|
default_callable: Variable | None = field(init=False)
|
||||||
|
place: int | None = field(default=None, init=False)
|
||||||
|
public_name: str = field(init=False)
|
||||||
|
private_name: str = field(init=False)
|
||||||
|
|
||||||
|
def __set_name__(self, owner: type, name: str):
|
||||||
|
self.public_name = name
|
||||||
|
self.private_name = "_" + name
|
||||||
|
if self.default is not None:
|
||||||
|
self.default_callable = Constant(
|
||||||
|
value=self.default,
|
||||||
|
place=self.place,
|
||||||
|
name=self.public_name,
|
||||||
|
suffix=self.suffix,
|
||||||
|
decimals=self.decimals,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.default_callable = None
|
||||||
|
|
||||||
|
def __set__(self, instance, value):
|
||||||
|
all_nums = instance.__variables_nums__
|
||||||
|
|
||||||
|
if value is self:
|
||||||
|
if self.default_callable is None:
|
||||||
|
raise self._error_no_default()
|
||||||
|
var_obj = self.default_callable
|
||||||
|
var_num = 1
|
||||||
|
elif isinstance(value, (float, int, complex)):
|
||||||
|
var_obj = Constant(value, self.place, self.public_name, self.suffix, self.decimals)
|
||||||
|
var_num = 1
|
||||||
|
elif isinstance(value, Sequence):
|
||||||
|
if isinstance(value[0], Sequence):
|
||||||
|
value = value[0]
|
||||||
|
var_obj = FuncVariable(
|
||||||
|
func=self.func,
|
||||||
|
value=value,
|
||||||
|
all_nums=all_nums,
|
||||||
|
place=self.place,
|
||||||
|
name=self.public_name,
|
||||||
|
suffix=self.suffix,
|
||||||
|
decimals=self.decimals,
|
||||||
|
)
|
||||||
|
var_num = value[2]
|
||||||
|
else:
|
||||||
|
var_obj = ArrayVariable(
|
||||||
|
values=value,
|
||||||
|
place=self.place,
|
||||||
|
name=self.public_name,
|
||||||
|
suffix=self.suffix,
|
||||||
|
decimals=self.decimals,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"value {value!r} of type {type(value)} not recognized")
|
||||||
|
|
||||||
|
all_nums[self.place] = var_num
|
||||||
|
instance.__dict__[self.private_name] = var_obj
|
||||||
|
|
||||||
|
def __get__(self, instance, _):
|
||||||
|
if instance is None:
|
||||||
|
return self
|
||||||
|
if self.private_name not in instance.__dict__:
|
||||||
|
if self.default_callable is None:
|
||||||
|
raise self._error_no_default()
|
||||||
|
return self.default_callable
|
||||||
|
|
||||||
|
return instance.__dict__[self.private_name]
|
||||||
|
|
||||||
|
def _error_no_default(self) -> ValueError:
|
||||||
|
return ValueError(
|
||||||
|
f"Variable parameter {self.public_name!r} has not been initialized "
|
||||||
|
"and no default has been set"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Filter:
|
||||||
|
indices: list[int]
|
||||||
|
shape: tuple[int, ...]
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
yield from self.indices
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, arr: Sequence | np.ndarray, axis: int = 0, squeeze: bool = True
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
filters and reshapes an array that has already been filtered
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
arr : Sequence | np.ndarray, shape (..., n, ...)
|
||||||
|
at index `axis` ^
|
||||||
|
array to reshape
|
||||||
|
axis : int, optional
|
||||||
|
which axis to reshape, by default 0
|
||||||
|
squeeze : bool, optional
|
||||||
|
squeeze the array to get rid of axes of size 1
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
arr : np.ndarray, shape (..., *shape, ...) where `np.prod(shape) == n`
|
||||||
|
"""
|
||||||
|
arr = np.asarray(arr)
|
||||||
|
arr = arr[*((slice(None),) * axis), self.indices]
|
||||||
|
return self.reshape(arr, axis, squeeze)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.indices)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.indices[key]
|
||||||
|
|
||||||
|
def reshape(
|
||||||
|
self, arr: Sequence | np.ndarray, axis: int = 0, squeeze: bool = True
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
reshapes an array that has already been filtered
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
arr : Sequence | np.ndarray, shape (..., n, ...)
|
||||||
|
at index `axis` ^
|
||||||
|
array to reshape
|
||||||
|
axis : int, optional
|
||||||
|
which axis to reshape
|
||||||
|
squeeze : bool, optional
|
||||||
|
squeeze the array to get rid of axes of size 1
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
arr : np.ndarray, shape (..., *shape, ...) where `np.prod(shape) == n`
|
||||||
|
"""
|
||||||
|
arr = np.asanyarray(arr)
|
||||||
|
if axis < 0:
|
||||||
|
axis = len(self.shape) + axis
|
||||||
|
arr = arr.reshape(*arr.shape[:axis], *self.shape, *arr.shape[axis + 1 :])
|
||||||
|
if squeeze:
|
||||||
|
arr = arr.squeeze()
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
|
def vfield(
|
||||||
|
func: Callable[[int, float, float, int], float] = get_linear,
|
||||||
|
default: float | int | None = None,
|
||||||
|
suffix: str = "",
|
||||||
|
decimals: int = 4,
|
||||||
|
):
|
||||||
|
return field(default=VariableParameter(func, default, suffix, decimals))
|
||||||
|
|
||||||
|
|
||||||
|
def create_filter(_vfields: dict[str, VariableParameter]):
|
||||||
|
def filter(self, **kwargs):
|
||||||
|
"""
|
||||||
|
call Config.filter(param_a=0, param_b=(2, 4, 6))
|
||||||
|
"""
|
||||||
|
if any(k not in _vfields for k in kwargs):
|
||||||
|
extra = set(kwargs) - set(_vfields)
|
||||||
|
raise NameError(f"class {self.__class__.__name__} has no attribute(s) {extra}")
|
||||||
|
all_local_indices = {
|
||||||
|
k: tuple(range(p_num)) for k, p_num in zip(_vfields, self.__variables_nums__)
|
||||||
|
}
|
||||||
|
conditions = {
|
||||||
|
k: normalize_condition(v, all_local_indices[k][-1] + 1) for k, v in kwargs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
conditions = list((all_local_indices | conditions).values())
|
||||||
|
all_indices = []
|
||||||
|
for global_index, local_indices in enumerate(
|
||||||
|
itertools.product(*all_local_indices.values())
|
||||||
|
):
|
||||||
|
if all(j in cond for j, cond in zip(local_indices, conditions)):
|
||||||
|
all_indices.append(global_index)
|
||||||
|
|
||||||
|
shape = tuple(len(s) for s in conditions)
|
||||||
|
|
||||||
|
return Filter(all_indices, shape)
|
||||||
|
|
||||||
|
return filter
|
||||||
|
|
||||||
|
|
||||||
|
def vdataclass(cls: type[T]) -> type[T]:
|
||||||
|
_vfields = {
|
||||||
|
k: v.default
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if isinstance(getattr(v, "default", None), VariableParameter)
|
||||||
|
}
|
||||||
|
if len(_vfields) == 0:
|
||||||
|
warnings.warn(
|
||||||
|
f"class {cls.__qualname__} doesn't contain any variable parameter,"
|
||||||
|
"building normal dataclass instead"
|
||||||
|
)
|
||||||
|
return dataclass(cls)
|
||||||
|
# put __variables_nums__ at the begining of the dict, as other fields need it
|
||||||
|
cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__
|
||||||
|
cls.__variables_nums__ = field(
|
||||||
|
init=False, repr=False, default_factory=(lambda: [None] * len(_vfields))
|
||||||
|
)
|
||||||
|
|
||||||
|
cls = dataclass(cls)
|
||||||
|
for i, v in enumerate(_vfields.values()):
|
||||||
|
v.place = i
|
||||||
|
|
||||||
|
def _config_len(self):
|
||||||
|
return np.prod(self.__variables_nums__)
|
||||||
|
|
||||||
|
setattr(cls, "filter", create_filter(_vfields))
|
||||||
|
setattr(cls, "__len__", _config_len)
|
||||||
|
|
||||||
|
return cls
|
||||||
64
tests/test_variableparameters.py
Normal file
64
tests/test_variableparameters.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from dataclasses import field
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from scgenerator.variableparameters import Variable, vdataclass, vfield
|
||||||
|
|
||||||
|
|
||||||
|
def test_vfield_identification():
|
||||||
|
"""only vfields and not normal fields count"""
|
||||||
|
|
||||||
|
class Conf:
|
||||||
|
x: int
|
||||||
|
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
vdataclass(Conf)
|
||||||
|
|
||||||
|
@vdataclass
|
||||||
|
class Conf2:
|
||||||
|
x: Variable = vfield(default=5)
|
||||||
|
|
||||||
|
assert hasattr(Conf2, "filter")
|
||||||
|
|
||||||
|
|
||||||
|
def test_constant():
|
||||||
|
"""classes with constant fields don't increase size and always return the constant"""
|
||||||
|
|
||||||
|
@vdataclass
|
||||||
|
class Conf:
|
||||||
|
x: Variable = vfield(default=3)
|
||||||
|
y: Variable = vfield(default=4)
|
||||||
|
|
||||||
|
conf = Conf()
|
||||||
|
assert len(conf) == 1
|
||||||
|
assert conf.x(0) == 3
|
||||||
|
assert conf.y(0) == 4
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
conf.x(1)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
conf.y(1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decoration():
|
||||||
|
"""proper construction of the class"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_name_error():
|
||||||
|
"""name errors are raised when accessing inexistent variable (direct, filter, ...)"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_defaults():
|
||||||
|
"""default values are respected, of proper type and of size one"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_formatting():
|
||||||
|
"""formatting is always respected"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_param_parsing():
|
||||||
|
"""parsing a string or path returns the correct values"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_parsing():
|
||||||
|
"""parsing a string returns the exact id corresponding to the string/path"""
|
||||||
Reference in New Issue
Block a user