diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index e500b7a..d727eff 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -218,6 +218,7 @@ class Parameter: converter: Callable = None, default=None, display_info: tuple[float, str] = None, + can_pickle: bool = True, ): """ Single parameter @@ -241,6 +242,7 @@ class Parameter: self.converter = converter self.default = default self.display_info = display_info + self.can_pickle = can_pickle def __set_name__(self, owner: Type[Parameters], name): self.name = name @@ -407,11 +409,11 @@ class Parameters: worker_num: int = Parameter(positive(int)) # computed - linear_operator: SpecOperator = Parameter(is_function) - nonlinear_operator: SpecOperator = Parameter(is_function) - conserved_quantity: Qualifier = Parameter(is_function) - fft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function) - ifft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function) + linear_operator: SpecOperator = Parameter(is_function, can_pickle=False) + nonlinear_operator: SpecOperator = Parameter(is_function, can_pickle=False) + conserved_quantity: Qualifier = Parameter(is_function, can_pickle=False) + fft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function, can_pickle=False) + ifft: Callable[[np.ndarray], np.ndarray] = Parameter(is_function, can_pickle=False) field_0: np.ndarray = Parameter(type_checker(np.ndarray)) spec_0: np.ndarray = Parameter(type_checker(np.ndarray)) beta2: float = Parameter(type_checker(int, float)) @@ -457,12 +459,11 @@ class Parameters: yield from (f"{k}={v}" for k, v in self.dump_dict().items()) def __getstate__(self) -> dict[str, Any]: - return self.dump_dict(add_metadata=False) + return {k: v for k, v in self._param_dico.items() if getattr(self.__class__, k).can_pickle} - def __setstate__(self, dumped_dict: dict[str, Any]): - self._param_dico = dict() - for k, v in dumped_dict.items(): - setattr(self, k, v) + def __setstate__(self, param_dico: dict[str, Any]): + self._param_dico = param_dico + self.frozen = False def __setattr__(self, k, v): if self.frozen and not k.endswith("_file"): @@ -649,9 +650,12 @@ class Parameters: "alpha", "gamma_arr", "effective_area_arr", + "input_time", + "input_field", "nonlinear_op", "linear_op", "c_to_a_factor", + "hr_w", } types = (np.ndarray, float, int, str, list, tuple, Path, DataFile) c = deepcopy if copy else lambda x: x