added synced parameters
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "scgenerator"
|
name = "scgenerator"
|
||||||
version = "0.3.20"
|
version = "0.3.21"
|
||||||
description = "Simulate nonlinear pulse propagation in optical fibers"
|
description = "Simulate nonlinear pulse propagation in optical fibers"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
||||||
@@ -25,6 +25,10 @@ dependencies = [
|
|||||||
"pydantic-settings",
|
"pydantic-settings",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
cli = ["click"]
|
||||||
|
test = ["pytest"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 100
|
line-length = 100
|
||||||
ignore = ["E741"]
|
ignore = ["E741"]
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@@ -11,6 +13,196 @@ T = TypeVar("T")
|
|||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Variable(Generic[T]):
|
||||||
|
name: str
|
||||||
|
values: Sequence[T]
|
||||||
|
place: int
|
||||||
|
all_nums: list[int]
|
||||||
|
suffix: str = ""
|
||||||
|
decimals: int = 4
|
||||||
|
formatter: Callable[[int], str] = field(init=False)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return {k: v for k, v in self.__dict__.items() if k != "formatter"}
|
||||||
|
|
||||||
|
def __setstate__(self, d):
|
||||||
|
self.__dict__ |= d
|
||||||
|
self.__post_init__()
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.formatter = unit_formatter(self.suffix, self.decimals)
|
||||||
|
|
||||||
|
def __call__(self, id: int, local: bool = False) -> T:
|
||||||
|
if id < 0:
|
||||||
|
raise IndexError("negative indices are not allowed")
|
||||||
|
if not local:
|
||||||
|
if not isinstance(id, (int, np.integer)):
|
||||||
|
raise TypeError(f"id {id!r} is not an integer")
|
||||||
|
if None in self.all_nums:
|
||||||
|
raise ValueError("at least one VariableParameter has not been configured")
|
||||||
|
id = debase(id, self.all_nums)[self.place]
|
||||||
|
if id >= len(self):
|
||||||
|
raise IndexError(f"id {id} is too big for variable parameter of size {len(self)}")
|
||||||
|
return self.values[id]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}({self.values})"
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.values)
|
||||||
|
|
||||||
|
def format(self, id: int, local: bool = False):
|
||||||
|
return f"{self.name}={self.formatter(self(id, local))}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(unsafe_hash=True)
|
||||||
|
class VariableParameter:
|
||||||
|
default: float | int | None = None
|
||||||
|
suffix: str = ""
|
||||||
|
decimals: int = 4
|
||||||
|
sync: VariableParameter | None = None
|
||||||
|
default_sequence: Variable | None = field(init=False)
|
||||||
|
place: int | None = field(default=None, init=False)
|
||||||
|
public_name: str = field(init=False)
|
||||||
|
private_name: str = field(init=False)
|
||||||
|
|
||||||
|
def __set_name__(self, owner: type, name: str):
|
||||||
|
self.public_name = name
|
||||||
|
self.private_name = "_variable__" + name
|
||||||
|
if self.default is not None:
|
||||||
|
self.default_sequence = get_sequence(self.default)
|
||||||
|
else:
|
||||||
|
self.default_sequence = None
|
||||||
|
|
||||||
|
def __set__(self, instance, value):
|
||||||
|
all_nums = instance.__variables_nums__
|
||||||
|
if value is self:
|
||||||
|
if self.default_sequence is None:
|
||||||
|
raise self._error_no_default()
|
||||||
|
all_nums[self.place] = len(self.default_sequence)
|
||||||
|
return
|
||||||
|
|
||||||
|
sequence = get_sequence(value)
|
||||||
|
|
||||||
|
self._check_sync(instance, sequence)
|
||||||
|
var_obj = Variable(
|
||||||
|
name=self.public_name,
|
||||||
|
values=sequence,
|
||||||
|
place=self.place,
|
||||||
|
suffix=self.suffix,
|
||||||
|
decimals=self.decimals,
|
||||||
|
all_nums=all_nums,
|
||||||
|
)
|
||||||
|
all_nums[self.place] = len(var_obj)
|
||||||
|
instance.__dict__[self.private_name] = var_obj
|
||||||
|
|
||||||
|
def __get__(self, instance, _):
|
||||||
|
if instance is None:
|
||||||
|
return self
|
||||||
|
if self.private_name not in instance.__dict__:
|
||||||
|
self._create_default(instance)
|
||||||
|
|
||||||
|
return instance.__dict__[self.private_name]
|
||||||
|
|
||||||
|
def _create_default(self, instance):
|
||||||
|
all_nums = instance.__variables_nums__
|
||||||
|
if self.default_sequence is None:
|
||||||
|
raise self._error_no_default()
|
||||||
|
self._check_sync(instance, self.default_sequence)
|
||||||
|
obj = Variable(
|
||||||
|
name=self.public_name,
|
||||||
|
values=self.default_sequence,
|
||||||
|
place=self.place,
|
||||||
|
suffix=self.suffix,
|
||||||
|
decimals=self.decimals,
|
||||||
|
all_nums=all_nums,
|
||||||
|
)
|
||||||
|
all_nums[self.place] = len(obj)
|
||||||
|
instance.__dict__[self.private_name] = obj
|
||||||
|
|
||||||
|
def _check_sync(self, instance, sequence):
|
||||||
|
if self.sync is not None and (this_len := len(sequence)) != (
|
||||||
|
other_len := len(getattr(instance, self.sync.private_name))
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"sequence of len {this_len} doesn't match syncronized sequence of len {other_len}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _error_no_default(self) -> ValueError:
|
||||||
|
return ValueError(
|
||||||
|
f"Variable parameter {self.public_name!r} has not been initialized "
|
||||||
|
"and no default has been set"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Filter:
|
||||||
|
indices: list[int]
|
||||||
|
shape: tuple[int, ...]
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
yield from self.indices
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, arr: Sequence | np.ndarray, axis: int = 0, squeeze: bool = True
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
filters and reshapes an array that has already been filtered
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
arr : Sequence | np.ndarray, shape (..., n, ...)
|
||||||
|
at index `axis` ^
|
||||||
|
array to reshape
|
||||||
|
axis : int, optional
|
||||||
|
which axis to reshape, by default 0
|
||||||
|
squeeze : bool, optional
|
||||||
|
squeeze the array to get rid of axes of size 1
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
arr : np.ndarray, shape (..., *shape, ...) where `np.prod(shape) == n`
|
||||||
|
"""
|
||||||
|
arr = np.asarray(arr)
|
||||||
|
arr = arr[*((slice(None),) * axis), self.indices]
|
||||||
|
return self.reshape(arr, axis, squeeze)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.indices)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.indices[key]
|
||||||
|
|
||||||
|
def reshape(
|
||||||
|
self, arr: Sequence | np.ndarray, axis: int = 0, squeeze: bool = True
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
reshapes an array that has already been filtered
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
arr : Sequence | np.ndarray, shape (..., n, ...)
|
||||||
|
at index `axis` ^
|
||||||
|
array to reshape
|
||||||
|
axis : int, optional
|
||||||
|
which axis to reshape
|
||||||
|
squeeze : bool, optional
|
||||||
|
squeeze the array to get rid of axes of size 1
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
arr : np.ndarray, shape (..., *shape, ...) where `np.prod(shape) == n`
|
||||||
|
"""
|
||||||
|
arr = np.asanyarray(arr)
|
||||||
|
if axis < 0:
|
||||||
|
axis = len(self.shape) + axis
|
||||||
|
arr = arr.reshape(*arr.shape[:axis], *self.shape, *arr.shape[axis + 1 :])
|
||||||
|
if squeeze:
|
||||||
|
arr = arr.squeeze()
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
def debase(num: int, base: tuple[int, ...], strict: bool = True) -> tuple[int, ...]:
|
def debase(num: int, base: tuple[int, ...], strict: bool = True) -> tuple[int, ...]:
|
||||||
"""
|
"""
|
||||||
decomboses a number into its digits in a variable base
|
decomboses a number into its digits in a variable base
|
||||||
@@ -105,191 +297,17 @@ def get_sequence(value) -> np.ndarray:
|
|||||||
return np.array([value])
|
return np.array([value])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Variable(Generic[T]):
|
|
||||||
name: str
|
|
||||||
values: Sequence[T]
|
|
||||||
place: int
|
|
||||||
all_nums: list[int]
|
|
||||||
suffix: str = ""
|
|
||||||
decimals: int = 4
|
|
||||||
formatter: Callable[[int], str] = field(init=False)
|
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
return {k: v for k, v in self.__dict__.items() if k != "formatter"}
|
|
||||||
|
|
||||||
def __setstate__(self, d):
|
|
||||||
self.__dict__ |= d
|
|
||||||
self.__post_init__()
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self.formatter = unit_formatter(self.suffix, self.decimals)
|
|
||||||
|
|
||||||
def __call__(self, id: int, local: bool = False) -> T:
|
|
||||||
if id < 0:
|
|
||||||
raise IndexError("negative indices are not allowed")
|
|
||||||
if not local:
|
|
||||||
if not isinstance(id, (int, np.integer)):
|
|
||||||
raise TypeError(f"id {id!r} is not an integer")
|
|
||||||
if None in self.all_nums:
|
|
||||||
raise ValueError("at least one VariableParameter has not been configured")
|
|
||||||
id = debase(id, self.all_nums)[self.place]
|
|
||||||
if id >= len(self):
|
|
||||||
raise IndexError(f"id {id} is too big for variable parameter of size {len(self)}")
|
|
||||||
return self.values[id]
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"{self.__class__.__name__}({self.values})"
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.values)
|
|
||||||
|
|
||||||
def parse(self, s: str):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def format(self, id: int, local: bool = False):
|
|
||||||
return f"{self.name}={self.formatter(self(id, local))}"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(unsafe_hash=True)
|
|
||||||
class VariableParameter:
|
|
||||||
func: Callable[[int, float, float, int], float] = get_linear
|
|
||||||
default: float | int | None = None
|
|
||||||
suffix: str = ""
|
|
||||||
decimals: int = 4
|
|
||||||
default_sequence: Variable | None = field(init=False)
|
|
||||||
place: int | None = field(default=None, init=False)
|
|
||||||
public_name: str = field(init=False)
|
|
||||||
private_name: str = field(init=False)
|
|
||||||
|
|
||||||
def __set_name__(self, owner: type, name: str):
|
|
||||||
self.public_name = name
|
|
||||||
self.private_name = "_variable__" + name
|
|
||||||
if self.default is not None:
|
|
||||||
self.default_sequence = get_sequence(self.default)
|
|
||||||
else:
|
|
||||||
self.default_sequence = None
|
|
||||||
|
|
||||||
def __set__(self, instance, value):
|
|
||||||
all_nums = instance.__variables_nums__
|
|
||||||
if value is self:
|
|
||||||
if self.default_sequence is None:
|
|
||||||
raise self._error_no_default()
|
|
||||||
all_nums[self.place] = len(self.default_sequence)
|
|
||||||
return
|
|
||||||
|
|
||||||
var_obj = Variable(
|
|
||||||
name=self.public_name,
|
|
||||||
values=get_sequence(value),
|
|
||||||
place=self.place,
|
|
||||||
suffix=self.suffix,
|
|
||||||
decimals=self.decimals,
|
|
||||||
all_nums=all_nums,
|
|
||||||
)
|
|
||||||
all_nums[self.place] = len(var_obj)
|
|
||||||
instance.__dict__[self.private_name] = var_obj
|
|
||||||
|
|
||||||
def __get__(self, instance, _):
|
|
||||||
if instance is None:
|
|
||||||
return self
|
|
||||||
if self.private_name not in instance.__dict__:
|
|
||||||
all_nums = instance.__variables_nums__
|
|
||||||
if self.default_sequence is None:
|
|
||||||
raise self._error_no_default()
|
|
||||||
obj = Variable(
|
|
||||||
name=self.public_name,
|
|
||||||
values=self.default_sequence,
|
|
||||||
place=self.place,
|
|
||||||
suffix=self.suffix,
|
|
||||||
decimals=self.decimals,
|
|
||||||
all_nums=all_nums,
|
|
||||||
)
|
|
||||||
all_nums[self.place] = len(obj)
|
|
||||||
instance.__dict__[self.private_name] = obj
|
|
||||||
|
|
||||||
return instance.__dict__[self.private_name]
|
|
||||||
|
|
||||||
def _error_no_default(self) -> ValueError:
|
|
||||||
return ValueError(
|
|
||||||
f"Variable parameter {self.public_name!r} has not been initialized "
|
|
||||||
"and no default has been set"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Filter:
|
|
||||||
indices: list[int]
|
|
||||||
shape: tuple[int, ...]
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[int]:
|
|
||||||
yield from self.indices
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, arr: Sequence | np.ndarray, axis: int = 0, squeeze: bool = True
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
filters and reshapes an array that has already been filtered
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
arr : Sequence | np.ndarray, shape (..., n, ...)
|
|
||||||
at index `axis` ^
|
|
||||||
array to reshape
|
|
||||||
axis : int, optional
|
|
||||||
which axis to reshape, by default 0
|
|
||||||
squeeze : bool, optional
|
|
||||||
squeeze the array to get rid of axes of size 1
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
arr : np.ndarray, shape (..., *shape, ...) where `np.prod(shape) == n`
|
|
||||||
"""
|
|
||||||
arr = np.asarray(arr)
|
|
||||||
arr = arr[*((slice(None),) * axis), self.indices]
|
|
||||||
return self.reshape(arr, axis, squeeze)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.indices)
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return self.indices[key]
|
|
||||||
|
|
||||||
def reshape(
|
|
||||||
self, arr: Sequence | np.ndarray, axis: int = 0, squeeze: bool = True
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
reshapes an array that has already been filtered
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
arr : Sequence | np.ndarray, shape (..., n, ...)
|
|
||||||
at index `axis` ^
|
|
||||||
array to reshape
|
|
||||||
axis : int, optional
|
|
||||||
which axis to reshape
|
|
||||||
squeeze : bool, optional
|
|
||||||
squeeze the array to get rid of axes of size 1
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
arr : np.ndarray, shape (..., *shape, ...) where `np.prod(shape) == n`
|
|
||||||
"""
|
|
||||||
arr = np.asanyarray(arr)
|
|
||||||
if axis < 0:
|
|
||||||
axis = len(self.shape) + axis
|
|
||||||
arr = arr.reshape(*arr.shape[:axis], *self.shape, *arr.shape[axis + 1 :])
|
|
||||||
if squeeze:
|
|
||||||
arr = arr.squeeze()
|
|
||||||
return arr
|
|
||||||
|
|
||||||
|
|
||||||
def vfield(
|
def vfield(
|
||||||
func: Callable[[int, float, float, int], float] = get_linear,
|
|
||||||
default: float | int | None = None,
|
default: float | int | None = None,
|
||||||
suffix: str = "",
|
suffix: str = "",
|
||||||
decimals: int = 4,
|
decimals: int = 4,
|
||||||
|
sync: VariableParameter | None = None,
|
||||||
):
|
):
|
||||||
return field(default=VariableParameter(func, default, suffix, decimals))
|
return field(
|
||||||
|
default=VariableParameter(
|
||||||
|
default, suffix, decimals, sync.default if sync is not None else None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_filter(_vfields: dict[str, VariableParameter]):
|
def create_filter(_vfields: dict[str, VariableParameter]):
|
||||||
@@ -323,7 +341,7 @@ def create_filter(_vfields: dict[str, VariableParameter]):
|
|||||||
|
|
||||||
|
|
||||||
def vdataclass(cls: type[T]) -> type[T]:
|
def vdataclass(cls: type[T]) -> type[T]:
|
||||||
_vfields = {
|
_vfields: dict[str, VariableParameter] = {
|
||||||
k: v.default
|
k: v.default
|
||||||
for k, v in cls.__dict__.items()
|
for k, v in cls.__dict__.items()
|
||||||
if isinstance(getattr(v, "default", None), VariableParameter)
|
if isinstance(getattr(v, "default", None), VariableParameter)
|
||||||
@@ -334,16 +352,21 @@ def vdataclass(cls: type[T]) -> type[T]:
|
|||||||
"building normal dataclass instead"
|
"building normal dataclass instead"
|
||||||
)
|
)
|
||||||
return dataclass(cls)
|
return dataclass(cls)
|
||||||
|
|
||||||
|
v_nums = []
|
||||||
|
for v in _vfields.values():
|
||||||
|
if v.sync is not None:
|
||||||
|
v.place = v.sync.place
|
||||||
|
else:
|
||||||
|
v.place = len(v_nums)
|
||||||
|
v_nums.append(None)
|
||||||
|
|
||||||
# put __variables_nums__ at the begining of the dict, as other fields need it
|
# put __variables_nums__ at the begining of the dict, as other fields need it
|
||||||
cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__
|
cls.__annotations__ = dict(__variables_nums__="list[int]") | cls.__annotations__
|
||||||
cls.__variables_nums__ = field(
|
cls.__variables_nums__ = field(init=False, repr=False, default_factory=v_nums.copy)
|
||||||
init=False, repr=False, default_factory=(lambda: [None] * len(_vfields))
|
|
||||||
)
|
|
||||||
cls.__vfields__ = _vfields
|
cls.__vfields__ = _vfields
|
||||||
|
|
||||||
cls = dataclass(cls)
|
cls = dataclass(cls)
|
||||||
for i, v in enumerate(_vfields.values()):
|
|
||||||
v.place = i
|
|
||||||
|
|
||||||
def _config_len(self):
|
def _config_len(self):
|
||||||
return np.prod(self.__variables_nums__)
|
return np.prod(self.__variables_nums__)
|
||||||
|
|||||||
@@ -102,6 +102,53 @@ def test_constant_list():
|
|||||||
conf.y(6)
|
conf.y(6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_synchronize():
|
||||||
|
"""synchronize 2 variable fields"""
|
||||||
|
|
||||||
|
@vdataclass
|
||||||
|
class Conf:
|
||||||
|
x: Variable = vfield(default=[1, 2, 7])
|
||||||
|
y: Variable = vfield(default=[2, 0, 0], sync=x)
|
||||||
|
|
||||||
|
conf = Conf()
|
||||||
|
assert conf.x(0) == 1
|
||||||
|
assert conf.x(1) == 2
|
||||||
|
assert conf.x(2) == 7
|
||||||
|
assert conf.y(0) == 2
|
||||||
|
assert conf.y(1) == 0
|
||||||
|
assert conf.y(2) == 0
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
conf.y(3)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
conf.x(3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_strict():
|
||||||
|
@vdataclass
|
||||||
|
class Conf:
|
||||||
|
x: Variable = vfield(default=[1, 2, 7])
|
||||||
|
y: Variable = vfield(default=[2, 0], sync=x)
|
||||||
|
|
||||||
|
conf = Conf()
|
||||||
|
assert conf.x(2) == 7
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
conf.y(0)
|
||||||
|
|
||||||
|
conf.x = [78, 55]
|
||||||
|
assert len(conf) == 2
|
||||||
|
assert conf.y(1) == 0
|
||||||
|
assert conf.x(1) == 55
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
conf.y = [1, 2, 3]
|
||||||
|
|
||||||
|
conf.x = [11, 22, 33, 44, 55]
|
||||||
|
conf.y = [0, 1, 33, 111, 1111]
|
||||||
|
assert conf.x(2) == conf.y(2) == 33
|
||||||
|
assert len(conf) == len(conf.x) == len(conf.y) == 5
|
||||||
|
|
||||||
|
|
||||||
def test_decoration():
|
def test_decoration():
|
||||||
"""proper construction of the class"""
|
"""proper construction of the class"""
|
||||||
|
|
||||||
@@ -275,15 +322,3 @@ def test_filter():
|
|||||||
|
|
||||||
mat = np.random.rand(70, 12, 11)
|
mat = np.random.rand(70, 12, 11)
|
||||||
assert f(mat.reshape(len(conf), 11), squeeze=False).shape == (12, 1, 1, 3, 11)
|
assert f(mat.reshape(len(conf), 11), squeeze=False).shape == (12, 1, 1, 3, 11)
|
||||||
|
|
||||||
|
|
||||||
def test_param_parsing():
|
|
||||||
"""parsing a string or path returns the correct values"""
|
|
||||||
|
|
||||||
|
|
||||||
def test_config_parsing():
|
|
||||||
"""parsing a string returns the exact id corresponding to the string/path"""
|
|
||||||
|
|
||||||
|
|
||||||
def test_repr():
|
|
||||||
"""print a config object correctly"""
|
|
||||||
|
|||||||
Reference in New Issue
Block a user