diff --git a/src/scgenerator/variableparameters.py b/src/scgenerator/variableparameters.py new file mode 100644 index 0000000..9799bcb --- /dev/null +++ b/src/scgenerator/variableparameters.py @@ -0,0 +1,425 @@ +import itertools +import warnings +from dataclasses import dataclass, field +from typing import Callable, Iterator, Protocol, Sequence, TypeVar + +import numpy as np + +T = TypeVar("T") + + +def debase(num: int, base: tuple[int, ...]) -> tuple[int, ...]: + indices = [] + _, *base = base + for i in range(len(base)): + ind, num = divmod(num, np.prod(base[i:])) + indices.append(ind) + indices.append(num) + return tuple(indices) + + +def normalize_condition(v: slice | range | int | tuple, num: int) -> tuple[int, ...]: + if isinstance(v, slice): + return tuple(range(num)[v]) + elif isinstance(v, Sequence): + return tuple(v) + elif isinstance(v, int): + return (v,) + raise TypeError(f"condition {v!r} of type {type(v)} not valid") + + +def get_geometric(id, b_min, b_max, b_num): + if b_num == 1: + return b_min + return 0 if id == 0 else b_min * (b_max / b_min) ** ((id - 1) / (b_num - 2)) + + +def get_linear(id, d_min, d_max, d_num) -> float: + if d_num == 1: + return d_min + return d_min + (d_max - d_min) * (id / (d_num - 1)) + + +def int_linear(id, d_min, d_max, d_num) -> int: + if d_num == 1: + return d_min + return d_min + id * ((d_max - d_min) // (d_num - 1)) + + +def unit_formatter(unit: str, decimals: int = 1) -> Callable[[float | int], str]: + if not unit: + + def formatter(val): + return f"{val:.{decimals}g}" + + else: + + def formatter(val): + base, true_exp = format(val, f".{3+decimals}e").split("e") + true_exp = int(true_exp) + exp, mult = divmod(true_exp, 3) + prefix = "yzafpnµm kMGTPEZY"[8 + exp] if exp else "" + return f"{float(base)*10**mult:.{decimals}f}{prefix}{unit}" + + return formatter + + +class Variable(Protocol[T]): + place: int + + def __call__(self, _: int, local: bool = False) -> T: + ... + + def __repr__(self) -> str: + ... + + def __len__(self) -> int: + ... + + def format(self, id: int, local: bool = False): + return f"{self.name}={self.formatter(self(id, local))}" + + def parse(self, _: str) -> T: + ... + + def filter(self, local_id) -> tuple[list[int], tuple[int, ...]]: + """ + filter based on the desired local id. + + Parameters + ---------- + local_id : int + index of the desired value + + Returns + ------- + list[int] + list of all global indices point to instances of this variable parameter that have the + same value + tuple[int, ...] + resulting shape formed by all other variable parameters + """ + shape = self.all_nums.copy() + shape[self.place] = 1 + indices = [ + i + for i, indices in enumerate(itertools.product(*(range(k) for k in self.all_nums))) + if indices[self.place] == local_id + ] + return indices, tuple(shape) + + +@dataclass +class FuncVariable(Variable[float]): + func: Callable[[int, float, float, int], float] + args: tuple[float, float, int] + all_nums: list[int] + place: int + name: str + suffix: str = "" + decimals: int = 4 + formatter: Callable[[int], str] = field(init=False) + + def __getstate__(self): + return {k: v for k, v in self.__dict__.items() if k != "formatter"} + + def __setstate__(self, d): + self.__dict__ |= d + self.__post_init__() + + def __post_init__(self): + self.formatter = unit_formatter(self.suffix, self.decimals) + + def __call__(self, id: int, local: bool = False) -> float: + if not local: + if not isinstance(id, (int, np.integer)): + raise TypeError(f"id {id!r} is not an integer") + if None in self.all_nums: + raise ValueError("at least one VariableParameter has not been configured") + id = debase(id, self.all_nums)[self.place] + if id >= self.args[2]: + raise ValueError( + f"id {id} is too big for variable parameter of size {self.args[2]}" + ) + return self.func(id, *self.args) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({', '.join(format(el) for el in self.args)})" + + def __len__(self) -> int: + return self.args[2] + + def parse(self, s: str): + return NotImplemented + + +@dataclass +class ArrayVariable(Variable[T]): + values: Sequence[T] + all_nums: list[int] + place: int + name: str + suffix: str + decimals: int + + def __call__(self, id: int, local: bool = False) -> T: + if not local: + id = debase(id, self.all_nums)[self.place] + if id >= len(self.values): + raise ValueError( + f"id {id} is too big for variable parameter of size {len(self.values)}" + ) + return self.values[id] + + def __len__(self) -> int: + return 1 + + def format(self, _: int, local: bool = False) -> str: + return self.formatted + + def parse(self, _: str) -> float: + return NotImplemented + + +@dataclass +class Constant(Variable[T]): + value: float | int + place: int + name: str + suffix: str + decimals: int + + def __post_init__(self): + self.formatted = f"{self.name}={unit_formatter(self.suffix, self.decimals)(self.value)}" + + def __call__(self, _: int, local: bool = False) -> T: + return self.value + + def __len__(self) -> int: + return 1 + + def local(self, _: int) -> T: + return self.value + + def format(self, _: int, local: bool = False) -> str: + return self.formatted + + def parse(self, _: str) -> float: + return NotImplemented + + +@dataclass(unsafe_hash=True) +class VariableParameter: + func: Callable[[int, float, float, int], float] = get_linear + default: float | int | None = None + suffix: str = "" + decimals: int = 4 + default_callable: Variable | None = field(init=False) + place: int | None = field(default=None, init=False) + public_name: str = field(init=False) + private_name: str = field(init=False) + + def __set_name__(self, owner: type, name: str): + self.public_name = name + self.private_name = "_" + name + if self.default is not None: + self.default_callable = Constant( + value=self.default, + place=self.place, + name=self.public_name, + suffix=self.suffix, + decimals=self.decimals, + ) + else: + self.default_callable = None + + def __set__(self, instance, value): + all_nums = instance.__variables_nums__ + + if value is self: + if self.default_callable is None: + raise self._error_no_default() + var_obj = self.default_callable + var_num = 1 + elif isinstance(value, (float, int, complex)): + var_obj = Constant(value, self.place, self.public_name, self.suffix, self.decimals) + var_num = 1 + elif isinstance(value, Sequence): + if isinstance(value[0], Sequence): + value = value[0] + var_obj = FuncVariable( + func=self.func, + value=value, + all_nums=all_nums, + place=self.place, + name=self.public_name, + suffix=self.suffix, + decimals=self.decimals, + ) + var_num = value[2] + else: + var_obj = ArrayVariable( + values=value, + place=self.place, + name=self.public_name, + suffix=self.suffix, + decimals=self.decimals, + ) + else: + raise TypeError(f"value {value!r} of type {type(value)} not recognized") + + all_nums[self.place] = var_num + instance.__dict__[self.private_name] = var_obj + + def __get__(self, instance, _): + if instance is None: + return self + if self.private_name not in instance.__dict__: + if self.default_callable is None: + raise self._error_no_default() + return self.default_callable + + return instance.__dict__[self.private_name] + + def _error_no_default(self) -> ValueError: + return ValueError( + f"Variable parameter {self.public_name!r} has not been initialized " + "and no default has been set" + ) + + +@dataclass +class Filter: + indices: list[int] + shape: tuple[int, ...] + + def __iter__(self) -> Iterator[int]: + yield from self.indices + + def __call__( + self, arr: Sequence | np.ndarray, axis: int = 0, squeeze: bool = True + ) -> np.ndarray: + """ + filters and reshapes an array that has already been filtered + + Parameters + ---------- + arr : Sequence | np.ndarray, shape (..., n, ...) + at index `axis` ^ + array to reshape + axis : int, optional + which axis to reshape, by default 0 + squeeze : bool, optional + squeeze the array to get rid of axes of size 1 + + Returns + ------- + arr : np.ndarray, shape (..., *shape, ...) where `np.prod(shape) == n` + """ + arr = np.asarray(arr) + arr = arr[*((slice(None),) * axis), self.indices] + return self.reshape(arr, axis, squeeze) + + def __len__(self) -> int: + return len(self.indices) + + def __getitem__(self, key): + return self.indices[key] + + def reshape( + self, arr: Sequence | np.ndarray, axis: int = 0, squeeze: bool = True + ) -> np.ndarray: + """ + reshapes an array that has already been filtered + + Parameters + ---------- + arr : Sequence | np.ndarray, shape (..., n, ...) + at index `axis` ^ + array to reshape + axis : int, optional + which axis to reshape + squeeze : bool, optional + squeeze the array to get rid of axes of size 1 + + Returns + ------- + arr : np.ndarray, shape (..., *shape, ...) where `np.prod(shape) == n` + """ + arr = np.asanyarray(arr) + if axis < 0: + axis = len(self.shape) + axis + arr = arr.reshape(*arr.shape[:axis], *self.shape, *arr.shape[axis + 1 :]) + if squeeze: + arr = arr.squeeze() + return arr + + +def vfield( + func: Callable[[int, float, float, int], float] = get_linear, + default: float | int | None = None, + suffix: str = "", + decimals: int = 4, +): + return field(default=VariableParameter(func, default, suffix, decimals)) + + +def create_filter(_vfields: dict[str, VariableParameter]): + def filter(self, **kwargs): + """ + call Config.filter(param_a=0, param_b=(2, 4, 6)) + """ + if any(k not in _vfields for k in kwargs): + extra = set(kwargs) - set(_vfields) + raise NameError(f"class {self.__class__.__name__} has no attribute(s) {extra}") + all_local_indices = { + k: tuple(range(p_num)) for k, p_num in zip(_vfields, self.__variables_nums__) + } + conditions = { + k: normalize_condition(v, all_local_indices[k][-1] + 1) for k, v in kwargs.items() + } + + conditions = list((all_local_indices | conditions).values()) + all_indices = [] + for global_index, local_indices in enumerate( + itertools.product(*all_local_indices.values()) + ): + if all(j in cond for j, cond in zip(local_indices, conditions)): + all_indices.append(global_index) + + shape = tuple(len(s) for s in conditions) + + return Filter(all_indices, shape) + + return filter + + +def vdataclass(cls: type[T]) -> type[T]: + _vfields = { + k: v.default + for k, v in cls.__dict__.items() + if isinstance(getattr(v, "default", None), VariableParameter) + } + if len(_vfields) == 0: + warnings.warn( + f"class {cls.__qualname__} doesn't contain any variable parameter," + "building normal dataclass instead" + ) + return dataclass(cls) + # put __variables_nums__ at the begining of the dict, as other fields need it + cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__ + cls.__variables_nums__ = field( + init=False, repr=False, default_factory=(lambda: [None] * len(_vfields)) + ) + + cls = dataclass(cls) + for i, v in enumerate(_vfields.values()): + v.place = i + + def _config_len(self): + return np.prod(self.__variables_nums__) + + setattr(cls, "filter", create_filter(_vfields)) + setattr(cls, "__len__", _config_len) + + return cls diff --git a/tests/test_variableparameters.py b/tests/test_variableparameters.py new file mode 100644 index 0000000..9e39e13 --- /dev/null +++ b/tests/test_variableparameters.py @@ -0,0 +1,64 @@ +from dataclasses import field + +import pytest + +from scgenerator.variableparameters import Variable, vdataclass, vfield + + +def test_vfield_identification(): + """only vfields and not normal fields count""" + + class Conf: + x: int + + with pytest.warns(UserWarning): + vdataclass(Conf) + + @vdataclass + class Conf2: + x: Variable = vfield(default=5) + + assert hasattr(Conf2, "filter") + + +def test_constant(): + """classes with constant fields don't increase size and always return the constant""" + + @vdataclass + class Conf: + x: Variable = vfield(default=3) + y: Variable = vfield(default=4) + + conf = Conf() + assert len(conf) == 1 + assert conf.x(0) == 3 + assert conf.y(0) == 4 + + with pytest.raises(ValueError): + conf.x(1) + with pytest.raises(ValueError): + conf.y(1) + + +def test_decoration(): + """proper construction of the class""" + + +def test_name_error(): + """name errors are raised when accessing inexistent variable (direct, filter, ...)""" + + +def test_defaults(): + """default values are respected, of proper type and of size one""" + + +def test_formatting(): + """formatting is always respected""" + + +def test_param_parsing(): + """parsing a string or path returns the correct values""" + + +def test_config_parsing(): + """parsing a string returns the exact id corresponding to the string/path"""