added parameter copy
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import datetime as datetime_module
|
import datetime as datetime_module
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from copy import copy
|
from copy import copy, deepcopy
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from math import isnan
|
from math import isnan
|
||||||
@@ -468,9 +468,12 @@ class Parameters:
|
|||||||
)
|
)
|
||||||
object.__setattr__(self, k, v)
|
object.__setattr__(self, k, v)
|
||||||
|
|
||||||
|
def copy(self) -> Parameters:
|
||||||
|
return Parameters(**deepcopy(self.strip_params_dict()))
|
||||||
|
|
||||||
def to_json(self) -> str:
|
def to_json(self) -> str:
|
||||||
d = self.dump_dict()
|
d = self.dump_dict()
|
||||||
return json.dumps(d, cls=DatetimeEncoder)
|
return json.dumps(d, cls=DatetimeEncoder, default=list)
|
||||||
|
|
||||||
def get_evaluator(self):
|
def get_evaluator(self):
|
||||||
evaluator = Evaluator.default(self.full_field)
|
evaluator = Evaluator.default(self.full_field)
|
||||||
@@ -478,7 +481,7 @@ class Parameters:
|
|||||||
return evaluator
|
return evaluator
|
||||||
|
|
||||||
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
|
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
|
||||||
param = Parameters.strip_params_dict()
|
param = self.strip_params_dict()
|
||||||
if add_metadata:
|
if add_metadata:
|
||||||
param["datetime"] = datetime_module.datetime.now()
|
param["datetime"] = datetime_module.datetime.now()
|
||||||
param["version"] = __version__
|
param["version"] = __version__
|
||||||
@@ -579,15 +582,15 @@ class Parameters:
|
|||||||
max_right = max(len(el[1]) for el in p_pairs)
|
max_right = max(len(el[1]) for el in p_pairs)
|
||||||
return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs)
|
return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs)
|
||||||
|
|
||||||
def strip_params_dict(self) -> dict[str, Any]:
|
def strip_params_dict(self, copy=False) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
prepares a dictionary for serialization. Some keys may not be preserved
|
prepares a dictionary for serialization. Some keys may not be preserved
|
||||||
(dropped because they take a lot of space and can be exactly reconstructed)
|
(dropped because they take a lot of space and can be exactly reconstructed)
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
dico : dict
|
copy : bool, optional
|
||||||
dictionary
|
whether to deepcopy each value, by default False
|
||||||
"""
|
"""
|
||||||
forbiden_keys = {
|
forbiden_keys = {
|
||||||
"_param_dico",
|
"_param_dico",
|
||||||
@@ -608,25 +611,19 @@ class Parameters:
|
|||||||
"linear_op",
|
"linear_op",
|
||||||
"c_to_a_factor",
|
"c_to_a_factor",
|
||||||
}
|
}
|
||||||
types = (np.ndarray, float, int, str, list, tuple, dict, Path)
|
types = (np.ndarray, float, int, str, list, tuple, Path)
|
||||||
|
c = deepcopy if copy else lambda x: x
|
||||||
out = {}
|
out = {}
|
||||||
for key, value in self._param_dico.items():
|
for key, value in self._param_dico.items():
|
||||||
if key in forbiden_keys or key not in self._p_names:
|
if key in forbiden_keys or key not in self._p_names:
|
||||||
continue
|
continue
|
||||||
if not isinstance(value, types):
|
if not isinstance(value, types):
|
||||||
continue
|
continue
|
||||||
if isinstance(value, dict):
|
|
||||||
out[key] = Parameters.strip_params_dict(value)
|
|
||||||
elif isinstance(value, Path):
|
elif isinstance(value, Path):
|
||||||
out[key] = str(value)
|
out[key] = str(value)
|
||||||
elif isinstance(value, np.ndarray) and value.dtype == complex:
|
elif isinstance(value, np.ndarray) and value.dtype == complex:
|
||||||
continue
|
continue
|
||||||
elif isinstance(value, np.ndarray):
|
|
||||||
out[key] = value.tolist()
|
|
||||||
else:
|
else:
|
||||||
out[key] = value
|
out[key] = c(value)
|
||||||
|
|
||||||
if "variable" in out and len(out["variable"]) == 0:
|
|
||||||
del out["variable"]
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user