dataclass completion for vdataclass
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import itertools
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar
|
||||
from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -51,6 +51,8 @@ class Variable(Generic[T]):
|
||||
return self.values[id]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if len(self.values) == 1:
|
||||
return repr(self.values[0])
|
||||
return f"{self.__class__.__name__}({self.values})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -401,38 +403,42 @@ def create_filter(_vfields: dict[str, VariableParameter]):
|
||||
return filter
|
||||
|
||||
|
||||
def vdataclass(cls: type[T]) -> type[T]:
|
||||
_vfields: dict[str, VariableParameter] = {
|
||||
k: v.default
|
||||
for k, v in cls.__dict__.items()
|
||||
if isinstance(getattr(v, "default", None), VariableParameter)
|
||||
}
|
||||
if len(_vfields) == 0:
|
||||
warnings.warn(
|
||||
f"class {cls.__qualname__} doesn't contain any variable parameter,"
|
||||
"building normal dataclass instead"
|
||||
)
|
||||
return dataclass(cls)
|
||||
if TYPE_CHECKING:
|
||||
vdataclass = dataclass
|
||||
else:
|
||||
|
||||
v_nums = []
|
||||
for v in _vfields.values():
|
||||
if v.sync is not None:
|
||||
v.place = v.sync.place
|
||||
else:
|
||||
v.place = len(v_nums)
|
||||
v_nums.append(None)
|
||||
def vdataclass(cls: type[T]) -> type[T]:
|
||||
_vfields: dict[str, VariableParameter] = {
|
||||
k: v.default
|
||||
for k, v in cls.__dict__.items()
|
||||
if isinstance(getattr(v, "default", None), VariableParameter)
|
||||
}
|
||||
if len(_vfields) == 0:
|
||||
warnings.warn(
|
||||
f"class {cls.__qualname__} doesn't contain any variable parameter,"
|
||||
"building normal dataclass instead"
|
||||
)
|
||||
return dataclass(cls)
|
||||
|
||||
# put __variables_nums__ at the begining of the dict, as other fields need it
|
||||
cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__
|
||||
cls.__variables_nums__ = field(init=False, repr=False, default_factory=v_nums.copy)
|
||||
cls.__vfields__ = _vfields
|
||||
v_nums = []
|
||||
for v in _vfields.values():
|
||||
if v.sync is not None:
|
||||
v.place = v.sync.place
|
||||
else:
|
||||
v.place = len(v_nums)
|
||||
v_nums.append(None)
|
||||
|
||||
cls = dataclass(cls)
|
||||
# put __variables_nums__ at the begining of the dict, as other fields need it
|
||||
cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__
|
||||
cls.__variables_nums__ = field(init=False, repr=False, default_factory=v_nums.copy)
|
||||
cls.__vfields__ = _vfields
|
||||
|
||||
def _config_len(self):
|
||||
return np.prod(self.__variables_nums__)
|
||||
cls = dataclass(cls)
|
||||
|
||||
setattr(cls, "filter", create_filter(_vfields))
|
||||
setattr(cls, "__len__", _config_len)
|
||||
def _config_len(self):
|
||||
return np.prod(self.__variables_nums__)
|
||||
|
||||
return cls
|
||||
setattr(cls, "filter", create_filter(_vfields))
|
||||
setattr(cls, "__len__", _config_len)
|
||||
|
||||
return cls
|
||||
|
||||
Reference in New Issue
Block a user