From d08c62f569ca6fda2a15aeb5d01f745769d8e0f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Thu, 27 Jul 2023 10:34:54 +0200 Subject: [PATCH] more Parameters fixup --- datetimejson.py | 33 +++++++++ src/scgenerator/io.py | 23 ++++++ src/scgenerator/parameter.py | 116 ++++++++++++++++++------------- src/scgenerator/physics/units.py | 30 -------- tests/test_type_checker.py | 21 ++++++ 5 files changed, 144 insertions(+), 79 deletions(-) create mode 100644 datetimejson.py create mode 100644 src/scgenerator/io.py create mode 100644 tests/test_type_checker.py diff --git a/datetimejson.py b/datetimejson.py new file mode 100644 index 0000000..71fe3af --- /dev/null +++ b/datetimejson.py @@ -0,0 +1,33 @@ +import datetime +import json + + +class DatetimeEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (datetime.date, datetime.datetime)): + return obj.isoformat() + + +def print_and_return(obj): + for k, v in obj.items(): + if not isinstance(v, str): + continue + try: + dt = datetime.datetime.fromisoformat(v) + except Exception: + pass + try: + dt = datetime.date.fromisoformat(v) + except Exception: + pass + obj[k] = dt + return obj + + +d = dict(user=dict(joined=datetime.datetime.now()), other_user=datetime.date.today()) + +s = json.dumps(d, cls=DatetimeEncoder) +print(s) +print() + +print(json.loads(s, object_hook=print_and_return)) diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py new file mode 100644 index 0000000..f682929 --- /dev/null +++ b/src/scgenerator/io.py @@ -0,0 +1,23 @@ +import datetime +import json + + +class DatetimeEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (datetime.date, datetime.datetime)): + return obj.isoformat() + + +def decode_datetime_hook(obj): + for k, v in obj.items(): + if not isinstance(v, str): + continue + try: + dt = datetime.datetime.fromisoformat(v) + except Exception: + try: + dt = datetime.date.fromisoformat(v) + except Exception: + continue + obj[k] = dt + return obj diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index a2dc586..5abd9d0 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime as datetime_module +import json import os from copy import copy from dataclasses import dataclass, field, fields @@ -13,7 +14,9 @@ import numpy as np from scgenerator import utils from scgenerator.const import MANDATORY_PARAMETERS, __version__ +from scgenerator.errors import EvaluatorError from scgenerator.evaluator import Evaluator +from scgenerator.io import DatetimeEncoder, decode_datetime_hook from scgenerator.operators import Qualifier, SpecOperator from scgenerator.utils import update_path_name @@ -432,6 +435,14 @@ class Parameters: datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime)) version: str = Parameter(string) + @classmethod + def from_json(cls, s: str) -> Parameters: + return cls(**json.loads(s, object_hook=decode_datetime_hook)) + + @classmethod + def load(cls, path: os.PathLike) -> Parameters: + return cls.from_json(Path(path).read_text()) + def __repr__(self) -> str: return "Parameter(" + ", ".join(self.__repr_list__()) + ")" @@ -457,19 +468,40 @@ class Parameters: ) object.__setattr__(self, k, v) + def to_json(self) -> str: + d = self.dump_dict() + return json.dumps(d, cls=DatetimeEncoder) + def get_evaluator(self): evaluator = Evaluator.default(self.full_field) evaluator.set(self._param_dico.copy()) return evaluator def dump_dict(self, add_metadata=True) -> dict[str, Any]: - param = Parameters.strip_params_dict(self._param_dico) + param = Parameters.strip_params_dict() if add_metadata: param["datetime"] = datetime_module.datetime.now() param["version"] = __version__ return param def compute(self, p_name: str, *other_p_names: str) -> Any | tuple[Any]: + """ + compute a single or a set of value + + Parameters + ---------- + p_name : str + parameter to compute + other_p_names : str + other parameters to compute + + Returns + ------- + if other_p_names == () + returns the computed `p_name` directly + else + returns a tuple of corresponding values to (p_name, other_p_names[0], ...) + """ evaluator = self.get_evaluator() first = evaluator.compute(p_name) if other_p_names: @@ -478,10 +510,39 @@ class Parameters: return first def compile(self, exhaustive=False) -> Parameters: + """ + Computes missing parameters and returns them in a frozen `Parameters` instance + + Parameters + ---------- + exhaustive : bool, optional + if True, will compute more parameters than strictly necessary for a simulation. + Depending on the specifics of the model and how the parameters were specified, there + might be no difference between a normal compilation and an exhaustive one. + by default False + + Returns + ------- + Parameters + a new, frozen instance of the `Parameters` class. Attributes already specified by the + user are copied, alongside newly computed ones. + + Raises + ------ + ValueError + When all the necessary parameters cannot be computed, a `ValueError` is raised. In most + cases, this is due to underdetermination by the user. + """ to_compute = MANDATORY_PARAMETERS evaluator = self.get_evaluator() - for k in to_compute: - evaluator.compute(k) + try: + for k in to_compute: + evaluator.compute(k) + except EvaluatorError as e: + raise ValueError( + "Could not compile the parameter set. Most likely, " + f"an essential value is missing\n{e}" + ) from None if exhaustive: for p in self._p_names: if p not in evaluator.params: @@ -492,7 +553,6 @@ class Parameters: computed = self.__class__( **{k: v for k, v in evaluator.params.items() if k in self._p_names} ) - computed._frozen = True return computed def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str: @@ -506,16 +566,7 @@ class Parameters: max_right = max(len(el[1]) for el in p_pairs) return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs) - @classmethod - def all_parameters(cls) -> list[str]: - return [f.name for f in fields(cls)] - - @classmethod - def load(cls, path: os.PathLike) -> "Parameters": - return cls(**utils.load_toml(path)) - - @classmethod - def strip_params_dict(cls, dico: dict[str, Any]) -> dict[str, Any]: + def strip_params_dict(self) -> dict[str, Any]: """prepares a dictionary for serialization. Some keys may not be preserved (dropped because they take a lot of space and can be exactly reconstructed) @@ -545,8 +596,8 @@ class Parameters: } types = (np.ndarray, float, int, str, list, tuple, dict, Path) out = {} - for key, value in dico.items(): - if key in forbiden_keys or key not in cls._p_names: + for key, value in self._param_dico.items(): + if key in forbiden_keys or key not in self._p_names: continue if not isinstance(value, types): continue @@ -565,36 +616,3 @@ class Parameters: del out["variable"] return out - - @property - def final_path(self) -> Path: - if self.output_path is not None: - return self.output_path.parent / update_path_name(self.output_path.name) - return None - - -if __name__ == "__main__": - numero = type_checker(int) - - @numero - def natural_number(name, n): - if n < 0: - raise ValueError(f"{name!r} must be positive") - - try: - numero("a", np.arange(45)) - except Exception as e: - print(e) - try: - natural_number("b", -1) - except Exception as e: - print(e) - try: - natural_number("c", 1.0) - except Exception as e: - print(e) - try: - natural_number("d", 1) - print("success !") - except Exception as e: - print(e) diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index 5180339..a52da7b 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -340,36 +340,6 @@ def beta2_coef(beta2_coefficients): return out -def standardize_dictionary(dico): - """convert lists of number and units into a float with SI units inside a dictionary - Parameters - ---------- - dico : a dictionary - Returns - ---------- - same dictionary with units converted - Example - ---------- - standardize_dictionary({"peak_power": [23, "kW"], "points": [1, 2, 3]}) - {"peak_power": 23000, "points": [1, 2, 3]}) - """ - for key, item in dico.items(): - if ( - isinstance(item, list) - and len(item) == 2 - and isinstance(item[0], (int, float)) - and isinstance(item[1], str) - ): - num, unit = item - fac = 1 - if len(unit) == 2: - fac = prefix[unit[0]] - elif unit == "bar": - fac = 1e5 - dico[key] = num * fac - return dico - - def to_WL(spectrum: np.ndarray, lambda_: np.ndarray) -> np.ndarray: """rescales the spectrum because of uneven binning when going from freq to wl diff --git a/tests/test_type_checker.py b/tests/test_type_checker.py new file mode 100644 index 0000000..e25f959 --- /dev/null +++ b/tests/test_type_checker.py @@ -0,0 +1,21 @@ +import numpy as np +import pytest + +from scgenerator.parameter import type_checker + + +def test_type_checker(): + numero = type_checker(int) + + @numero + def natural_number(name, n): + if n < 0: + raise ValueError(f"{name!r} must be positive") + + with pytest.raises(TypeError, match="of type"): + numero("a", np.arange(45)) + with pytest.raises(ValueError, match="positive"): + natural_number("b", -1) + with pytest.raises(TypeError, match="of type"): + natural_number("c", 1.0) + natural_number("d", 1)