From e5c37f3155de9122fcfb50872f28b1f80d946d9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Mon, 11 Dec 2023 10:45:24 +0100 Subject: [PATCH] added auto_sync feature --- pyproject.toml | 2 +- src/scgenerator/variableparameters.py | 38 ++++++++++++++++++++++++++- tests/test_variableparameters.py | 36 +++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 86fedb3..0c8a146 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "scgenerator" -version = "0.3.25" +version = "0.3.26" description = "Simulate nonlinear pulse propagation in optical fibers" readme = "README.md" authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] diff --git a/src/scgenerator/variableparameters.py b/src/scgenerator/variableparameters.py index b66d683..7177774 100644 --- a/src/scgenerator/variableparameters.py +++ b/src/scgenerator/variableparameters.py @@ -62,6 +62,7 @@ class VariableParameter: suffix: str = "" decimals: int = 4 sync: VariableParameter | None = None + auto_sync: bool = False default_sequence: Variable | None = field(init=False) place: int | None = field(default=None, init=False) public_name: str = field(init=False) @@ -75,6 +76,9 @@ class VariableParameter: else: self.default_sequence = None + if self.sync is not None and self.auto_sync: + raise ValueError("Cannot handle both `sync` and `auto_sync` at the same time") + def __set__(self, instance, value): all_nums = instance.__variables_nums__ if value is self: @@ -96,6 +100,7 @@ class VariableParameter: ) all_nums[self.place] = len(var_obj) instance.__dict__[self.private_name] = var_obj + _recheck_auto_sync(instance) def __get__(self, instance, _): if instance is None: @@ -105,6 +110,17 @@ class VariableParameter: return instance.__dict__[self.private_name] + def _re_base(self, instance, new_base: int | None = None): + if self.private_name not in instance.__dict__: # not all fields set yet + return + var_obj = instance.__dict__[self.private_name] + if new_base is None: + var_obj.place = self.place + instance.__variables_nums__[self.place] = len(var_obj) + else: + var_obj.place = new_base + instance.__variables_nums__[self.place] = 1 + def _create_default(self, instance): all_nums = instance.__variables_nums__ if self.default_sequence is None: @@ -204,6 +220,25 @@ class Filter: return arr +def _recheck_auto_sync(instance): + vfields: dict[str, VariableParameter] = instance.__vfields__ + + for _field in vfields.values(): + if not _field.auto_sync: + continue + + num = instance.__variables_nums__[_field.place] + + try: + new_place = instance.__variables_nums__.index(num) + except ValueError: + continue + if new_place != _field.place: + _field._re_base(instance, new_place) + else: + _field._re_base(instance) + + def debase(num: int, base: tuple[int, ...], strict: bool = True) -> tuple[int, ...]: """ decomboses a number into its digits in a variable base @@ -303,10 +338,11 @@ def vfield( suffix: str = "", decimals: int = 4, sync: VariableParameter | None = None, + auto_sync: bool = False, ): return field( default=VariableParameter( - default, suffix, decimals, sync.default if sync is not None else None + default, suffix, decimals, sync.default if sync is not None else None, auto_sync ) ) diff --git a/tests/test_variableparameters.py b/tests/test_variableparameters.py index 42906ee..bd1d87a 100644 --- a/tests/test_variableparameters.py +++ b/tests/test_variableparameters.py @@ -342,3 +342,39 @@ def test_filter(): mat = np.random.rand(70, 12, 11) assert f(mat.reshape(len(conf), 11), squeeze=False).shape == (12, 1, 1, 3, 11) + + +def test_auto_sync(): + @vdataclass + class Conf: + x: Variable = vfield() + y: Variable = vfield(default=[7, 8, 9]) + z: Variable = vfield(default=7.5e-9, auto_sync=True) + w: Variable = vfield(default=7568.4e6, auto_sync=True) + + conf = Conf([1, 2, 3]) + assert len(conf) == 9 + + @vdataclass + class Conf: + x: Variable = vfield() + y: Variable = vfield(default=[7, 8, 9]) + z: Variable = vfield(default=7.5e-9, auto_sync=True) + + conf = Conf(1, z=[1, 2, 3]) + assert len(conf) == 3 + + conf.z = [0, 5] + assert len(conf) == 6 + assert conf.y(0) == 7 + assert conf.y(1) == 7 + assert conf.y(2) == 8 + assert conf.y(3) == 8 + assert conf.y(4) == 9 + assert conf.y(5) == 9 + assert conf.z(0) == 0 + assert conf.z(1) == 5 + assert conf.z(2) == 0 + assert conf.z(3) == 5 + assert conf.z(4) == 0 + assert conf.z(5) == 5