improved Parameters machinery

This commit is contained in:
Benoît Sierro
2021-10-25 15:02:56 +02:00
parent 499ddaca5b
commit 3c197386b6
5 changed files with 63 additions and 48 deletions

View File

@@ -175,9 +175,10 @@ def merge(args):
def preview(args):
path = args.path
partial_plot(path)
plt.show()
for path in utils.simulations_list(args.path):
partial_plot(path)
plt.show()
plt.close()
def prep_ray():

View File

@@ -95,17 +95,6 @@ class EvalStat:
priority: float = np.inf
class pdict(dict):
"""a dictionary that cannot have any None value"""
def __setitem__(self, k, v):
if v is None:
if k in self:
del self[k]
else:
super().__setitem__(k, v)
class Evaluator:
defaults: dict[str, Any] = {}

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import datetime as datetime_module
import enum
import os
@@ -8,14 +7,16 @@ import time
from copy import copy
from dataclasses import dataclass, field, fields
from functools import lru_cache
from math import isnan
from pathlib import Path
from typing import Any, Callable, Iterable, Iterator, TypeVar, Union
from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeVar, Union
import numpy as np
from . import env, legacy, utils
from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__
from .evaluator import Evaluator, pdict
from .errors import EvaluatorError
from .evaluator import Evaluator
from .logger import get_logger
from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator
from .utils import fiber_folder, update_path_name
@@ -199,43 +200,51 @@ class Parameter:
converts a valid value (for example, str.lower), by default None
default : callable, optional
factory function for a default value (for example, list), by default None
display_info : tuple[float, str], optional
a factor by which to multiply the value and a string to be appended as a suffix
when displaying the value
example : (1e-6, "MW") will mean the value 1.12e6 is displayed as '1.12MW'
"""
self.validator = validator
self.__validator = validator
self.converter = converter
self.default = default
self.display_info = display_info
self.value = None
def __set_name__(self, owner, name):
def __set_name__(self, owner: Type[Parameters], name):
self.name = name
try:
owner._p_names.add(self.name)
except AttributeError:
pass
if self.default is not None:
Evaluator.register_default_param(self.name, self.default)
VariationDescriptor.register_formatter(self.name, self.display)
def __get__(self, instance, owner):
def __get__(self, instance: Parameters, owner):
if instance is None:
return self
if self.name not in instance._param_dico:
instance._evaluator.compute(self.name)
return instance._param_dico[self.name]
try:
instance._evaluator.compute(self.name)
except EvaluatorError:
pass
return instance._param_dico.get(self.name)
# return instance.__dict__[self.name]
def __delete__(self, instance):
raise AttributeError("Cannot delete parameter")
def __set__(self, instance, value):
def __set__(self, instance: Parameters, value):
if isinstance(value, Parameter):
defaut = None if self.default is None else copy(self.default)
instance._param_dico[self.name] = defaut
# instance.__dict__[self.name] = defaut
if self.default is not None:
instance._param_dico[self.name] = copy(self.default)
else:
if value is not None:
if self.converter is not None:
value = self.converter(value)
self.validator(self.name, value)
instance._param_dico[self.name] = value
# instance.__dict__[self.name] = value
is_value, value = self.validate(value)
if is_value:
instance._param_dico[self.name] = value
else:
if self.name in instance._param_dico:
del instance._param_dico[self.name]
def display(self, num: float) -> str:
if self.display_info is None:
@@ -247,6 +256,22 @@ class Parameter:
num_str = num_str[:-3]
return f"{num_str} {unit}"
def validate(self: Parameter, v) -> tuple[bool, Any]:
if v is None:
is_value = False
try:
is_value = not isnan(v)
except TypeError:
is_value = True
if is_value:
if self.converter is not None:
v = self.converter(v)
self.__validator(self.name, v)
return is_value, v
def validator(self, name, value):
self.validate(value)
@dataclass(repr=False)
class Parameters:
@@ -255,8 +280,9 @@ class Parameters:
"""
# internal machinery
_param_dico: pdict[str, Any] = field(init=False, default_factory=pdict, repr=False)
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
_evaluator: Evaluator = field(init=False, repr=False)
_p_names: ClassVar[Set[str]] = set()
# root
name: str = Parameter(string, default="no name")
@@ -383,7 +409,9 @@ class Parameters:
def __repr_list__(self) -> Iterator[str]:
yield from (f"{k}={v}" for k, v in self.dump_dict().items())
def dump_dict(self) -> dict[str, Any]:
def dump_dict(self, compute=True) -> dict[str, Any]:
if compute:
self.compute_in_place()
param = Parameters.strip_params_dict(self._param_dico)
param["datetime"] = datetime_module.datetime.now()
param["version"] = __version__
@@ -408,8 +436,8 @@ class Parameters:
def load(cls, path: os.PathLike) -> "Parameters":
return cls(**utils.load_toml(path))
@staticmethod
def strip_params_dict(dico: dict[str, Any]) -> dict[str, Any]:
@classmethod
def strip_params_dict(cls, dico: dict[str, Any]) -> dict[str, Any]:
"""prepares a dictionary for serialization. Some keys may not be preserved
(dropped because they take a lot of space and can be exactly reconstructed)
@@ -439,7 +467,7 @@ class Parameters:
types = (np.ndarray, float, int, str, list, tuple, dict, Path)
out = {}
for key, value in dico.items():
if key in forbiden_keys:
if key in forbiden_keys or key not in cls._p_names:
continue
if not isinstance(value, types):
continue
@@ -575,7 +603,7 @@ class Configuration:
def __validate_variable(self, vary_dict_list: list[dict[str, list]]):
for vary_dict in vary_dict_list:
for k, v in vary_dict.items():
p = getattr(Parameters, k)
p: Parameter = getattr(Parameters, k)
validator_list(p.validator)("variable " + k, v)
if k not in VALID_VARIABLE:
raise TypeError(f"{k!r} is not a valid variable parameter")

View File

@@ -9,12 +9,8 @@ from typing import Any, Generator, Optional, Type, Union
import numpy as np
from .. import utils
from ..errors import EvaluatorError
from ..logger import get_logger
from ..operators import (
AbstractConservedQuantity,
CurrentState,
)
from ..operators import AbstractConservedQuantity, CurrentState
from ..parameter import Configuration, Parameters
from ..pbar import PBars, ProgressBarActor, progress_worker
@@ -97,9 +93,9 @@ class RK4IP:
self.store_num = len(self.z_targets)
# Setup initial values for every physical quantity that we want to track
try:
if self.params.A_eff_arr is not None:
C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
except EvaluatorError:
else:
C_to_A_factor = 1.0
z = self.z_targets.pop(0)
# Initial step size

View File

@@ -1072,8 +1072,9 @@ def annotate_fwhm(
def partial_plot(root: os.PathLike):
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 8))
path = Path(root)
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 8))
fig.suptitle(path.name)
spec_list = sorted(
path.glob(SPEC1_FN.format("*")), key=lambda el: int(re.search("[0-9]+", el.name)[0])
)