diff --git a/src/scgenerator/variableparameters.py b/src/scgenerator/variableparameters.py index bbf08cd..5bb3d3e 100644 --- a/src/scgenerator/variableparameters.py +++ b/src/scgenerator/variableparameters.py @@ -21,6 +21,7 @@ class Variable(Generic[T]): all_nums: list[int] suffix: str = "" decimals: int = 4 + format_str: str | None = None formatter: Callable[[int], str] = field(init=False) def __getstate__(self): @@ -31,7 +32,10 @@ class Variable(Generic[T]): self.__post_init__() def __post_init__(self): - self.formatter = unit_formatter(self.suffix, self.decimals) + if self.format_str is not None: + self.formatter = lambda v: format(v, self.format_str) + else: + self.formatter = unit_formatter(self.suffix, self.decimals) def __call__(self, id: int, local: bool = False) -> T: if id < 0: @@ -61,6 +65,7 @@ class VariableParameter: default: float | int | None = None suffix: str = "" decimals: int = 4 + format_str: str | None = None sync: VariableParameter | None = None auto_sync: bool = False default_sequence: Variable | None = field(init=False) @@ -96,6 +101,7 @@ class VariableParameter: place=self.place, suffix=self.suffix, decimals=self.decimals, + format_str=self.format_str, all_nums=all_nums, ) all_nums[self.place] = len(var_obj) @@ -132,6 +138,7 @@ class VariableParameter: place=self.place, suffix=self.suffix, decimals=self.decimals, + format_str=self.format_str, all_nums=all_nums, ) all_nums[self.place] = len(obj) @@ -348,12 +355,18 @@ def vfield( default: float | int | None = None, suffix: str = "", decimals: int = 4, + format_str: str | None = None, sync: VariableParameter | None = None, auto_sync: bool = False, ): return field( default=VariableParameter( - default, suffix, decimals, sync.default if sync is not None else None, auto_sync + default=default, + suffix=suffix, + decimals=decimals, + format_str=format_str, + sync=sync.default if sync is not None else None, + auto_sync=auto_sync, ) ) diff --git a/tests/test_variableparameters.py b/tests/test_variableparameters.py index b46958f..c09a264 100644 --- a/tests/test_variableparameters.py +++ b/tests/test_variableparameters.py @@ -292,6 +292,21 @@ def test_unit_forammter(): assert fmt(np.zeros((2, 2, 2))) == "(((0, 0), (0, 0)), ((0, 0), (0, 0)))" +def test_format_str(): + """format_str takes precedence""" + + @vdataclass + class Conf: + x: Variable = vfield(default=123456789) + y: Variable = vfield(default=123456789, decimals=2) + z: Variable = vfield(default=123456789, decimals=2, format_str="d") + + conf = Conf() + assert conf.x.format(0) == "x=1.235e+08" + assert conf.y.format(0) == "y=1.2e+08" + assert conf.z.format(0) == "z=123456789" + + def test_param_formatting(): """formatting is always respected"""