added auto_sync feature

This commit is contained in:
Benoît Sierro
2023-12-11 10:45:24 +01:00
parent 91545d39bb
commit e5c37f3155
3 changed files with 74 additions and 2 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "scgenerator" name = "scgenerator"
version = "0.3.25" version = "0.3.26"
description = "Simulate nonlinear pulse propagation in optical fibers" description = "Simulate nonlinear pulse propagation in optical fibers"
readme = "README.md" readme = "README.md"
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]

View File

@@ -62,6 +62,7 @@ class VariableParameter:
suffix: str = "" suffix: str = ""
decimals: int = 4 decimals: int = 4
sync: VariableParameter | None = None sync: VariableParameter | None = None
auto_sync: bool = False
default_sequence: Variable | None = field(init=False) default_sequence: Variable | None = field(init=False)
place: int | None = field(default=None, init=False) place: int | None = field(default=None, init=False)
public_name: str = field(init=False) public_name: str = field(init=False)
@@ -75,6 +76,9 @@ class VariableParameter:
else: else:
self.default_sequence = None self.default_sequence = None
if self.sync is not None and self.auto_sync:
raise ValueError("Cannot handle both `sync` and `auto_sync` at the same time")
def __set__(self, instance, value): def __set__(self, instance, value):
all_nums = instance.__variables_nums__ all_nums = instance.__variables_nums__
if value is self: if value is self:
@@ -96,6 +100,7 @@ class VariableParameter:
) )
all_nums[self.place] = len(var_obj) all_nums[self.place] = len(var_obj)
instance.__dict__[self.private_name] = var_obj instance.__dict__[self.private_name] = var_obj
_recheck_auto_sync(instance)
def __get__(self, instance, _): def __get__(self, instance, _):
if instance is None: if instance is None:
@@ -105,6 +110,17 @@ class VariableParameter:
return instance.__dict__[self.private_name] return instance.__dict__[self.private_name]
def _re_base(self, instance, new_base: int | None = None):
if self.private_name not in instance.__dict__: # not all fields set yet
return
var_obj = instance.__dict__[self.private_name]
if new_base is None:
var_obj.place = self.place
instance.__variables_nums__[self.place] = len(var_obj)
else:
var_obj.place = new_base
instance.__variables_nums__[self.place] = 1
def _create_default(self, instance): def _create_default(self, instance):
all_nums = instance.__variables_nums__ all_nums = instance.__variables_nums__
if self.default_sequence is None: if self.default_sequence is None:
@@ -204,6 +220,25 @@ class Filter:
return arr return arr
def _recheck_auto_sync(instance):
vfields: dict[str, VariableParameter] = instance.__vfields__
for _field in vfields.values():
if not _field.auto_sync:
continue
num = instance.__variables_nums__[_field.place]
try:
new_place = instance.__variables_nums__.index(num)
except ValueError:
continue
if new_place != _field.place:
_field._re_base(instance, new_place)
else:
_field._re_base(instance)
def debase(num: int, base: tuple[int, ...], strict: bool = True) -> tuple[int, ...]: def debase(num: int, base: tuple[int, ...], strict: bool = True) -> tuple[int, ...]:
""" """
decomboses a number into its digits in a variable base decomboses a number into its digits in a variable base
@@ -303,10 +338,11 @@ def vfield(
suffix: str = "", suffix: str = "",
decimals: int = 4, decimals: int = 4,
sync: VariableParameter | None = None, sync: VariableParameter | None = None,
auto_sync: bool = False,
): ):
return field( return field(
default=VariableParameter( default=VariableParameter(
default, suffix, decimals, sync.default if sync is not None else None default, suffix, decimals, sync.default if sync is not None else None, auto_sync
) )
) )

View File

@@ -342,3 +342,39 @@ def test_filter():
mat = np.random.rand(70, 12, 11) mat = np.random.rand(70, 12, 11)
assert f(mat.reshape(len(conf), 11), squeeze=False).shape == (12, 1, 1, 3, 11) assert f(mat.reshape(len(conf), 11), squeeze=False).shape == (12, 1, 1, 3, 11)
def test_auto_sync():
@vdataclass
class Conf:
x: Variable = vfield()
y: Variable = vfield(default=[7, 8, 9])
z: Variable = vfield(default=7.5e-9, auto_sync=True)
w: Variable = vfield(default=7568.4e6, auto_sync=True)
conf = Conf([1, 2, 3])
assert len(conf) == 9
@vdataclass
class Conf:
x: Variable = vfield()
y: Variable = vfield(default=[7, 8, 9])
z: Variable = vfield(default=7.5e-9, auto_sync=True)
conf = Conf(1, z=[1, 2, 3])
assert len(conf) == 3
conf.z = [0, 5]
assert len(conf) == 6
assert conf.y(0) == 7
assert conf.y(1) == 7
assert conf.y(2) == 8
assert conf.y(3) == 8
assert conf.y(4) == 9
assert conf.y(5) == 9
assert conf.z(0) == 0
assert conf.z(1) == 5
assert conf.z(2) == 0
assert conf.z(3) == 5
assert conf.z(4) == 0
assert conf.z(5) == 5