added parameter copy

This commit is contained in:
Benoît Sierro
2023-08-03 13:42:45 +02:00
parent 18839d4528
commit 7b6e33ca0f

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import datetime as datetime_module import datetime as datetime_module
import json import json
import os import os
from copy import copy from copy import copy, deepcopy
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from functools import lru_cache, wraps from functools import lru_cache, wraps
from math import isnan from math import isnan
@@ -468,9 +468,12 @@ class Parameters:
) )
object.__setattr__(self, k, v) object.__setattr__(self, k, v)
def copy(self) -> Parameters:
return Parameters(**deepcopy(self.strip_params_dict()))
def to_json(self) -> str: def to_json(self) -> str:
d = self.dump_dict() d = self.dump_dict()
return json.dumps(d, cls=DatetimeEncoder) return json.dumps(d, cls=DatetimeEncoder, default=list)
def get_evaluator(self): def get_evaluator(self):
evaluator = Evaluator.default(self.full_field) evaluator = Evaluator.default(self.full_field)
@@ -478,7 +481,7 @@ class Parameters:
return evaluator return evaluator
def dump_dict(self, add_metadata=True) -> dict[str, Any]: def dump_dict(self, add_metadata=True) -> dict[str, Any]:
param = Parameters.strip_params_dict() param = self.strip_params_dict()
if add_metadata: if add_metadata:
param["datetime"] = datetime_module.datetime.now() param["datetime"] = datetime_module.datetime.now()
param["version"] = __version__ param["version"] = __version__
@@ -579,15 +582,15 @@ class Parameters:
max_right = max(len(el[1]) for el in p_pairs) max_right = max(len(el[1]) for el in p_pairs)
return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs) return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs)
def strip_params_dict(self) -> dict[str, Any]: def strip_params_dict(self, copy=False) -> dict[str, Any]:
""" """
prepares a dictionary for serialization. Some keys may not be preserved prepares a dictionary for serialization. Some keys may not be preserved
(dropped because they take a lot of space and can be exactly reconstructed) (dropped because they take a lot of space and can be exactly reconstructed)
Parameters Parameters
---------- ----------
dico : dict copy : bool, optional
dictionary whether to deepcopy each value, by default False
""" """
forbiden_keys = { forbiden_keys = {
"_param_dico", "_param_dico",
@@ -608,25 +611,19 @@ class Parameters:
"linear_op", "linear_op",
"c_to_a_factor", "c_to_a_factor",
} }
types = (np.ndarray, float, int, str, list, tuple, dict, Path) types = (np.ndarray, float, int, str, list, tuple, Path)
c = deepcopy if copy else lambda x: x
out = {} out = {}
for key, value in self._param_dico.items(): for key, value in self._param_dico.items():
if key in forbiden_keys or key not in self._p_names: if key in forbiden_keys or key not in self._p_names:
continue continue
if not isinstance(value, types): if not isinstance(value, types):
continue continue
if isinstance(value, dict):
out[key] = Parameters.strip_params_dict(value)
elif isinstance(value, Path): elif isinstance(value, Path):
out[key] = str(value) out[key] = str(value)
elif isinstance(value, np.ndarray) and value.dtype == complex: elif isinstance(value, np.ndarray) and value.dtype == complex:
continue continue
elif isinstance(value, np.ndarray):
out[key] = value.tolist()
else: else:
out[key] = value out[key] = c(value)
if "variable" in out and len(out["variable"]) == 0:
del out["variable"]
return out return out