improved Parameters machinery

This commit is contained in:
Benoît Sierro
2021-10-25 15:02:56 +02:00
parent 499ddaca5b
commit 3c197386b6
5 changed files with 63 additions and 48 deletions

View File

@@ -175,9 +175,10 @@ def merge(args):
def preview(args): def preview(args):
path = args.path for path in utils.simulations_list(args.path):
partial_plot(path) partial_plot(path)
plt.show() plt.show()
plt.close()
def prep_ray(): def prep_ray():

View File

@@ -95,17 +95,6 @@ class EvalStat:
priority: float = np.inf priority: float = np.inf
class pdict(dict):
"""a dictionary that cannot have any None value"""
def __setitem__(self, k, v):
if v is None:
if k in self:
del self[k]
else:
super().__setitem__(k, v)
class Evaluator: class Evaluator:
defaults: dict[str, Any] = {} defaults: dict[str, Any] = {}

View File

@@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import datetime as datetime_module import datetime as datetime_module
import enum import enum
import os import os
@@ -8,14 +7,16 @@ import time
from copy import copy from copy import copy
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from functools import lru_cache from functools import lru_cache
from math import isnan
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Iterable, Iterator, TypeVar, Union from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeVar, Union
import numpy as np import numpy as np
from . import env, legacy, utils from . import env, legacy, utils
from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__ from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__
from .evaluator import Evaluator, pdict from .errors import EvaluatorError
from .evaluator import Evaluator
from .logger import get_logger from .logger import get_logger
from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator
from .utils import fiber_folder, update_path_name from .utils import fiber_folder, update_path_name
@@ -199,43 +200,51 @@ class Parameter:
converts a valid value (for example, str.lower), by default None converts a valid value (for example, str.lower), by default None
default : callable, optional default : callable, optional
factory function for a default value (for example, list), by default None factory function for a default value (for example, list), by default None
display_info : tuple[float, str], optional
a factor by which to multiply the value and a string to be appended as a suffix
when displaying the value
example : (1e-6, "MW") will mean the value 1.12e6 is displayed as '1.12MW'
""" """
self.__validator = validator
self.validator = validator
self.converter = converter self.converter = converter
self.default = default self.default = default
self.display_info = display_info self.display_info = display_info
self.value = None
def __set_name__(self, owner, name): def __set_name__(self, owner: Type[Parameters], name):
self.name = name self.name = name
try:
owner._p_names.add(self.name)
except AttributeError:
pass
if self.default is not None: if self.default is not None:
Evaluator.register_default_param(self.name, self.default) Evaluator.register_default_param(self.name, self.default)
VariationDescriptor.register_formatter(self.name, self.display) VariationDescriptor.register_formatter(self.name, self.display)
def __get__(self, instance, owner): def __get__(self, instance: Parameters, owner):
if instance is None: if instance is None:
return self return self
if self.name not in instance._param_dico: if self.name not in instance._param_dico:
instance._evaluator.compute(self.name) try:
return instance._param_dico[self.name] instance._evaluator.compute(self.name)
except EvaluatorError:
pass
return instance._param_dico.get(self.name)
# return instance.__dict__[self.name] # return instance.__dict__[self.name]
def __delete__(self, instance): def __delete__(self, instance):
raise AttributeError("Cannot delete parameter") raise AttributeError("Cannot delete parameter")
def __set__(self, instance, value): def __set__(self, instance: Parameters, value):
if isinstance(value, Parameter): if isinstance(value, Parameter):
defaut = None if self.default is None else copy(self.default) if self.default is not None:
instance._param_dico[self.name] = defaut instance._param_dico[self.name] = copy(self.default)
# instance.__dict__[self.name] = defaut
else: else:
if value is not None: is_value, value = self.validate(value)
if self.converter is not None: if is_value:
value = self.converter(value) instance._param_dico[self.name] = value
self.validator(self.name, value) else:
instance._param_dico[self.name] = value if self.name in instance._param_dico:
# instance.__dict__[self.name] = value del instance._param_dico[self.name]
def display(self, num: float) -> str: def display(self, num: float) -> str:
if self.display_info is None: if self.display_info is None:
@@ -247,6 +256,22 @@ class Parameter:
num_str = num_str[:-3] num_str = num_str[:-3]
return f"{num_str} {unit}" return f"{num_str} {unit}"
def validate(self: Parameter, v) -> tuple[bool, Any]:
if v is None:
is_value = False
try:
is_value = not isnan(v)
except TypeError:
is_value = True
if is_value:
if self.converter is not None:
v = self.converter(v)
self.__validator(self.name, v)
return is_value, v
def validator(self, name, value):
self.validate(value)
@dataclass(repr=False) @dataclass(repr=False)
class Parameters: class Parameters:
@@ -255,8 +280,9 @@ class Parameters:
""" """
# internal machinery # internal machinery
_param_dico: pdict[str, Any] = field(init=False, default_factory=pdict, repr=False) _param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
_evaluator: Evaluator = field(init=False, repr=False) _evaluator: Evaluator = field(init=False, repr=False)
_p_names: ClassVar[Set[str]] = set()
# root # root
name: str = Parameter(string, default="no name") name: str = Parameter(string, default="no name")
@@ -383,7 +409,9 @@ class Parameters:
def __repr_list__(self) -> Iterator[str]: def __repr_list__(self) -> Iterator[str]:
yield from (f"{k}={v}" for k, v in self.dump_dict().items()) yield from (f"{k}={v}" for k, v in self.dump_dict().items())
def dump_dict(self) -> dict[str, Any]: def dump_dict(self, compute=True) -> dict[str, Any]:
if compute:
self.compute_in_place()
param = Parameters.strip_params_dict(self._param_dico) param = Parameters.strip_params_dict(self._param_dico)
param["datetime"] = datetime_module.datetime.now() param["datetime"] = datetime_module.datetime.now()
param["version"] = __version__ param["version"] = __version__
@@ -408,8 +436,8 @@ class Parameters:
def load(cls, path: os.PathLike) -> "Parameters": def load(cls, path: os.PathLike) -> "Parameters":
return cls(**utils.load_toml(path)) return cls(**utils.load_toml(path))
@staticmethod @classmethod
def strip_params_dict(dico: dict[str, Any]) -> dict[str, Any]: def strip_params_dict(cls, dico: dict[str, Any]) -> dict[str, Any]:
"""prepares a dictionary for serialization. Some keys may not be preserved """prepares a dictionary for serialization. Some keys may not be preserved
(dropped because they take a lot of space and can be exactly reconstructed) (dropped because they take a lot of space and can be exactly reconstructed)
@@ -439,7 +467,7 @@ class Parameters:
types = (np.ndarray, float, int, str, list, tuple, dict, Path) types = (np.ndarray, float, int, str, list, tuple, dict, Path)
out = {} out = {}
for key, value in dico.items(): for key, value in dico.items():
if key in forbiden_keys: if key in forbiden_keys or key not in cls._p_names:
continue continue
if not isinstance(value, types): if not isinstance(value, types):
continue continue
@@ -575,7 +603,7 @@ class Configuration:
def __validate_variable(self, vary_dict_list: list[dict[str, list]]): def __validate_variable(self, vary_dict_list: list[dict[str, list]]):
for vary_dict in vary_dict_list: for vary_dict in vary_dict_list:
for k, v in vary_dict.items(): for k, v in vary_dict.items():
p = getattr(Parameters, k) p: Parameter = getattr(Parameters, k)
validator_list(p.validator)("variable " + k, v) validator_list(p.validator)("variable " + k, v)
if k not in VALID_VARIABLE: if k not in VALID_VARIABLE:
raise TypeError(f"{k!r} is not a valid variable parameter") raise TypeError(f"{k!r} is not a valid variable parameter")

