more Parameters fixup
This commit is contained in:
33
datetimejson.py
Normal file
33
datetimejson.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class DatetimeEncoder(json.JSONEncoder):
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, (datetime.date, datetime.datetime)):
|
||||||
|
return obj.isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def print_and_return(obj):
|
||||||
|
for k, v in obj.items():
|
||||||
|
if not isinstance(v, str):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
dt = datetime.datetime.fromisoformat(v)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
dt = datetime.date.fromisoformat(v)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
obj[k] = dt
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
d = dict(user=dict(joined=datetime.datetime.now()), other_user=datetime.date.today())
|
||||||
|
|
||||||
|
s = json.dumps(d, cls=DatetimeEncoder)
|
||||||
|
print(s)
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(json.loads(s, object_hook=print_and_return))
|
||||||
23
src/scgenerator/io.py
Normal file
23
src/scgenerator/io.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class DatetimeEncoder(json.JSONEncoder):
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, (datetime.date, datetime.datetime)):
|
||||||
|
return obj.isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_datetime_hook(obj):
|
||||||
|
for k, v in obj.items():
|
||||||
|
if not isinstance(v, str):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
dt = datetime.datetime.fromisoformat(v)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
dt = datetime.date.fromisoformat(v)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
obj[k] = dt
|
||||||
|
return obj
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime as datetime_module
|
import datetime as datetime_module
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
@@ -13,7 +14,9 @@ import numpy as np
|
|||||||
|
|
||||||
from scgenerator import utils
|
from scgenerator import utils
|
||||||
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
||||||
|
from scgenerator.errors import EvaluatorError
|
||||||
from scgenerator.evaluator import Evaluator
|
from scgenerator.evaluator import Evaluator
|
||||||
|
from scgenerator.io import DatetimeEncoder, decode_datetime_hook
|
||||||
from scgenerator.operators import Qualifier, SpecOperator
|
from scgenerator.operators import Qualifier, SpecOperator
|
||||||
from scgenerator.utils import update_path_name
|
from scgenerator.utils import update_path_name
|
||||||
|
|
||||||
@@ -432,6 +435,14 @@ class Parameters:
|
|||||||
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
||||||
version: str = Parameter(string)
|
version: str = Parameter(string)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, s: str) -> Parameters:
|
||||||
|
return cls(**json.loads(s, object_hook=decode_datetime_hook))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: os.PathLike) -> Parameters:
|
||||||
|
return cls.from_json(Path(path).read_text())
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "Parameter(" + ", ".join(self.__repr_list__()) + ")"
|
return "Parameter(" + ", ".join(self.__repr_list__()) + ")"
|
||||||
|
|
||||||
@@ -457,19 +468,40 @@ class Parameters:
|
|||||||
)
|
)
|
||||||
object.__setattr__(self, k, v)
|
object.__setattr__(self, k, v)
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
d = self.dump_dict()
|
||||||
|
return json.dumps(d, cls=DatetimeEncoder)
|
||||||
|
|
||||||
def get_evaluator(self):
|
def get_evaluator(self):
|
||||||
evaluator = Evaluator.default(self.full_field)
|
evaluator = Evaluator.default(self.full_field)
|
||||||
evaluator.set(self._param_dico.copy())
|
evaluator.set(self._param_dico.copy())
|
||||||
return evaluator
|
return evaluator
|
||||||
|
|
||||||
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
|
def dump_dict(self, add_metadata=True) -> dict[str, Any]:
|
||||||
param = Parameters.strip_params_dict(self._param_dico)
|
param = Parameters.strip_params_dict()
|
||||||
if add_metadata:
|
if add_metadata:
|
||||||
param["datetime"] = datetime_module.datetime.now()
|
param["datetime"] = datetime_module.datetime.now()
|
||||||
param["version"] = __version__
|
param["version"] = __version__
|
||||||
return param
|
return param
|
||||||
|
|
||||||
def compute(self, p_name: str, *other_p_names: str) -> Any | tuple[Any]:
|
def compute(self, p_name: str, *other_p_names: str) -> Any | tuple[Any]:
|
||||||
|
"""
|
||||||
|
compute a single or a set of value
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
p_name : str
|
||||||
|
parameter to compute
|
||||||
|
other_p_names : str
|
||||||
|
other parameters to compute
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
if other_p_names == ()
|
||||||
|
returns the computed `p_name` directly
|
||||||
|
else
|
||||||
|
returns a tuple of corresponding values to (p_name, other_p_names[0], ...)
|
||||||
|
"""
|
||||||
evaluator = self.get_evaluator()
|
evaluator = self.get_evaluator()
|
||||||
first = evaluator.compute(p_name)
|
first = evaluator.compute(p_name)
|
||||||
if other_p_names:
|
if other_p_names:
|
||||||
@@ -478,10 +510,39 @@ class Parameters:
|
|||||||
return first
|
return first
|
||||||
|
|
||||||
def compile(self, exhaustive=False) -> Parameters:
|
def compile(self, exhaustive=False) -> Parameters:
|
||||||
|
"""
|
||||||
|
Computes missing parameters and returns them in a frozen `Parameters` instance
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
exhaustive : bool, optional
|
||||||
|
if True, will compute more parameters than strictly necessary for a simulation.
|
||||||
|
Depending on the specifics of the model and how the parameters were specified, there
|
||||||
|
might be no difference between a normal compilation and an exhaustive one.
|
||||||
|
by default False
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Parameters
|
||||||
|
a new, frozen instance of the `Parameters` class. Attributes already specified by the
|
||||||
|
user are copied, alongside newly computed ones.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
When all the necessary parameters cannot be computed, a `ValueError` is raised. In most
|
||||||
|
cases, this is due to underdetermination by the user.
|
||||||
|
"""
|
||||||
to_compute = MANDATORY_PARAMETERS
|
to_compute = MANDATORY_PARAMETERS
|
||||||
evaluator = self.get_evaluator()
|
evaluator = self.get_evaluator()
|
||||||
for k in to_compute:
|
try:
|
||||||
evaluator.compute(k)
|
for k in to_compute:
|
||||||
|
evaluator.compute(k)
|
||||||
|
except EvaluatorError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not compile the parameter set. Most likely, "
|
||||||
|
f"an essential value is missing\n{e}"
|
||||||
|
) from None
|
||||||
if exhaustive:
|
if exhaustive:
|
||||||
for p in self._p_names:
|
for p in self._p_names:
|
||||||
if p not in evaluator.params:
|
if p not in evaluator.params:
|
||||||
@@ -492,7 +553,6 @@ class Parameters:
|
|||||||
computed = self.__class__(
|
computed = self.__class__(
|
||||||
**{k: v for k, v in evaluator.params.items() if k in self._p_names}
|
**{k: v for k, v in evaluator.params.items() if k in self._p_names}
|
||||||
)
|
)
|
||||||
computed._frozen = True
|
|
||||||
return computed
|
return computed
|
||||||
|
|
||||||
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
|
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
|
||||||
@@ -506,16 +566,7 @@ class Parameters:
|
|||||||
max_right = max(len(el[1]) for el in p_pairs)
|
max_right = max(len(el[1]) for el in p_pairs)
|
||||||
return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs)
|
return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs)
|
||||||
|
|
||||||
@classmethod
|
def strip_params_dict(self) -> dict[str, Any]:
|
||||||
def all_parameters(cls) -> list[str]:
|
|
||||||
return [f.name for f in fields(cls)]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, path: os.PathLike) -> "Parameters":
|
|
||||||
return cls(**utils.load_toml(path))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def strip_params_dict(cls, dico: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""prepares a dictionary for serialization. Some keys may not be preserved
|
"""prepares a dictionary for serialization. Some keys may not be preserved
|
||||||
(dropped because they take a lot of space and can be exactly reconstructed)
|
(dropped because they take a lot of space and can be exactly reconstructed)
|
||||||
|
|
||||||
@@ -545,8 +596,8 @@ class Parameters:
|
|||||||
}
|
}
|
||||||
types = (np.ndarray, float, int, str, list, tuple, dict, Path)
|
types = (np.ndarray, float, int, str, list, tuple, dict, Path)
|
||||||
out = {}
|
out = {}
|
||||||
for key, value in dico.items():
|
for key, value in self._param_dico.items():
|
||||||
if key in forbiden_keys or key not in cls._p_names:
|
if key in forbiden_keys or key not in self._p_names:
|
||||||
continue
|
continue
|
||||||
if not isinstance(value, types):
|
if not isinstance(value, types):
|
||||||
continue
|
continue
|
||||||
@@ -565,36 +616,3 @@ class Parameters:
|
|||||||
del out["variable"]
|
del out["variable"]
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@property
|
|
||||||
def final_path(self) -> Path:
|
|
||||||
if self.output_path is not None:
|
|
||||||
return self.output_path.parent / update_path_name(self.output_path.name)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
numero = type_checker(int)
|
|
||||||
|
|
||||||
@numero
|
|
||||||
def natural_number(name, n):
|
|
||||||
if n < 0:
|
|
||||||
raise ValueError(f"{name!r} must be positive")
|
|
||||||
|
|
||||||
try:
|
|
||||||
numero("a", np.arange(45))
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
try:
|
|
||||||
natural_number("b", -1)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
try:
|
|
||||||
natural_number("c", 1.0)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
try:
|
|
||||||
natural_number("d", 1)
|
|
||||||
print("success !")
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|||||||
@@ -340,36 +340,6 @@ def beta2_coef(beta2_coefficients):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def standardize_dictionary(dico):
|
|
||||||
"""convert lists of number and units into a float with SI units inside a dictionary
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
dico : a dictionary
|
|
||||||
Returns
|
|
||||||
----------
|
|
||||||
same dictionary with units converted
|
|
||||||
Example
|
|
||||||
----------
|
|
||||||
standardize_dictionary({"peak_power": [23, "kW"], "points": [1, 2, 3]})
|
|
||||||
{"peak_power": 23000, "points": [1, 2, 3]})
|
|
||||||
"""
|
|
||||||
for key, item in dico.items():
|
|
||||||
if (
|
|
||||||
isinstance(item, list)
|
|
||||||
and len(item) == 2
|
|
||||||
and isinstance(item[0], (int, float))
|
|
||||||
and isinstance(item[1], str)
|
|
||||||
):
|
|
||||||
num, unit = item
|
|
||||||
fac = 1
|
|
||||||
if len(unit) == 2:
|
|
||||||
fac = prefix[unit[0]]
|
|
||||||
elif unit == "bar":
|
|
||||||
fac = 1e5
|
|
||||||
dico[key] = num * fac
|
|
||||||
return dico
|
|
||||||
|
|
||||||
|
|
||||||
def to_WL(spectrum: np.ndarray, lambda_: np.ndarray) -> np.ndarray:
|
def to_WL(spectrum: np.ndarray, lambda_: np.ndarray) -> np.ndarray:
|
||||||
"""rescales the spectrum because of uneven binning when going from freq to wl
|
"""rescales the spectrum because of uneven binning when going from freq to wl
|
||||||
|
|
||||||
|
|||||||
21
tests/test_type_checker.py
Normal file
21
tests/test_type_checker.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from scgenerator.parameter import type_checker
|
||||||
|
|
||||||
|
|
||||||
|
def test_type_checker():
|
||||||
|
numero = type_checker(int)
|
||||||
|
|
||||||
|
@numero
|
||||||
|
def natural_number(name, n):
|
||||||
|
if n < 0:
|
||||||
|
raise ValueError(f"{name!r} must be positive")
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match="of type"):
|
||||||
|
numero("a", np.arange(45))
|
||||||
|
with pytest.raises(ValueError, match="positive"):
|
||||||
|
natural_number("b", -1)
|
||||||
|
with pytest.raises(TypeError, match="of type"):
|
||||||
|
natural_number("c", 1.0)
|
||||||
|
natural_number("d", 1)
|
||||||
Reference in New Issue
Block a user