dataclass completion for vdataclass

This commit is contained in:
2024-03-14 14:32:47 +01:00
parent 8c59f95a01
commit 187f90a72d

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import itertools import itertools
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar, TYPE_CHECKING
import numpy as np import numpy as np
@@ -51,6 +51,8 @@ class Variable(Generic[T]):
return self.values[id] return self.values[id]
def __repr__(self) -> str: def __repr__(self) -> str:
if len(self.values) == 1:
return repr(self.values[0])
return f"{self.__class__.__name__}({self.values})" return f"{self.__class__.__name__}({self.values})"
def __len__(self) -> int: def __len__(self) -> int:
@@ -401,38 +403,42 @@ def create_filter(_vfields: dict[str, VariableParameter]):
return filter return filter
def vdataclass(cls: type[T]) -> type[T]: if TYPE_CHECKING:
_vfields: dict[str, VariableParameter] = { vdataclass = dataclass
k: v.default else:
for k, v in cls.__dict__.items()
if isinstance(getattr(v, "default", None), VariableParameter)
}
if len(_vfields) == 0:
warnings.warn(
f"class {cls.__qualname__} doesn't contain any variable parameter,"
"building normal dataclass instead"
)
return dataclass(cls)
v_nums = [] def vdataclass(cls: type[T]) -> type[T]:
for v in _vfields.values(): _vfields: dict[str, VariableParameter] = {
if v.sync is not None: k: v.default
v.place = v.sync.place for k, v in cls.__dict__.items()
else: if isinstance(getattr(v, "default", None), VariableParameter)
v.place = len(v_nums) }
v_nums.append(None) if len(_vfields) == 0:
warnings.warn(
f"class {cls.__qualname__} doesn't contain any variable parameter,"
"building normal dataclass instead"
)
return dataclass(cls)
# put __variables_nums__ at the begining of the dict, as other fields need it v_nums = []
cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__ for v in _vfields.values():
cls.__variables_nums__ = field(init=False, repr=False, default_factory=v_nums.copy) if v.sync is not None:
cls.__vfields__ = _vfields v.place = v.sync.place
else:
v.place = len(v_nums)
v_nums.append(None)
cls = dataclass(cls) # put __variables_nums__ at the begining of the dict, as other fields need it
cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__
cls.__variables_nums__ = field(init=False, repr=False, default_factory=v_nums.copy)
cls.__vfields__ = _vfields
def _config_len(self): cls = dataclass(cls)
return np.prod(self.__variables_nums__)
setattr(cls, "filter", create_filter(_vfields)) def _config_len(self):
setattr(cls, "__len__", _config_len) return np.prod(self.__variables_nums__)
return cls setattr(cls, "filter", create_filter(_vfields))
setattr(cls, "__len__", _config_len)
return cls