added parameter copy
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import datetime as datetime_module
|
||||
import json
|
||||
import os
|
||||
from copy import copy
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass, field, fields
|
||||
from functools import lru_cache, wraps
|
||||
from math import isnan
|
||||
@@ -468,9 +468,12 @@ class Parameters:
|
||||
)
|
||||
object.__setattr__(self, k, v)
|
||||
|
||||
def copy(self) -> Parameters:
|
||||
return Parameters(**deepcopy(self.strip_params_dict()))
|
||||
|
||||
def to_json(self) -> str:
|
||||
d = self.dump_dict()
|
||||
return json.dumps(d, cls=DatetimeEncoder)
|
||||
return json.dumps(d, cls=DatetimeEncoder, default=list)
|
||||
|
||||
def get_evaluator(self):
|
||||
evaluator = Evaluator.default(self.full_field)
|
||||
@@ -478,7 +481,7 @@ class Parameters:
|
||||
return evaluator
|
||||
|
||||
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
|
||||
param = Parameters.strip_params_dict()
|
||||
param = self.strip_params_dict()
|
||||
if add_metadata:
|
||||
param["datetime"] = datetime_module.datetime.now()
|
||||
param["version"] = __version__
|
||||
@@ -579,15 +582,15 @@ class Parameters:
|
||||
max_right = max(len(el[1]) for el in p_pairs)
|
||||
return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs)
|
||||
|
||||
def strip_params_dict(self) -> dict[str, Any]:
|
||||
def strip_params_dict(self, copy=False) -> dict[str, Any]:
|
||||
"""
|
||||
prepares a dictionary for serialization. Some keys may not be preserved
|
||||
(dropped because they take a lot of space and can be exactly reconstructed)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dico : dict
|
||||
dictionary
|
||||
copy : bool, optional
|
||||
whether to deepcopy each value, by default False
|
||||
"""
|
||||
forbiden_keys = {
|
||||
"_param_dico",
|
||||
@@ -608,25 +611,19 @@ class Parameters:
|
||||
"linear_op",
|
||||
"c_to_a_factor",
|
||||
}
|
||||
types = (np.ndarray, float, int, str, list, tuple, dict, Path)
|
||||
types = (np.ndarray, float, int, str, list, tuple, Path)
|
||||
c = deepcopy if copy else lambda x: x
|
||||
out = {}
|
||||
for key, value in self._param_dico.items():
|
||||
if key in forbiden_keys or key not in self._p_names:
|
||||
continue
|
||||
if not isinstance(value, types):
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
out[key] = Parameters.strip_params_dict(value)
|
||||
elif isinstance(value, Path):
|
||||
out[key] = str(value)
|
||||
elif isinstance(value, np.ndarray) and value.dtype == complex:
|
||||
continue
|
||||
elif isinstance(value, np.ndarray):
|
||||
out[key] = value.tolist()
|
||||
else:
|
||||
out[key] = value
|
||||
|
||||
if "variable" in out and len(out["variable"]) == 0:
|
||||
del out["variable"]
|
||||
out[key] = c(value)
|
||||
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user