change: better params saving
This commit is contained in:
@@ -218,6 +218,7 @@ class Parameter:
|
|||||||
converter: Callable = None,
|
converter: Callable = None,
|
||||||
default=None,
|
default=None,
|
||||||
display_info: tuple[float, str] = None,
|
display_info: tuple[float, str] = None,
|
||||||
|
can_pickle: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Single parameter
|
Single parameter
|
||||||
@@ -241,6 +242,7 @@ class Parameter:
|
|||||||
self.converter = converter
|
self.converter = converter
|
||||||
self.default = default
|
self.default = default
|
||||||
self.display_info = display_info
|
self.display_info = display_info
|
||||||
|
self.can_pickle = can_pickle
|
||||||
|
|
||||||
def __set_name__(self, owner: Type[Parameters], name):
|
def __set_name__(self, owner: Type[Parameters], name):
|
||||||
self.name = name
|
self.name = name
|
||||||
@@ -407,11 +409,11 @@ class Parameters:
|
|||||||
worker_num: int = Parameter(positive(int))
|
worker_num: int = Parameter(positive(int))
|
||||||
|
|
||||||
# computed
|
# computed
|
||||||
linear_operator: SpecOperator = Parameter(is_function)
|
linear_operator: SpecOperator = Parameter(is_function, can_pickle=False)
|
||||||
nonlinear_operator: SpecOperator = Parameter(is_function)
|
nonlinear_operator: SpecOperator = Parameter(is_function, can_pickle=False)
|
||||||
conserved_quantity: Qualifier = Parameter(is_function)
|
conserved_quantity: Qualifier = Parameter(is_function, can_pickle=False)
|
||||||
fft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function)
|
fft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function, can_pickle=False)
|
||||||
ifft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function)
|
ifft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function, can_pickle=False)
|
||||||
field_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
field_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
beta2: float = Parameter(type_checker(int, float))
|
beta2: float = Parameter(type_checker(int, float))
|
||||||
@@ -457,12 +459,11 @@ class Parameters:
|
|||||||
yield from (f"{k}={v}" for k, v in self.dump_dict().items())
|
yield from (f"{k}={v}" for k, v in self.dump_dict().items())
|
||||||
|
|
||||||
def __getstate__(self) -> dict[str, Any]:
|
def __getstate__(self) -> dict[str, Any]:
|
||||||
return self.dump_dict(add_metadata=False)
|
return {k: v for k, v in self._param_dico.items() if getattr(self.__class__, k).can_pickle}
|
||||||
|
|
||||||
def __setstate__(self, dumped_dict: dict[str, Any]):
|
def __setstate__(self, param_dico: dict[str, Any]):
|
||||||
self._param_dico = dict()
|
self._param_dico = param_dico
|
||||||
for k, v in dumped_dict.items():
|
self.frozen = False
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
def __setattr__(self, k, v):
|
def __setattr__(self, k, v):
|
||||||
if self.frozen and not k.endswith("_file"):
|
if self.frozen and not k.endswith("_file"):
|
||||||
@@ -649,9 +650,12 @@ class Parameters:
|
|||||||
"alpha",
|
"alpha",
|
||||||
"gamma_arr",
|
"gamma_arr",
|
||||||
"effective_area_arr",
|
"effective_area_arr",
|
||||||
|
"input_time",
|
||||||
|
"input_field",
|
||||||
"nonlinear_op",
|
"nonlinear_op",
|
||||||
"linear_op",
|
"linear_op",
|
||||||
"c_to_a_factor",
|
"c_to_a_factor",
|
||||||
|
"hr_w",
|
||||||
}
|
}
|
||||||
types = (np.ndarray, float, int, str, list, tuple, Path, DataFile)
|
types = (np.ndarray, float, int, str, list, tuple, Path, DataFile)
|
||||||
c = deepcopy if copy else lambda x: x
|
c = deepcopy if copy else lambda x: x
|
||||||
|
|||||||
Reference in New Issue
Block a user