more Parameters fixup

This commit is contained in:
Benoît Sierro
2023-07-27 10:34:54 +02:00
parent d72409f339
commit d08c62f569
5 changed files with 144 additions and 79 deletions

33
datetimejson.py Normal file
View File

@@ -0,0 +1,33 @@
import datetime
import json
class DatetimeEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (datetime.date, datetime.datetime)):
return obj.isoformat()
def print_and_return(obj):
for k, v in obj.items():
if not isinstance(v, str):
continue
try:
dt = datetime.datetime.fromisoformat(v)
except Exception:
pass
try:
dt = datetime.date.fromisoformat(v)
except Exception:
pass
obj[k] = dt
return obj
d = dict(user=dict(joined=datetime.datetime.now()), other_user=datetime.date.today())
s = json.dumps(d, cls=DatetimeEncoder)
print(s)
print()
print(json.loads(s, object_hook=print_and_return))

23
src/scgenerator/io.py Normal file
View File

@@ -0,0 +1,23 @@
import datetime
import json
class DatetimeEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (datetime.date, datetime.datetime)):
return obj.isoformat()
def decode_datetime_hook(obj):
for k, v in obj.items():
if not isinstance(v, str):
continue
try:
dt = datetime.datetime.fromisoformat(v)
except Exception:
try:
dt = datetime.date.fromisoformat(v)
except Exception:
continue
obj[k] = dt
return obj

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import datetime as datetime_module
import json
import os
from copy import copy
from dataclasses import dataclass, field, fields
@@ -13,7 +14,9 @@ import numpy as np
from scgenerator import utils
from scgenerator.const import MANDATORY_PARAMETERS, __version__
from scgenerator.errors import EvaluatorError
from scgenerator.evaluator import Evaluator
from scgenerator.io import DatetimeEncoder, decode_datetime_hook
from scgenerator.operators import Qualifier, SpecOperator
from scgenerator.utils import update_path_name
@@ -432,6 +435,14 @@ class Parameters:
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
version: str = Parameter(string)
@classmethod
def from_json(cls, s: str) -> Parameters:
return cls(**json.loads(s, object_hook=decode_datetime_hook))
@classmethod
def load(cls, path: os.PathLike) -> Parameters:
return cls.from_json(Path(path).read_text())
def __repr__(self) -> str:
return "Parameter(" + ", ".join(self.__repr_list__()) + ")"
@@ -457,19 +468,40 @@ class Parameters:
)
object.__setattr__(self, k, v)
def to_json(self) -> str:
d = self.dump_dict()
return json.dumps(d, cls=DatetimeEncoder)
def get_evaluator(self):
evaluator = Evaluator.default(self.full_field)
evaluator.set(self._param_dico.copy())
return evaluator
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
param = Parameters.strip_params_dict(self._param_dico)
param = Parameters.strip_params_dict()
if add_metadata:
param["datetime"] = datetime_module.datetime.now()
param["version"] = __version__
return param
def compute(self, p_name: str, *other_p_names: str) -> Any | tuple[Any]:
"""
compute a single or a set of value
Parameters
----------
p_name : str
parameter to compute
other_p_names : str
other parameters to compute
Returns
-------
if other_p_names == ()
returns the computed `p_name` directly
else
returns a tuple of corresponding values to (p_name, other_p_names[0], ...)
"""
evaluator = self.get_evaluator()
first = evaluator.compute(p_name)
if other_p_names:
@@ -478,10 +510,39 @@ class Parameters:
return first
def compile(self, exhaustive=False) -> Parameters:
"""
Computes missing parameters and returns them in a frozen `Parameters` instance
Parameters
----------
exhaustive : bool, optional
if True, will compute more parameters than strictly necessary for a simulation.
Depending on the specifics of the model and how the parameters were specified, there
might be no difference between a normal compilation and an exhaustive one.
by default False
Returns
-------
Parameters
a new, frozen instance of the `Parameters` class. Attributes already specified by the
user are copied, alongside newly computed ones.
Raises
------
ValueError
When all the necessary parameters cannot be computed, a `ValueError` is raised. In most
cases, this is due to underdetermination by the user.
"""
to_compute = MANDATORY_PARAMETERS
evaluator = self.get_evaluator()
try:
for k in to_compute:
evaluator.compute(k)
except EvaluatorError as e:
raise ValueError(
"Could not compile the parameter set. Most likely, "
f"an essential value is missing\n{e}"
) from None
if exhaustive:
for p in self._p_names:
if p not in evaluator.params:
@@ -492,7 +553,6 @@ class Parameters:
computed = self.__class__(
**{k: v for k, v in evaluator.params.items() if k in self._p_names}
)
computed._frozen = True
return computed
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
@@ -506,16 +566,7 @@ class Parameters:
max_right = max(len(el[1]) for el in p_pairs)
return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs)
@classmethod
def all_parameters(cls) -> list[str]:
return [f.name for f in fields(cls)]
@classmethod
def load(cls, path: os.PathLike) -> "Parameters":
return cls(**utils.load_toml(path))
@classmethod
def strip_params_dict(cls, dico: dict[str, Any]) -> dict[str, Any]:
def strip_params_dict(self) -> dict[str, Any]:
"""prepares a dictionary for serialization. Some keys may not be preserved
(dropped because they take a lot of space and can be exactly reconstructed)
@@ -545,8 +596,8 @@ class Parameters:
}
types = (np.ndarray, float, int, str, list, tuple, dict, Path)
out = {}
for key, value in dico.items():
if key in forbiden_keys or key not in cls._p_names:
for key, value in self._param_dico.items():
if key in forbiden_keys or key not in self._p_names:
continue
if not isinstance(value, types):
continue
@@ -565,36 +616,3 @@ class Parameters:
del out["variable"]
return out
@property
def final_path(self) -> Path:
if self.output_path is not None:
return self.output_path.parent / update_path_name(self.output_path.name)
return None
if __name__ == "__main__":
numero = type_checker(int)
@numero
def natural_number(name, n):
if n < 0:
raise ValueError(f"{name!r} must be positive")
try:
numero("a", np.arange(45))
except Exception as e:
print(e)
try:
natural_number("b", -1)
except Exception as e:
print(e)
try:
natural_number("c", 1.0)
except Exception as e:
print(e)
try:
natural_number("d", 1)
print("success !")
except Exception as e:
print(e)

View File

@@ -340,36 +340,6 @@ def beta2_coef(beta2_coefficients):
return out
def standardize_dictionary(dico):
"""convert lists of number and units into a float with SI units inside a dictionary
Parameters
----------
dico : a dictionary
Returns
----------
same dictionary with units converted
Example
----------
standardize_dictionary({"peak_power": [23, "kW"], "points": [1, 2, 3]})
{"peak_power": 23000, "points": [1, 2, 3]})
"""
for key, item in dico.items():
if (
isinstance(item, list)
and len(item) == 2
and isinstance(item[0], (int, float))
and isinstance(item[1], str)
):
num, unit = item
fac = 1
if len(unit) == 2:
fac = prefix[unit[0]]
elif unit == "bar":
fac = 1e5
dico[key] = num * fac
return dico
def to_WL(spectrum: np.ndarray, lambda_: np.ndarray) -> np.ndarray:
"""rescales the spectrum because of uneven binning when going from freq to wl

View File

@@ -0,0 +1,21 @@
import numpy as np
import pytest
from scgenerator.parameter import type_checker
def test_type_checker():
numero = type_checker(int)
@numero
def natural_number(name, n):
if n < 0:
raise ValueError(f"{name!r} must be positive")
with pytest.raises(TypeError, match="of type"):
numero("a", np.arange(45))
with pytest.raises(ValueError, match="positive"):
natural_number("b", -1)
with pytest.raises(TypeError, match="of type"):
natural_number("c", 1.0)
natural_number("d", 1)