dataclass completion for vdataclass
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import itertools
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar
|
from typing import Callable, Generic, Iterator, ParamSpec, Sequence, TypeVar, TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -51,6 +51,8 @@ class Variable(Generic[T]):
|
|||||||
return self.values[id]
|
return self.values[id]
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
if len(self.values) == 1:
|
||||||
|
return repr(self.values[0])
|
||||||
return f"{self.__class__.__name__}({self.values})"
|
return f"{self.__class__.__name__}({self.values})"
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@@ -401,7 +403,11 @@ def create_filter(_vfields: dict[str, VariableParameter]):
|
|||||||
return filter
|
return filter
|
||||||
|
|
||||||
|
|
||||||
def vdataclass(cls: type[T]) -> type[T]:
|
if TYPE_CHECKING:
|
||||||
|
vdataclass = dataclass
|
||||||
|
else:
|
||||||
|
|
||||||
|
def vdataclass(cls: type[T]) -> type[T]:
|
||||||
_vfields: dict[str, VariableParameter] = {
|
_vfields: dict[str, VariableParameter] = {
|
||||||
k: v.default
|
k: v.default
|
||||||
for k, v in cls.__dict__.items()
|
for k, v in cls.__dict__.items()
|
||||||
|
|||||||
Reference in New Issue
Block a user