more Parameters fixup
This commit is contained in:
33
datetimejson.py
Normal file
33
datetimejson.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
|
||||
class DatetimeEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, (datetime.date, datetime.datetime)):
|
||||
return obj.isoformat()
|
||||
|
||||
|
||||
def print_and_return(obj):
|
||||
for k, v in obj.items():
|
||||
if not isinstance(v, str):
|
||||
continue
|
||||
try:
|
||||
dt = datetime.datetime.fromisoformat(v)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
dt = datetime.date.fromisoformat(v)
|
||||
except Exception:
|
||||
pass
|
||||
obj[k] = dt
|
||||
return obj
|
||||
|
||||
|
||||
d = dict(user=dict(joined=datetime.datetime.now()), other_user=datetime.date.today())
|
||||
|
||||
s = json.dumps(d, cls=DatetimeEncoder)
|
||||
print(s)
|
||||
print()
|
||||
|
||||
print(json.loads(s, object_hook=print_and_return))
|
||||
23
src/scgenerator/io.py
Normal file
23
src/scgenerator/io.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
|
||||
class DatetimeEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, (datetime.date, datetime.datetime)):
|
||||
return obj.isoformat()
|
||||
|
||||
|
||||
def decode_datetime_hook(obj):
|
||||
for k, v in obj.items():
|
||||
if not isinstance(v, str):
|
||||
continue
|
||||
try:
|
||||
dt = datetime.datetime.fromisoformat(v)
|
||||
except Exception:
|
||||
try:
|
||||
dt = datetime.date.fromisoformat(v)
|
||||
except Exception:
|
||||
continue
|
||||
obj[k] = dt
|
||||
return obj
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as datetime_module
|
||||
import json
|
||||
import os
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field, fields
|
||||
@@ -13,7 +14,9 @@ import numpy as np
|
||||
|
||||
from scgenerator import utils
|
||||
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
||||
from scgenerator.errors import EvaluatorError
|
||||
from scgenerator.evaluator import Evaluator
|
||||
from scgenerator.io import DatetimeEncoder, decode_datetime_hook
|
||||
from scgenerator.operators import Qualifier, SpecOperator
|
||||
from scgenerator.utils import update_path_name
|
||||
|
||||
@@ -432,6 +435,14 @@ class Parameters:
|
||||
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
||||
version: str = Parameter(string)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s: str) -> Parameters:
|
||||
return cls(**json.loads(s, object_hook=decode_datetime_hook))
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: os.PathLike) -> Parameters:
|
||||
return cls.from_json(Path(path).read_text())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Parameter(" + ", ".join(self.__repr_list__()) + ")"
|
||||
|
||||
@@ -457,19 +468,40 @@ class Parameters:
|
||||
)
|
||||
object.__setattr__(self, k, v)
|
||||
|
||||
def to_json(self) -> str:
|
||||
d = self.dump_dict()
|
||||
return json.dumps(d, cls=DatetimeEncoder)
|
||||
|
||||
def get_evaluator(self):
|
||||
evaluator = Evaluator.default(self.full_field)
|
||||
evaluator.set(self._param_dico.copy())
|
||||
return evaluator
|
||||
|
||||
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
|
||||
param = Parameters.strip_params_dict(self._param_dico)
|
||||
param = Parameters.strip_params_dict()
|
||||
if add_metadata:
|
||||
param["datetime"] = datetime_module.datetime.now()
|
||||
param["version"] = __version__
|
||||
return param
|
||||
|
||||
def compute(self, p_name: str, *other_p_names: str) -> Any | tuple[Any]:
|
||||
"""
|
||||
compute a single or a set of value
|
||||
|
||||
Parameters
|
||||
----------
|
||||
p_name : str
|
||||
parameter to compute
|
||||
other_p_names : str
|
||||
other parameters to compute
|
||||
|
||||
Returns
|
||||
-------
|
||||
if other_p_names == ()
|
||||
returns the computed `p_name` directly
|
||||
else
|
||||
returns a tuple of corresponding values to (p_name, other_p_names[0], ...)
|
||||
"""
|
||||
evaluator = self.get_evaluator()
|
||||
first = evaluator.compute(p_name)
|
||||
if other_p_names:
|
||||
@@ -478,10 +510,39 @@ class Parameters:
|
||||
return first
|
||||
|
||||
def compile(self, exhaustive=False) -> Parameters:
|
||||
"""
|
||||
Computes missing parameters and returns them in a frozen `Parameters` instance
|
||||
|
||||
Parameters
|
||||
----------
|
||||
exhaustive : bool, optional
|
||||
if True, will compute more parameters than strictly necessary for a simulation.
|
||||
Depending on the specifics of the model and how the parameters were specified, there
|
||||
might be no difference between a normal compilation and an exhaustive one.
|
||||
by default False
|
||||
|
||||
Returns
|
||||
-------
|
||||
Parameters
|
||||
a new, frozen instance of the `Parameters` class. Attributes already specified by the
|
||||
user are copied, alongside newly computed ones.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
When all the necessary parameters cannot be computed, a `ValueError` is raised. In most
|
||||
cases, this is due to underdetermination by the user.
|
||||
"""
|
||||
to_compute = MANDATORY_PARAMETERS
|
||||
evaluator = self.get_evaluator()
|
||||
for k in to_compute:
|
||||
evaluator.compute(k)
|
||||
try:
|
||||
for k in to_compute:
|
||||
evaluator.compute(k)
|
||||
except EvaluatorError as e:
|
||||
raise ValueError(
|
||||
"Could not compile the parameter set. Most likely, "
|
||||
f"an essential value is missing\n{e}"
|
||||
) from None
|
||||
if exhaustive:
|
||||
for p in self._p_names:
|
||||
if p not in evaluator.params:
|
||||
@@ -492,7 +553,6 @@ class Parameters:
|
||||
computed = self.__class__(
|
||||
**{k: v for k, v in evaluator.params.items() if k in self._p_names}
|
||||
)
|
||||
computed._frozen = True
|
||||
return computed
|
||||
|
||||
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
|
||||
@@ -506,16 +566,7 @@ class Parameters:
|
||||
max_right = max(len(el[1]) for el in p_pairs)
|
||||
return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs)
|
||||
|
||||
@classmethod
|
||||
def all_parameters(cls) -> list[str]:
|
||||
return [f.name for f in fields(cls)]
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: os.PathLike) -> "Parameters":
|
||||
return cls(**utils.load_toml(path))
|
||||
|
||||
@classmethod
|
||||
def strip_params_dict(cls, dico: dict[str, Any]) -> dict[str, Any]:
|
||||
def strip_params_dict(self) -> dict[str, Any]:
|
||||
"""prepares a dictionary for serialization. Some keys may not be preserved
|
||||
(dropped because they take a lot of space and can be exactly reconstructed)
|
||||
|
||||
@@ -545,8 +596,8 @@ class Parameters:
|
||||
}
|
||||
types = (np.ndarray, float, int, str, list, tuple, dict, Path)
|
||||
out = {}
|
||||
for key, value in dico.items():
|
||||
if key in forbiden_keys or key not in cls._p_names:
|
||||
for key, value in self._param_dico.items():
|
||||
if key in forbiden_keys or key not in self._p_names:
|
||||
continue
|
||||
if not isinstance(value, types):
|
||||
continue
|
||||
@@ -565,36 +616,3 @@ class Parameters:
|
||||
del out["variable"]
|
||||
|
||||
return out
|
||||
|
||||
@property
|
||||
def final_path(self) -> Path:
|
||||
if self.output_path is not None:
|
||||
return self.output_path.parent / update_path_name(self.output_path.name)
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
numero = type_checker(int)
|
||||
|
||||
@numero
|
||||
def natural_number(name, n):
|
||||
if n < 0:
|
||||
raise ValueError(f"{name!r} must be positive")
|
||||
|
||||
try:
|
||||
numero("a", np.arange(45))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
try:
|
||||
natural_number("b", -1)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
try:
|
||||
natural_number("c", 1.0)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
try:
|
||||
natural_number("d", 1)
|
||||
print("success !")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
@@ -340,36 +340,6 @@ def beta2_coef(beta2_coefficients):
|
||||
return out
|
||||
|
||||
|
||||
def standardize_dictionary(dico):
|
||||
"""convert lists of number and units into a float with SI units inside a dictionary
|
||||
Parameters
|
||||
----------
|
||||
dico : a dictionary
|
||||
Returns
|
||||
----------
|
||||
same dictionary with units converted
|
||||
Example
|
||||
----------
|
||||
standardize_dictionary({"peak_power": [23, "kW"], "points": [1, 2, 3]})
|
||||
{"peak_power": 23000, "points": [1, 2, 3]})
|
||||
"""
|
||||
for key, item in dico.items():
|
||||
if (
|
||||
isinstance(item, list)
|
||||
and len(item) == 2
|
||||
and isinstance(item[0], (int, float))
|
||||
and isinstance(item[1], str)
|
||||
):
|
||||
num, unit = item
|
||||
fac = 1
|
||||
if len(unit) == 2:
|
||||
fac = prefix[unit[0]]
|
||||
elif unit == "bar":
|
||||
fac = 1e5
|
||||
dico[key] = num * fac
|
||||
return dico
|
||||
|
||||
|
||||
def to_WL(spectrum: np.ndarray, lambda_: np.ndarray) -> np.ndarray:
|
||||
"""rescales the spectrum because of uneven binning when going from freq to wl
|
||||
|
||||
|
||||
21
tests/test_type_checker.py
Normal file
21
tests/test_type_checker.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from scgenerator.parameter import type_checker
|
||||
|
||||
|
||||
def test_type_checker():
|
||||
numero = type_checker(int)
|
||||
|
||||
@numero
|
||||
def natural_number(name, n):
|
||||
if n < 0:
|
||||
raise ValueError(f"{name!r} must be positive")
|
||||
|
||||
with pytest.raises(TypeError, match="of type"):
|
||||
numero("a", np.arange(45))
|
||||
with pytest.raises(ValueError, match="positive"):
|
||||
natural_number("b", -1)
|
||||
with pytest.raises(TypeError, match="of type"):
|
||||
natural_number("c", 1.0)
|
||||
natural_number("d", 1)
|
||||
Reference in New Issue
Block a user