change: better params saving

This commit is contained in:
Benoît Sierro
2023-08-23 11:23:04 +02:00
parent 03f5b9c23d
commit 8c84487465

View File

@@ -218,6 +218,7 @@ class Parameter:
converter: Callable = None,
default=None,
display_info: tuple[float, str] = None,
can_pickle: bool = True,
):
"""
Single parameter
@@ -241,6 +242,7 @@ class Parameter:
self.converter = converter
self.default = default
self.display_info = display_info
self.can_pickle = can_pickle
def __set_name__(self, owner: Type[Parameters], name):
self.name = name
@@ -407,11 +409,11 @@ class Parameters:
worker_num: int = Parameter(positive(int))
# computed
linear_operator: SpecOperator = Parameter(is_function)
nonlinear_operator: SpecOperator = Parameter(is_function)
conserved_quantity: Qualifier = Parameter(is_function)
fft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function)
ifft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function)
linear_operator: SpecOperator = Parameter(is_function, can_pickle=False)
nonlinear_operator: SpecOperator = Parameter(is_function, can_pickle=False)
conserved_quantity: Qualifier = Parameter(is_function, can_pickle=False)
fft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function, can_pickle=False)
ifft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function, can_pickle=False)
field_0: np.ndarray = Parameter(type_checker(np.ndarray))
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
beta2: float = Parameter(type_checker(int, float))
@@ -457,12 +459,11 @@ class Parameters:
yield from (f"{k}={v}" for k, v in self.dump_dict().items())
def __getstate__(self) -> dict[str, Any]:
return self.dump_dict(add_metadata=False)
return {k: v for k, v in self._param_dico.items() if getattr(self.__class__, k).can_pickle}
def __setstate__(self, dumped_dict: dict[str, Any]):
self._param_dico = dict()
for k, v in dumped_dict.items():
setattr(self, k, v)
def __setstate__(self, param_dico: dict[str, Any]):
self._param_dico = param_dico
self.frozen = False
def __setattr__(self, k, v):
if self.frozen and not k.endswith("_file"):
@@ -649,9 +650,12 @@ class Parameters:
"alpha",
"gamma_arr",
"effective_area_arr",
"input_time",
"input_field",
"nonlinear_op",
"linear_op",
"c_to_a_factor",
"hr_w",
}
types = (np.ndarray, float, int, str, list, tuple, Path, DataFile)
c = deepcopy if copy else lambda x: x