added parameter copy

This commit is contained in:
Benoît Sierro
2023-08-03 13:42:45 +02:00
parent 18839d4528
commit 7b6e33ca0f

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import datetime as datetime_module
import json
import os
from copy import copy
from copy import copy, deepcopy
from dataclasses import dataclass, field, fields
from functools import lru_cache, wraps
from math import isnan
@@ -468,9 +468,12 @@ class Parameters:
)
object.__setattr__(self, k, v)
def copy(self) -> Parameters:
return Parameters(**deepcopy(self.strip_params_dict()))
def to_json(self) -> str:
d = self.dump_dict()
return json.dumps(d, cls=DatetimeEncoder)
return json.dumps(d, cls=DatetimeEncoder, default=list)
def get_evaluator(self):
evaluator = Evaluator.default(self.full_field)
@@ -478,7 +481,7 @@ class Parameters:
return evaluator
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
param = Parameters.strip_params_dict()
param = self.strip_params_dict()
if add_metadata:
param["datetime"] = datetime_module.datetime.now()
param["version"] = __version__
@@ -579,15 +582,15 @@ class Parameters:
max_right = max(len(el[1]) for el in p_pairs)
return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs)
def strip_params_dict(self) -> dict[str, Any]:
def strip_params_dict(self, copy=False) -> dict[str, Any]:
"""
prepares a dictionary for serialization. Some keys may not be preserved
(dropped because they take a lot of space and can be exactly reconstructed)
Parameters
----------
dico : dict
dictionary
copy : bool, optional
whether to deepcopy each value, by default False
"""
forbiden_keys = {
"_param_dico",
@@ -608,25 +611,19 @@ class Parameters:
"linear_op",
"c_to_a_factor",
}
types = (np.ndarray, float, int, str, list, tuple, dict, Path)
types = (np.ndarray, float, int, str, list, tuple, Path)
c = deepcopy if copy else lambda x: x
out = {}
for key, value in self._param_dico.items():
if key in forbiden_keys or key not in self._p_names:
continue
if not isinstance(value, types):
continue
if isinstance(value, dict):
out[key] = Parameters.strip_params_dict(value)
elif isinstance(value, Path):
out[key] = str(value)
elif isinstance(value, np.ndarray) and value.dtype == complex:
continue
elif isinstance(value, np.ndarray):
out[key] = value.tolist()
else:
out[key] = value
if "variable" in out and len(out["variable"]) == 0:
del out["variable"]
out[key] = c(value)
return out