From 187f90a72d567be66b87d524a9c0bc3c4f3fe196 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Thu, 14 Mar 2024 14:32:47 +0100 Subject: [PATCH] dataclass completion for vdataclass --- src/scgenerator/variableparameters.py | 66 +++++++++++++++------------ 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/src/scgenerator/variableparameters.py b/src/scgenerator/variableparameters.py index 5bb3d3e..4b8b316 100644 --- a/src/scgenerator/variableparameters.py +++ b/src/scgenerator/variableparameters.py @@ -3,7 +3,7 @@ from __future__ import annotations import itertools import warnings from dataclasses import dataclass, field -from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar +from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar, TYPE_CHECKING import numpy as np @@ -51,6 +51,8 @@ class Variable(Generic[T]): return self.values[id] def __repr__(self) -> str: + if len(self.values) == 1: + return repr(self.values[0]) return f"{self.__class__.__name__}({self.values})" def __len__(self) -> int: @@ -401,38 +403,42 @@ def create_filter(_vfields: dict[str, VariableParameter]): return filter -def vdataclass(cls: type[T]) -> type[T]: - _vfields: dict[str, VariableParameter] = { - k: v.default - for k, v in cls.__dict__.items() - if isinstance(getattr(v, "default", None), VariableParameter) - } - if len(_vfields) == 0: - warnings.warn( - f"class {cls.__qualname__} doesn't contain any variable parameter," - "building normal dataclass instead" - ) - return dataclass(cls) +if TYPE_CHECKING: + vdataclass = dataclass +else: - v_nums = [] - for v in _vfields.values(): - if v.sync is not None: - v.place = v.sync.place - else: - v.place = len(v_nums) - v_nums.append(None) + def vdataclass(cls: type[T]) -> type[T]: + _vfields: dict[str, VariableParameter] = { + k: v.default + for k, v in cls.__dict__.items() + if isinstance(getattr(v, "default", None), VariableParameter) + } + if len(_vfields) == 0: + warnings.warn( + f"class {cls.__qualname__} doesn't contain any variable parameter," + "building normal dataclass instead" + ) + return dataclass(cls) - # put __variables_nums__ at the begining of the dict, as other fields need it - cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__ - cls.__variables_nums__ = field(init=False, repr=False, default_factory=v_nums.copy) - cls.__vfields__ = _vfields + v_nums = [] + for v in _vfields.values(): + if v.sync is not None: + v.place = v.sync.place + else: + v.place = len(v_nums) + v_nums.append(None) - cls = dataclass(cls) + # put __variables_nums__ at the begining of the dict, as other fields need it + cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__ + cls.__variables_nums__ = field(init=False, repr=False, default_factory=v_nums.copy) + cls.__vfields__ = _vfields - def _config_len(self): - return np.prod(self.__variables_nums__) + cls = dataclass(cls) - setattr(cls, "filter", create_filter(_vfields)) - setattr(cls, "__len__", _config_len) + def _config_len(self): + return np.prod(self.__variables_nums__) - return cls + setattr(cls, "filter", create_filter(_vfields)) + setattr(cls, "__len__", _config_len) + + return cls