From 7b6e33ca0f7cb9c14257b2f81e6470bc204f99d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Thu, 3 Aug 2023 13:42:45 +0200 Subject: [PATCH] added parameter copy --- src/scgenerator/parameter.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 2223fff..71318ba 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -3,7 +3,7 @@ from __future__ import annotations import datetime as datetime_module import json import os -from copy import copy +from copy import copy, deepcopy from dataclasses import dataclass, field, fields from functools import lru_cache, wraps from math import isnan @@ -468,9 +468,12 @@ class Parameters: ) object.__setattr__(self, k, v) + def copy(self) -> Parameters: + return Parameters(**deepcopy(self.strip_params_dict())) + def to_json(self) -> str: d = self.dump_dict() - return json.dumps(d, cls=DatetimeEncoder) + return json.dumps(d, cls=DatetimeEncoder, default=list) def get_evaluator(self): evaluator = Evaluator.default(self.full_field) @@ -478,7 +481,7 @@ class Parameters: return evaluator def dump_dict(self, add_metadata=True) -> dict[str, Any]: - param = Parameters.strip_params_dict() + param = self.strip_params_dict() if add_metadata: param["datetime"] = datetime_module.datetime.now() param["version"] = __version__ @@ -579,15 +582,15 @@ class Parameters: max_right = max(len(el[1]) for el in p_pairs) return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs) - def strip_params_dict(self) -> dict[str, Any]: + def strip_params_dict(self, copy=False) -> dict[str, Any]: """ prepares a dictionary for serialization. Some keys may not be preserved (dropped because they take a lot of space and can be exactly reconstructed) Parameters ---------- - dico : dict - dictionary + copy : bool, optional + whether to deepcopy each value, by default False """ forbiden_keys = { "_param_dico", @@ -608,25 +611,19 @@ class Parameters: "linear_op", "c_to_a_factor", } - types = (np.ndarray, float, int, str, list, tuple, dict, Path) + types = (np.ndarray, float, int, str, list, tuple, Path) + c = deepcopy if copy else lambda x: x out = {} for key, value in self._param_dico.items(): if key in forbiden_keys or key not in self._p_names: continue if not isinstance(value, types): continue - if isinstance(value, dict): - out[key] = Parameters.strip_params_dict(value) elif isinstance(value, Path): out[key] = str(value) elif isinstance(value, np.ndarray) and value.dtype == complex: continue - elif isinstance(value, np.ndarray): - out[key] = value.tolist() else: - out[key] = value - - if "variable" in out and len(out["variable"]) == 0: - del out["variable"] + out[key] = c(value) return out