improved Parameters machinery
This commit is contained in:
@@ -175,9 +175,10 @@ def merge(args):
|
||||
|
||||
|
||||
def preview(args):
|
||||
path = args.path
|
||||
for path in utils.simulations_list(args.path):
|
||||
partial_plot(path)
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
|
||||
def prep_ray():
|
||||
|
||||
@@ -95,17 +95,6 @@ class EvalStat:
|
||||
priority: float = np.inf
|
||||
|
||||
|
||||
class pdict(dict):
|
||||
"""a dictionary that cannot have any None value"""
|
||||
|
||||
def __setitem__(self, k, v):
|
||||
if v is None:
|
||||
if k in self:
|
||||
del self[k]
|
||||
else:
|
||||
super().__setitem__(k, v)
|
||||
|
||||
|
||||
class Evaluator:
|
||||
defaults: dict[str, Any] = {}
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import datetime as datetime_module
|
||||
import enum
|
||||
import os
|
||||
@@ -8,14 +7,16 @@ import time
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field, fields
|
||||
from functools import lru_cache
|
||||
from math import isnan
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable, Iterator, TypeVar, Union
|
||||
from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import env, legacy, utils
|
||||
from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__
|
||||
from .evaluator import Evaluator, pdict
|
||||
from .errors import EvaluatorError
|
||||
from .evaluator import Evaluator
|
||||
from .logger import get_logger
|
||||
from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator
|
||||
from .utils import fiber_folder, update_path_name
|
||||
@@ -199,43 +200,51 @@ class Parameter:
|
||||
converts a valid value (for example, str.lower), by default None
|
||||
default : callable, optional
|
||||
factory function for a default value (for example, list), by default None
|
||||
display_info : tuple[float, str], optional
|
||||
a factor by which to multiply the value and a string to be appended as a suffix
|
||||
when displaying the value
|
||||
example : (1e-6, "MW") will mean the value 1.12e6 is displayed as '1.12MW'
|
||||
"""
|
||||
|
||||
self.validator = validator
|
||||
self.__validator = validator
|
||||
self.converter = converter
|
||||
self.default = default
|
||||
self.display_info = display_info
|
||||
self.value = None
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
def __set_name__(self, owner: Type[Parameters], name):
|
||||
self.name = name
|
||||
try:
|
||||
owner._p_names.add(self.name)
|
||||
except AttributeError:
|
||||
pass
|
||||
if self.default is not None:
|
||||
Evaluator.register_default_param(self.name, self.default)
|
||||
VariationDescriptor.register_formatter(self.name, self.display)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
def __get__(self, instance: Parameters, owner):
|
||||
if instance is None:
|
||||
return self
|
||||
if self.name not in instance._param_dico:
|
||||
try:
|
||||
instance._evaluator.compute(self.name)
|
||||
return instance._param_dico[self.name]
|
||||
except EvaluatorError:
|
||||
pass
|
||||
return instance._param_dico.get(self.name)
|
||||
# return instance.__dict__[self.name]
|
||||
|
||||
def __delete__(self, instance):
|
||||
raise AttributeError("Cannot delete parameter")
|
||||
|
||||
def __set__(self, instance, value):
|
||||
def __set__(self, instance: Parameters, value):
|
||||
if isinstance(value, Parameter):
|
||||
defaut = None if self.default is None else copy(self.default)
|
||||
instance._param_dico[self.name] = defaut
|
||||
# instance.__dict__[self.name] = defaut
|
||||
if self.default is not None:
|
||||
instance._param_dico[self.name] = copy(self.default)
|
||||
else:
|
||||
if value is not None:
|
||||
if self.converter is not None:
|
||||
value = self.converter(value)
|
||||
self.validator(self.name, value)
|
||||
is_value, value = self.validate(value)
|
||||
if is_value:
|
||||
instance._param_dico[self.name] = value
|
||||
# instance.__dict__[self.name] = value
|
||||
else:
|
||||
if self.name in instance._param_dico:
|
||||
del instance._param_dico[self.name]
|
||||
|
||||
def display(self, num: float) -> str:
|
||||
if self.display_info is None:
|
||||
@@ -247,6 +256,22 @@ class Parameter:
|
||||
num_str = num_str[:-3]
|
||||
return f"{num_str} {unit}"
|
||||
|
||||
def validate(self: Parameter, v) -> tuple[bool, Any]:
|
||||
if v is None:
|
||||
is_value = False
|
||||
try:
|
||||
is_value = not isnan(v)
|
||||
except TypeError:
|
||||
is_value = True
|
||||
if is_value:
|
||||
if self.converter is not None:
|
||||
v = self.converter(v)
|
||||
self.__validator(self.name, v)
|
||||
return is_value, v
|
||||
|
||||
def validator(self, name, value):
|
||||
self.validate(value)
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class Parameters:
|
||||
@@ -255,8 +280,9 @@ class Parameters:
|
||||
"""
|
||||
|
||||
# internal machinery
|
||||
_param_dico: pdict[str, Any] = field(init=False, default_factory=pdict, repr=False)
|
||||
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
|
||||
_evaluator: Evaluator = field(init=False, repr=False)
|
||||
_p_names: ClassVar[Set[str]] = set()
|
||||
|
||||
# root
|
||||
name: str = Parameter(string, default="no name")
|
||||
@@ -383,7 +409,9 @@ class Parameters:
|
||||
def __repr_list__(self) -> Iterator[str]:
|
||||
yield from (f"{k}={v}" for k, v in self.dump_dict().items())
|
||||
|
||||
def dump_dict(self) -> dict[str, Any]:
|
||||
def dump_dict(self, compute=True) -> dict[str, Any]:
|
||||
if compute:
|
||||
self.compute_in_place()
|
||||
param = Parameters.strip_params_dict(self._param_dico)
|
||||
param["datetime"] = datetime_module.datetime.now()
|
||||
param["version"] = __version__
|
||||
@@ -408,8 +436,8 @@ class Parameters:
|
||||
def load(cls, path: os.PathLike) -> "Parameters":
|
||||
return cls(**utils.load_toml(path))
|
||||
|
||||
@staticmethod
|
||||
def strip_params_dict(dico: dict[str, Any]) -> dict[str, Any]:
|
||||
@classmethod
|
||||
def strip_params_dict(cls, dico: dict[str, Any]) -> dict[str, Any]:
|
||||
"""prepares a dictionary for serialization. Some keys may not be preserved
|
||||
(dropped because they take a lot of space and can be exactly reconstructed)
|
||||
|
||||
@@ -439,7 +467,7 @@ class Parameters:
|
||||
types = (np.ndarray, float, int, str, list, tuple, dict, Path)
|
||||
out = {}
|
||||
for key, value in dico.items():
|
||||
if key in forbiden_keys:
|
||||
if key in forbiden_keys or key not in cls._p_names:
|
||||
continue
|
||||
if not isinstance(value, types):
|
||||
continue
|
||||
@@ -575,7 +603,7 @@ class Configuration:
|
||||
def __validate_variable(self, vary_dict_list: list[dict[str, list]]):
|
||||
for vary_dict in vary_dict_list:
|
||||
for k, v in vary_dict.items():
|
||||
p = getattr(Parameters, k)
|
||||
p: Parameter = getattr(Parameters, k)
|
||||
validator_list(p.validator)("variable " + k, v)
|
||||
if k not in VALID_VARIABLE:
|
||||
raise TypeError(f"{k!r} is not a valid variable parameter")
|
||||
|
||||
@@ -9,12 +9,8 @@ from typing import Any, Generator, Optional, Type, Union
|
||||
import numpy as np
|
||||
|
||||
from .. import utils
|
||||
from ..errors import EvaluatorError
|
||||
from ..logger import get_logger
|
||||
from ..operators import (
|
||||
AbstractConservedQuantity,
|
||||
CurrentState,
|
||||
)
|
||||
from ..operators import AbstractConservedQuantity, CurrentState
|
||||
from ..parameter import Configuration, Parameters
|
||||
from ..pbar import PBars, ProgressBarActor, progress_worker
|
||||
|
||||
@@ -97,9 +93,9 @@ class RK4IP:
|
||||
self.store_num = len(self.z_targets)
|
||||
|
||||
# Setup initial values for every physical quantity that we want to track
|
||||
try:
|
||||
if self.params.A_eff_arr is not None:
|
||||
C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
|
||||
except EvaluatorError:
|
||||
else:
|
||||
C_to_A_factor = 1.0
|
||||
z = self.z_targets.pop(0)
|
||||
# Initial step size
|
||||
|
||||
@@ -1072,8 +1072,9 @@ def annotate_fwhm(
|
||||
|
||||
|
||||
def partial_plot(root: os.PathLike):
|
||||
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 8))
|
||||
path = Path(root)
|
||||
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 8))
|
||||
fig.suptitle(path.name)
|
||||
spec_list = sorted(
|
||||
path.glob(SPEC1_FN.format("*")), key=lambda el: int(re.search("[0-9]+", el.name)[0])
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user