View File

@@ -9,12 +9,8 @@ from typing import Any, Generator, Optional, Type, Union
import numpy as np import numpy as np
from .. import utils from .. import utils
from ..errors import EvaluatorError
from ..logger import get_logger from ..logger import get_logger
from ..operators import ( from ..operators import AbstractConservedQuantity, CurrentState
AbstractConservedQuantity,
CurrentState,
)
from ..parameter import Configuration, Parameters from ..parameter import Configuration, Parameters
from ..pbar import PBars, ProgressBarActor, progress_worker from ..pbar import PBars, ProgressBarActor, progress_worker
@@ -97,9 +93,9 @@ class RK4IP:
self.store_num = len(self.z_targets) self.store_num = len(self.z_targets)
# Setup initial values for every physical quantity that we want to track # Setup initial values for every physical quantity that we want to track
try: if self.params.A_eff_arr is not None:
C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4) C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
except EvaluatorError: else:
C_to_A_factor = 1.0 C_to_A_factor = 1.0
z = self.z_targets.pop(0) z = self.z_targets.pop(0)
# Initial step size # Initial step size

View File

@@ -1072,8 +1072,9 @@ def annotate_fwhm(
def partial_plot(root: os.PathLike): def partial_plot(root: os.PathLike):
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 8))
path = Path(root) path = Path(root)
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 8))
fig.suptitle(path.name)
spec_list = sorted( spec_list = sorted(
path.glob(SPEC1_FN.format("*")), key=lambda el: int(re.search("[0-9]+", el.name)[0]) path.glob(SPEC1_FN.format("*")), key=lambda el: int(re.search("[0-9]+", el.name)[0])
) )