new format_str option in vfield

This commit is contained in:
Benoît Sierro
2024-01-24 15:24:37 +01:00
parent 4e22d70193
commit 820dbbdea5
2 changed files with 30 additions and 2 deletions

View File

@@ -21,6 +21,7 @@ class Variable(Generic[T]):
all_nums: list[int]
suffix: str = ""
decimals: int = 4
format_str: str | None = None
formatter: Callable[[int], str] = field(init=False)
def __getstate__(self):
@@ -31,6 +32,9 @@ class Variable(Generic[T]):
self.__post_init__()
def __post_init__(self):
if self.format_str is not None:
self.formatter = lambda v: format(v, self.format_str)
else:
self.formatter = unit_formatter(self.suffix, self.decimals)
def __call__(self, id: int, local: bool = False) -> T:
@@ -61,6 +65,7 @@ class VariableParameter:
default: float | int | None = None
suffix: str = ""
decimals: int = 4
format_str: str | None = None
sync: VariableParameter | None = None
auto_sync: bool = False
default_sequence: Variable | None = field(init=False)
@@ -96,6 +101,7 @@ class VariableParameter:
place=self.place,
suffix=self.suffix,
decimals=self.decimals,
format_str=self.format_str,
all_nums=all_nums,
)
all_nums[self.place] = len(var_obj)
@@ -132,6 +138,7 @@ class VariableParameter:
place=self.place,
suffix=self.suffix,
decimals=self.decimals,
format_str=self.format_str,
all_nums=all_nums,
)
all_nums[self.place] = len(obj)
@@ -348,12 +355,18 @@ def vfield(
default: float | int | None = None,
suffix: str = "",
decimals: int = 4,
format_str: str | None = None,
sync: VariableParameter | None = None,
auto_sync: bool = False,
):
return field(
default=VariableParameter(
default, suffix, decimals, sync.default if sync is not None else None, auto_sync
default=default,
suffix=suffix,
decimals=decimals,
format_str=format_str,
sync=sync.default if sync is not None else None,
auto_sync=auto_sync,
)
)

View File

@@ -292,6 +292,21 @@ def test_unit_forammter():
assert fmt(np.zeros((2, 2, 2))) == "(((0, 0), (0, 0)), ((0, 0), (0, 0)))"
def test_format_str():
"""format_str takes precedence"""
@vdataclass
class Conf:
x: Variable = vfield(default=123456789)
y: Variable = vfield(default=123456789, decimals=2)
z: Variable = vfield(default=123456789, decimals=2, format_str="d")
conf = Conf()
assert conf.x.format(0) == "x=1.235e+08"
assert conf.y.format(0) == "y=1.2e+08"
assert conf.z.format(0) == "z=123456789"
def test_param_formatting():
"""formatting is always respected"""