From 3c197386b6544e9335c34ee17e19a403e93193dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Mon, 25 Oct 2021 15:02:56 +0200 Subject: [PATCH] improved Parameters machinery --- src/scgenerator/cli/cli.py | 7 +-- src/scgenerator/evaluator.py | 11 ---- src/scgenerator/parameter.py | 80 +++++++++++++++++++---------- src/scgenerator/physics/simulate.py | 10 ++-- src/scgenerator/plotting.py | 3 +- 5 files changed, 63 insertions(+), 48 deletions(-) diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py index ab6f8fc..550f5f9 100644 --- a/src/scgenerator/cli/cli.py +++ b/src/scgenerator/cli/cli.py @@ -175,9 +175,10 @@ def merge(args): def preview(args): - path = args.path - partial_plot(path) - plt.show() + for path in utils.simulations_list(args.path): + partial_plot(path) + plt.show() + plt.close() def prep_ray(): diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index 048e7a7..a3bb4aa 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -95,17 +95,6 @@ class EvalStat: priority: float = np.inf -class pdict(dict): - """a dictionary that cannot have any None value""" - - def __setitem__(self, k, v): - if v is None: - if k in self: - del self[k] - else: - super().__setitem__(k, v) - - class Evaluator: defaults: dict[str, Any] = {} diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 6c358c9..c209792 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -1,6 +1,5 @@ from __future__ import annotations - import datetime as datetime_module import enum import os @@ -8,14 +7,16 @@ import time from copy import copy from dataclasses import dataclass, field, fields from functools import lru_cache +from math import isnan from pathlib import Path -from typing import Any, Callable, Iterable, Iterator, TypeVar, Union +from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeVar, Union import numpy as np from . import env, legacy, utils from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__ -from .evaluator import Evaluator, pdict +from .errors import EvaluatorError +from .evaluator import Evaluator from .logger import get_logger from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator from .utils import fiber_folder, update_path_name @@ -199,43 +200,51 @@ class Parameter: converts a valid value (for example, str.lower), by default None default : callable, optional factory function for a default value (for example, list), by default None + display_info : tuple[float, str], optional + a factor by which to multiply the value and a string to be appended as a suffix + when displaying the value + example : (1e-6, "MW") will mean the value 1.12e6 is displayed as '1.12MW' """ - - self.validator = validator + self.__validator = validator self.converter = converter self.default = default self.display_info = display_info - self.value = None - def __set_name__(self, owner, name): + def __set_name__(self, owner: Type[Parameters], name): self.name = name + try: + owner._p_names.add(self.name) + except AttributeError: + pass if self.default is not None: Evaluator.register_default_param(self.name, self.default) VariationDescriptor.register_formatter(self.name, self.display) - def __get__(self, instance, owner): + def __get__(self, instance: Parameters, owner): if instance is None: return self if self.name not in instance._param_dico: - instance._evaluator.compute(self.name) - return instance._param_dico[self.name] + try: + instance._evaluator.compute(self.name) + except EvaluatorError: + pass + return instance._param_dico.get(self.name) # return instance.__dict__[self.name] def __delete__(self, instance): raise AttributeError("Cannot delete parameter") - def __set__(self, instance, value): + def __set__(self, instance: Parameters, value): if isinstance(value, Parameter): - defaut = None if self.default is None else copy(self.default) - instance._param_dico[self.name] = defaut - # instance.__dict__[self.name] = defaut + if self.default is not None: + instance._param_dico[self.name] = copy(self.default) else: - if value is not None: - if self.converter is not None: - value = self.converter(value) - self.validator(self.name, value) - instance._param_dico[self.name] = value - # instance.__dict__[self.name] = value + is_value, value = self.validate(value) + if is_value: + instance._param_dico[self.name] = value + else: + if self.name in instance._param_dico: + del instance._param_dico[self.name] def display(self, num: float) -> str: if self.display_info is None: @@ -247,6 +256,22 @@ class Parameter: num_str = num_str[:-3] return f"{num_str} {unit}" + def validate(self: Parameter, v) -> tuple[bool, Any]: + if v is None: + is_value = False + try: + is_value = not isnan(v) + except TypeError: + is_value = True + if is_value: + if self.converter is not None: + v = self.converter(v) + self.__validator(self.name, v) + return is_value, v + + def validator(self, name, value): + self.validate(value) + @dataclass(repr=False) class Parameters: @@ -255,8 +280,9 @@ class Parameters: """ # internal machinery - _param_dico: pdict[str, Any] = field(init=False, default_factory=pdict, repr=False) + _param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False) _evaluator: Evaluator = field(init=False, repr=False) + _p_names: ClassVar[Set[str]] = set() # root name: str = Parameter(string, default="no name") @@ -383,7 +409,9 @@ class Parameters: def __repr_list__(self) -> Iterator[str]: yield from (f"{k}={v}" for k, v in self.dump_dict().items()) - def dump_dict(self) -> dict[str, Any]: + def dump_dict(self, compute=True) -> dict[str, Any]: + if compute: + self.compute_in_place() param = Parameters.strip_params_dict(self._param_dico) param["datetime"] = datetime_module.datetime.now() param["version"] = __version__ @@ -408,8 +436,8 @@ class Parameters: def load(cls, path: os.PathLike) -> "Parameters": return cls(**utils.load_toml(path)) - @staticmethod - def strip_params_dict(dico: dict[str, Any]) -> dict[str, Any]: + @classmethod + def strip_params_dict(cls, dico: dict[str, Any]) -> dict[str, Any]: """prepares a dictionary for serialization. Some keys may not be preserved (dropped because they take a lot of space and can be exactly reconstructed) @@ -439,7 +467,7 @@ class Parameters: types = (np.ndarray, float, int, str, list, tuple, dict, Path) out = {} for key, value in dico.items(): - if key in forbiden_keys: + if key in forbiden_keys or key not in cls._p_names: continue if not isinstance(value, types): continue @@ -575,7 +603,7 @@ class Configuration: def __validate_variable(self, vary_dict_list: list[dict[str, list]]): for vary_dict in vary_dict_list: for k, v in vary_dict.items(): - p = getattr(Parameters, k) + p: Parameter = getattr(Parameters, k) validator_list(p.validator)("variable " + k, v) if k not in VALID_VARIABLE: raise TypeError(f"{k!r} is not a valid variable parameter") diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 5b242e3..67ebf6d 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -9,12 +9,8 @@ from typing import Any, Generator, Optional, Type, Union import numpy as np from .. import utils -from ..errors import EvaluatorError from ..logger import get_logger -from ..operators import ( - AbstractConservedQuantity, - CurrentState, -) +from ..operators import AbstractConservedQuantity, CurrentState from ..parameter import Configuration, Parameters from ..pbar import PBars, ProgressBarActor, progress_worker @@ -97,9 +93,9 @@ class RK4IP: self.store_num = len(self.z_targets) # Setup initial values for every physical quantity that we want to track - try: + if self.params.A_eff_arr is not None: C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4) - except EvaluatorError: + else: C_to_A_factor = 1.0 z = self.z_targets.pop(0) # Initial step size diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 16e9ace..08fdb48 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -1072,8 +1072,9 @@ def annotate_fwhm( def partial_plot(root: os.PathLike): - fig, (left, right) = plt.subplots(1, 2, figsize=(12, 8)) path = Path(root) + fig, (left, right) = plt.subplots(1, 2, figsize=(12, 8)) + fig.suptitle(path.name) spec_list = sorted( path.glob(SPEC1_FN.format("*")), key=lambda el: int(re.search("[0-9]+", el.name)[0]) )