From c153ef3af97945278a2c256e6f26cfda2de6a1c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Wed, 18 Oct 2023 08:52:31 +0200 Subject: [PATCH] added synced parameters --- pyproject.toml | 6 +- src/scgenerator/variableparameters.py | 395 ++++++++++++++------------ tests/test_variableparameters.py | 59 +++- 3 files changed, 261 insertions(+), 199 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7abcdd1..13de4ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "scgenerator" -version = "0.3.20" +version = "0.3.21" description = "Simulate nonlinear pulse propagation in optical fibers" readme = "README.md" authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] @@ -25,6 +25,10 @@ dependencies = [ "pydantic-settings", ] +[project.optional-dependencies] +cli = ["click"] +test = ["pytest"] + [tool.ruff] line-length = 100 ignore = ["E741"] diff --git a/src/scgenerator/variableparameters.py b/src/scgenerator/variableparameters.py index db6f05d..c516474 100644 --- a/src/scgenerator/variableparameters.py +++ b/src/scgenerator/variableparameters.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import warnings from dataclasses import dataclass, field @@ -11,6 +13,196 @@ T = TypeVar("T") P = ParamSpec("P") +@dataclass +class Variable(Generic[T]): + name: str + values: Sequence[T] + place: int + all_nums: list[int] + suffix: str = "" + decimals: int = 4 + formatter: Callable[[int], str] = field(init=False) + + def __getstate__(self): + return {k: v for k, v in self.__dict__.items() if k != "formatter"} + + def __setstate__(self, d): + self.__dict__ |= d + self.__post_init__() + + def __post_init__(self): + self.formatter = unit_formatter(self.suffix, self.decimals) + + def __call__(self, id: int, local: bool = False) -> T: + if id < 0: + raise IndexError("negative indices are not allowed") + if not local: + if not isinstance(id, (int, np.integer)): + raise TypeError(f"id {id!r} is not an integer") + if None in self.all_nums: + raise ValueError("at least one VariableParameter has not been configured") + id = debase(id, self.all_nums)[self.place] + if id >= len(self): + raise IndexError(f"id {id} is too big for variable parameter of size {len(self)}") + return self.values[id] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.values})" + + def __len__(self) -> int: + return len(self.values) + + def format(self, id: int, local: bool = False): + return f"{self.name}={self.formatter(self(id, local))}" + + +@dataclass(unsafe_hash=True) +class VariableParameter: + default: float | int | None = None + suffix: str = "" + decimals: int = 4 + sync: VariableParameter | None = None + default_sequence: Variable | None = field(init=False) + place: int | None = field(default=None, init=False) + public_name: str = field(init=False) + private_name: str = field(init=False) + + def __set_name__(self, owner: type, name: str): + self.public_name = name + self.private_name = "_variable__" + name + if self.default is not None: + self.default_sequence = get_sequence(self.default) + else: + self.default_sequence = None + + def __set__(self, instance, value): + all_nums = instance.__variables_nums__ + if value is self: + if self.default_sequence is None: + raise self._error_no_default() + all_nums[self.place] = len(self.default_sequence) + return + + sequence = get_sequence(value) + + self._check_sync(instance, sequence) + var_obj = Variable( + name=self.public_name, + values=sequence, + place=self.place, + suffix=self.suffix, + decimals=self.decimals, + all_nums=all_nums, + ) + all_nums[self.place] = len(var_obj) + instance.__dict__[self.private_name] = var_obj + + def __get__(self, instance, _): + if instance is None: + return self + if self.private_name not in instance.__dict__: + self._create_default(instance) + + return instance.__dict__[self.private_name] + + def _create_default(self, instance): + all_nums = instance.__variables_nums__ + if self.default_sequence is None: + raise self._error_no_default() + self._check_sync(instance, self.default_sequence) + obj = Variable( + name=self.public_name, + values=self.default_sequence, + place=self.place, + suffix=self.suffix, + decimals=self.decimals, + all_nums=all_nums, + ) + all_nums[self.place] = len(obj) + instance.__dict__[self.private_name] = obj + + def _check_sync(self, instance, sequence): + if self.sync is not None and (this_len := len(sequence)) != ( + other_len := len(getattr(instance, self.sync.private_name)) + ): + raise ValueError( + f"sequence of len {this_len} doesn't match syncronized sequence of len {other_len}" + ) + + def _error_no_default(self) -> ValueError: + return ValueError( + f"Variable parameter {self.public_name!r} has not been initialized " + "and no default has been set" + ) + + +@dataclass +class Filter: + indices: list[int] + shape: tuple[int, ...] + + def __iter__(self) -> Iterator[int]: + yield from self.indices + + def __call__( + self, arr: Sequence | np.ndarray, axis: int = 0, squeeze: bool = True + ) -> np.ndarray: + """ + filters and reshapes an array that has already been filtered + + Parameters + ---------- + arr : Sequence | np.ndarray, shape (..., n, ...) + at index `axis` ^ + array to reshape + axis : int, optional + which axis to reshape, by default 0 + squeeze : bool, optional + squeeze the array to get rid of axes of size 1 + + Returns + ------- + arr : np.ndarray, shape (..., *shape, ...) where `np.prod(shape) == n` + """ + arr = np.asarray(arr) + arr = arr[*((slice(None),) * axis), self.indices] + return self.reshape(arr, axis, squeeze) + + def __len__(self) -> int: + return len(self.indices) + + def __getitem__(self, key): + return self.indices[key] + + def reshape( + self, arr: Sequence | np.ndarray, axis: int = 0, squeeze: bool = True + ) -> np.ndarray: + """ + reshapes an array that has already been filtered + + Parameters + ---------- + arr : Sequence | np.ndarray, shape (..., n, ...) + at index `axis` ^ + array to reshape + axis : int, optional + which axis to reshape + squeeze : bool, optional + squeeze the array to get rid of axes of size 1 + + Returns + ------- + arr : np.ndarray, shape (..., *shape, ...) where `np.prod(shape) == n` + """ + arr = np.asanyarray(arr) + if axis < 0: + axis = len(self.shape) + axis + arr = arr.reshape(*arr.shape[:axis], *self.shape, *arr.shape[axis + 1 :]) + if squeeze: + arr = arr.squeeze() + return arr + + def debase(num: int, base: tuple[int, ...], strict: bool = True) -> tuple[int, ...]: """ decomboses a number into its digits in a variable base @@ -105,191 +297,17 @@ def get_sequence(value) -> np.ndarray: return np.array([value]) -@dataclass -class Variable(Generic[T]): - name: str - values: Sequence[T] - place: int - all_nums: list[int] - suffix: str = "" - decimals: int = 4 - formatter: Callable[[int], str] = field(init=False) - - def __getstate__(self): - return {k: v for k, v in self.__dict__.items() if k != "formatter"} - - def __setstate__(self, d): - self.__dict__ |= d - self.__post_init__() - - def __post_init__(self): - self.formatter = unit_formatter(self.suffix, self.decimals) - - def __call__(self, id: int, local: bool = False) -> T: - if id < 0: - raise IndexError("negative indices are not allowed") - if not local: - if not isinstance(id, (int, np.integer)): - raise TypeError(f"id {id!r} is not an integer") - if None in self.all_nums: - raise ValueError("at least one VariableParameter has not been configured") - id = debase(id, self.all_nums)[self.place] - if id >= len(self): - raise IndexError(f"id {id} is too big for variable parameter of size {len(self)}") - return self.values[id] - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.values})" - - def __len__(self) -> int: - return len(self.values) - - def parse(self, s: str): - return NotImplemented - - def format(self, id: int, local: bool = False): - return f"{self.name}={self.formatter(self(id, local))}" - - -@dataclass(unsafe_hash=True) -class VariableParameter: - func: Callable[[int, float, float, int], float] = get_linear - default: float | int | None = None - suffix: str = "" - decimals: int = 4 - default_sequence: Variable | None = field(init=False) - place: int | None = field(default=None, init=False) - public_name: str = field(init=False) - private_name: str = field(init=False) - - def __set_name__(self, owner: type, name: str): - self.public_name = name - self.private_name = "_variable__" + name - if self.default is not None: - self.default_sequence = get_sequence(self.default) - else: - self.default_sequence = None - - def __set__(self, instance, value): - all_nums = instance.__variables_nums__ - if value is self: - if self.default_sequence is None: - raise self._error_no_default() - all_nums[self.place] = len(self.default_sequence) - return - - var_obj = Variable( - name=self.public_name, - values=get_sequence(value), - place=self.place, - suffix=self.suffix, - decimals=self.decimals, - all_nums=all_nums, - ) - all_nums[self.place] = len(var_obj) - instance.__dict__[self.private_name] = var_obj - - def __get__(self, instance, _): - if instance is None: - return self - if self.private_name not in instance.__dict__: - all_nums = instance.__variables_nums__ - if self.default_sequence is None: - raise self._error_no_default() - obj = Variable( - name=self.public_name, - values=self.default_sequence, - place=self.place, - suffix=self.suffix, - decimals=self.decimals, - all_nums=all_nums, - ) - all_nums[self.place] = len(obj) - instance.__dict__[self.private_name] = obj - - return instance.__dict__[self.private_name] - - def _error_no_default(self) -> ValueError: - return ValueError( - f"Variable parameter {self.public_name!r} has not been initialized " - "and no default has been set" - ) - - -@dataclass -class Filter: - indices: list[int] - shape: tuple[int, ...] - - def __iter__(self) -> Iterator[int]: - yield from self.indices - - def __call__( - self, arr: Sequence | np.ndarray, axis: int = 0, squeeze: bool = True - ) -> np.ndarray: - """ - filters and reshapes an array that has already been filtered - - Parameters - ---------- - arr : Sequence | np.ndarray, shape (..., n, ...) - at index `axis` ^ - array to reshape - axis : int, optional - which axis to reshape, by default 0 - squeeze : bool, optional - squeeze the array to get rid of axes of size 1 - - Returns - ------- - arr : np.ndarray, shape (..., *shape, ...) where `np.prod(shape) == n` - """ - arr = np.asarray(arr) - arr = arr[*((slice(None),) * axis), self.indices] - return self.reshape(arr, axis, squeeze) - - def __len__(self) -> int: - return len(self.indices) - - def __getitem__(self, key): - return self.indices[key] - - def reshape( - self, arr: Sequence | np.ndarray, axis: int = 0, squeeze: bool = True - ) -> np.ndarray: - """ - reshapes an array that has already been filtered - - Parameters - ---------- - arr : Sequence | np.ndarray, shape (..., n, ...) - at index `axis` ^ - array to reshape - axis : int, optional - which axis to reshape - squeeze : bool, optional - squeeze the array to get rid of axes of size 1 - - Returns - ------- - arr : np.ndarray, shape (..., *shape, ...) where `np.prod(shape) == n` - """ - arr = np.asanyarray(arr) - if axis < 0: - axis = len(self.shape) + axis - arr = arr.reshape(*arr.shape[:axis], *self.shape, *arr.shape[axis + 1 :]) - if squeeze: - arr = arr.squeeze() - return arr - - def vfield( - func: Callable[[int, float, float, int], float] = get_linear, default: float | int | None = None, suffix: str = "", decimals: int = 4, + sync: VariableParameter | None = None, ): - return field(default=VariableParameter(func, default, suffix, decimals)) + return field( + default=VariableParameter( + default, suffix, decimals, sync.default if sync is not None else None + ) + ) def create_filter(_vfields: dict[str, VariableParameter]): @@ -323,7 +341,7 @@ def create_filter(_vfields: dict[str, VariableParameter]): def vdataclass(cls: type[T]) -> type[T]: - _vfields = { + _vfields: dict[str, VariableParameter] = { k: v.default for k, v in cls.__dict__.items() if isinstance(getattr(v, "default", None), VariableParameter) @@ -334,16 +352,21 @@ def vdataclass(cls: type[T]) -> type[T]: "building normal dataclass instead" ) return dataclass(cls) + + v_nums = [] + for v in _vfields.values(): + if v.sync is not None: + v.place = v.sync.place + else: + v.place = len(v_nums) + v_nums.append(None) + # put __variables_nums__ at the begining of the dict, as other fields need it cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__ - cls.__variables_nums__ = field( - init=False, repr=False, default_factory=(lambda: [None] * len(_vfields)) - ) + cls.__variables_nums__ = field(init=False, repr=False, default_factory=v_nums.copy) cls.__vfields__ = _vfields cls = dataclass(cls) - for i, v in enumerate(_vfields.values()): - v.place = i def _config_len(self): return np.prod(self.__variables_nums__) diff --git a/tests/test_variableparameters.py b/tests/test_variableparameters.py index c3dd917..b95581e 100644 --- a/tests/test_variableparameters.py +++ b/tests/test_variableparameters.py @@ -102,6 +102,53 @@ def test_constant_list(): conf.y(6) +def test_simple_synchronize(): + """synchronize 2 variable fields""" + + @vdataclass + class Conf: + x: Variable = vfield(default=[1, 2, 7]) + y: Variable = vfield(default=[2, 0, 0], sync=x) + + conf = Conf() + assert conf.x(0) == 1 + assert conf.x(1) == 2 + assert conf.x(2) == 7 + assert conf.y(0) == 2 + assert conf.y(1) == 0 + assert conf.y(2) == 0 + + with pytest.raises(ValueError): + conf.y(3) + with pytest.raises(ValueError): + conf.x(3) + + +def test_sync_strict(): + @vdataclass + class Conf: + x: Variable = vfield(default=[1, 2, 7]) + y: Variable = vfield(default=[2, 0], sync=x) + + conf = Conf() + assert conf.x(2) == 7 + with pytest.raises(ValueError): + conf.y(0) + + conf.x = [78, 55] + assert len(conf) == 2 + assert conf.y(1) == 0 + assert conf.x(1) == 55 + + with pytest.raises(ValueError): + conf.y = [1, 2, 3] + + conf.x = [11, 22, 33, 44, 55] + conf.y = [0, 1, 33, 111, 1111] + assert conf.x(2) == conf.y(2) == 33 + assert len(conf) == len(conf.x) == len(conf.y) == 5 + + def test_decoration(): """proper construction of the class""" @@ -275,15 +322,3 @@ def test_filter(): mat = np.random.rand(70, 12, 11) assert f(mat.reshape(len(conf), 11), squeeze=False).shape == (12, 1, 1, 3, 11) - - -def test_param_parsing(): - """parsing a string or path returns the correct values""" - - -def test_config_parsing(): - """parsing a string returns the exact id corresponding to the string/path""" - - -def test_repr(): - """print a config object correctly"""