dataclass completion for vdataclass
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import itertools
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar
|
from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar, TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -51,6 +51,8 @@ class Variable(Generic[T]):
|
|||||||
return self.values[id]
|
return self.values[id]
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
if len(self.values) == 1:
|
||||||
|
return repr(self.values[0])
|
||||||
return f"{self.__class__.__name__}({self.values})"
|
return f"{self.__class__.__name__}({self.values})"
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@@ -401,38 +403,42 @@ def create_filter(_vfields: dict[str, VariableParameter]):
|
|||||||
return filter
|
return filter
|
||||||
|
|
||||||
|
|
||||||
def vdataclass(cls: type[T]) -> type[T]:
|
if TYPE_CHECKING:
|
||||||
_vfields: dict[str, VariableParameter] = {
|
vdataclass = dataclass
|
||||||
k: v.default
|
else:
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if isinstance(getattr(v, "default", None), VariableParameter)
|
|
||||||
}
|
|
||||||
if len(_vfields) == 0:
|
|
||||||
warnings.warn(
|
|
||||||
f"class {cls.__qualname__} doesn't contain any variable parameter,"
|
|
||||||
"building normal dataclass instead"
|
|
||||||
)
|
|
||||||
return dataclass(cls)
|
|
||||||
|
|
||||||
v_nums = []
|
def vdataclass(cls: type[T]) -> type[T]:
|
||||||
for v in _vfields.values():
|
_vfields: dict[str, VariableParameter] = {
|
||||||
if v.sync is not None:
|
k: v.default
|
||||||
v.place = v.sync.place
|
for k, v in cls.__dict__.items()
|
||||||
else:
|
if isinstance(getattr(v, "default", None), VariableParameter)
|
||||||
v.place = len(v_nums)
|
}
|
||||||
v_nums.append(None)
|
if len(_vfields) == 0:
|
||||||
|
warnings.warn(
|
||||||
|
f"class {cls.__qualname__} doesn't contain any variable parameter,"
|
||||||
|
"building normal dataclass instead"
|
||||||
|
)
|
||||||
|
return dataclass(cls)
|
||||||
|
|
||||||
# put __variables_nums__ at the begining of the dict, as other fields need it
|
v_nums = []
|
||||||
cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__
|
for v in _vfields.values():
|
||||||
cls.__variables_nums__ = field(init=False, repr=False, default_factory=v_nums.copy)
|
if v.sync is not None:
|
||||||
cls.__vfields__ = _vfields
|
v.place = v.sync.place
|
||||||
|
else:
|
||||||
|
v.place = len(v_nums)
|
||||||
|
v_nums.append(None)
|
||||||
|
|
||||||
cls = dataclass(cls)
|
# put __variables_nums__ at the begining of the dict, as other fields need it
|
||||||
|
cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__
|
||||||
|
cls.__variables_nums__ = field(init=False, repr=False, default_factory=v_nums.copy)
|
||||||
|
cls.__vfields__ = _vfields
|
||||||
|
|
||||||
def _config_len(self):
|
cls = dataclass(cls)
|
||||||
return np.prod(self.__variables_nums__)
|
|
||||||
|
|
||||||
setattr(cls, "filter", create_filter(_vfields))
|
def _config_len(self):
|
||||||
setattr(cls, "__len__", _config_len)
|
return np.prod(self.__variables_nums__)
|
||||||
|
|
||||||
return cls
|
setattr(cls, "filter", create_filter(_vfields))
|
||||||
|
setattr(cls, "__len__", _config_len)
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|||||||
Reference in New Issue
Block a user