improved Parameters machinery
This commit is contained in:
@@ -175,9 +175,10 @@ def merge(args):
|
|||||||
|
|
||||||
|
|
||||||
def preview(args):
|
def preview(args):
|
||||||
path = args.path
|
for path in utils.simulations_list(args.path):
|
||||||
partial_plot(path)
|
partial_plot(path)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
def prep_ray():
|
def prep_ray():
|
||||||
|
|||||||
@@ -95,17 +95,6 @@ class EvalStat:
|
|||||||
priority: float = np.inf
|
priority: float = np.inf
|
||||||
|
|
||||||
|
|
||||||
class pdict(dict):
|
|
||||||
"""a dictionary that cannot have any None value"""
|
|
||||||
|
|
||||||
def __setitem__(self, k, v):
|
|
||||||
if v is None:
|
|
||||||
if k in self:
|
|
||||||
del self[k]
|
|
||||||
else:
|
|
||||||
super().__setitem__(k, v)
|
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
defaults: dict[str, Any] = {}
|
defaults: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
import datetime as datetime_module
|
import datetime as datetime_module
|
||||||
import enum
|
import enum
|
||||||
import os
|
import os
|
||||||
@@ -8,14 +7,16 @@ import time
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from math import isnan
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Iterable, Iterator, TypeVar, Union
|
from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from . import env, legacy, utils
|
from . import env, legacy, utils
|
||||||
from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__
|
from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__
|
||||||
from .evaluator import Evaluator, pdict
|
from .errors import EvaluatorError
|
||||||
|
from .evaluator import Evaluator
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator
|
from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator
|
||||||
from .utils import fiber_folder, update_path_name
|
from .utils import fiber_folder, update_path_name
|
||||||
@@ -199,43 +200,51 @@ class Parameter:
|
|||||||
converts a valid value (for example, str.lower), by default None
|
converts a valid value (for example, str.lower), by default None
|
||||||
default : callable, optional
|
default : callable, optional
|
||||||
factory function for a default value (for example, list), by default None
|
factory function for a default value (for example, list), by default None
|
||||||
|
display_info : tuple[float, str], optional
|
||||||
|
a factor by which to multiply the value and a string to be appended as a suffix
|
||||||
|
when displaying the value
|
||||||
|
example : (1e-6, "MW") will mean the value 1.12e6 is displayed as '1.12MW'
|
||||||
"""
|
"""
|
||||||
|
self.__validator = validator
|
||||||
self.validator = validator
|
|
||||||
self.converter = converter
|
self.converter = converter
|
||||||
self.default = default
|
self.default = default
|
||||||
self.display_info = display_info
|
self.display_info = display_info
|
||||||
self.value = None
|
|
||||||
|
|
||||||
def __set_name__(self, owner, name):
|
def __set_name__(self, owner: Type[Parameters], name):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
try:
|
||||||
|
owner._p_names.add(self.name)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
if self.default is not None:
|
if self.default is not None:
|
||||||
Evaluator.register_default_param(self.name, self.default)
|
Evaluator.register_default_param(self.name, self.default)
|
||||||
VariationDescriptor.register_formatter(self.name, self.display)
|
VariationDescriptor.register_formatter(self.name, self.display)
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
def __get__(self, instance: Parameters, owner):
|
||||||
if instance is None:
|
if instance is None:
|
||||||
return self
|
return self
|
||||||
if self.name not in instance._param_dico:
|
if self.name not in instance._param_dico:
|
||||||
|
try:
|
||||||
instance._evaluator.compute(self.name)
|
instance._evaluator.compute(self.name)
|
||||||
return instance._param_dico[self.name]
|
except EvaluatorError:
|
||||||
|
pass
|
||||||
|
return instance._param_dico.get(self.name)
|
||||||
# return instance.__dict__[self.name]
|
# return instance.__dict__[self.name]
|
||||||
|
|
||||||
def __delete__(self, instance):
|
def __delete__(self, instance):
|
||||||
raise AttributeError("Cannot delete parameter")
|
raise AttributeError("Cannot delete parameter")
|
||||||
|
|
||||||
def __set__(self, instance, value):
|
def __set__(self, instance: Parameters, value):
|
||||||
if isinstance(value, Parameter):
|
if isinstance(value, Parameter):
|
||||||
defaut = None if self.default is None else copy(self.default)
|
if self.default is not None:
|
||||||
instance._param_dico[self.name] = defaut
|
instance._param_dico[self.name] = copy(self.default)
|
||||||
# instance.__dict__[self.name] = defaut
|
|
||||||
else:
|
else:
|
||||||
if value is not None:
|
is_value, value = self.validate(value)
|
||||||
if self.converter is not None:
|
if is_value:
|
||||||
value = self.converter(value)
|
|
||||||
self.validator(self.name, value)
|
|
||||||
instance._param_dico[self.name] = value
|
instance._param_dico[self.name] = value
|
||||||
# instance.__dict__[self.name] = value
|
else:
|
||||||
|
if self.name in instance._param_dico:
|
||||||
|
del instance._param_dico[self.name]
|
||||||
|
|
||||||
def display(self, num: float) -> str:
|
def display(self, num: float) -> str:
|
||||||
if self.display_info is None:
|
if self.display_info is None:
|
||||||
@@ -247,6 +256,22 @@ class Parameter:
|
|||||||
num_str = num_str[:-3]
|
num_str = num_str[:-3]
|
||||||
return f"{num_str} {unit}"
|
return f"{num_str} {unit}"
|
||||||
|
|
||||||
|
def validate(self: Parameter, v) -> tuple[bool, Any]:
|
||||||
|
if v is None:
|
||||||
|
is_value = False
|
||||||
|
try:
|
||||||
|
is_value = not isnan(v)
|
||||||
|
except TypeError:
|
||||||
|
is_value = True
|
||||||
|
if is_value:
|
||||||
|
if self.converter is not None:
|
||||||
|
v = self.converter(v)
|
||||||
|
self.__validator(self.name, v)
|
||||||
|
return is_value, v
|
||||||
|
|
||||||
|
def validator(self, name, value):
|
||||||
|
self.validate(value)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(repr=False)
|
@dataclass(repr=False)
|
||||||
class Parameters:
|
class Parameters:
|
||||||
@@ -255,8 +280,9 @@ class Parameters:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# internal machinery
|
# internal machinery
|
||||||
_param_dico: pdict[str, Any] = field(init=False, default_factory=pdict, repr=False)
|
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
|
||||||
_evaluator: Evaluator = field(init=False, repr=False)
|
_evaluator: Evaluator = field(init=False, repr=False)
|
||||||
|
_p_names: ClassVar[Set[str]] = set()
|
||||||
|
|
||||||
# root
|
# root
|
||||||
name: str = Parameter(string, default="no name")
|
name: str = Parameter(string, default="no name")
|
||||||
@@ -383,7 +409,9 @@ class Parameters:
|
|||||||
def __repr_list__(self) -> Iterator[str]:
|
def __repr_list__(self) -> Iterator[str]:
|
||||||
yield from (f"{k}={v}" for k, v in self.dump_dict().items())
|
yield from (f"{k}={v}" for k, v in self.dump_dict().items())
|
||||||
|
|
||||||
def dump_dict(self) -> dict[str, Any]:
|
def dump_dict(self, compute=True) -> dict[str, Any]:
|
||||||
|
if compute:
|
||||||
|
self.compute_in_place()
|
||||||
param = Parameters.strip_params_dict(self._param_dico)
|
param = Parameters.strip_params_dict(self._param_dico)
|
||||||
param["datetime"] = datetime_module.datetime.now()
|
param["datetime"] = datetime_module.datetime.now()
|
||||||
param["version"] = __version__
|
param["version"] = __version__
|
||||||
@@ -408,8 +436,8 @@ class Parameters:
|
|||||||
def load(cls, path: os.PathLike) -> "Parameters":
|
def load(cls, path: os.PathLike) -> "Parameters":
|
||||||
return cls(**utils.load_toml(path))
|
return cls(**utils.load_toml(path))
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def strip_params_dict(dico: dict[str, Any]) -> dict[str, Any]:
|
def strip_params_dict(cls, dico: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""prepares a dictionary for serialization. Some keys may not be preserved
|
"""prepares a dictionary for serialization. Some keys may not be preserved
|
||||||
(dropped because they take a lot of space and can be exactly reconstructed)
|
(dropped because they take a lot of space and can be exactly reconstructed)
|
||||||
|
|
||||||
@@ -439,7 +467,7 @@ class Parameters:
|
|||||||
types = (np.ndarray, float, int, str, list, tuple, dict, Path)
|
types = (np.ndarray, float, int, str, list, tuple, dict, Path)
|
||||||
out = {}
|
out = {}
|
||||||
for key, value in dico.items():
|
for key, value in dico.items():
|
||||||
if key in forbiden_keys:
|
if key in forbiden_keys or key not in cls._p_names:
|
||||||
continue
|
continue
|
||||||
if not isinstance(value, types):
|
if not isinstance(value, types):
|
||||||
continue
|
continue
|
||||||
@@ -575,7 +603,7 @@ class Configuration:
|
|||||||
def __validate_variable(self, vary_dict_list: list[dict[str, list]]):
|
def __validate_variable(self, vary_dict_list: list[dict[str, list]]):
|
||||||
for vary_dict in vary_dict_list:
|
for vary_dict in vary_dict_list:
|
||||||
for k, v in vary_dict.items():
|
for k, v in vary_dict.items():
|
||||||
p = getattr(Parameters, k)
|
p: Parameter = getattr(Parameters, k)
|
||||||
validator_list(p.validator)("variable " + k, v)
|
validator_list(p.validator)("variable " + k, v)
|
||||||
if k not in VALID_VARIABLE:
|
if k not in VALID_VARIABLE:
|
||||||
raise TypeError(f"{k!r} is not a valid variable parameter")
|
raise TypeError(f"{k!r} is not a valid variable parameter")
|
||||||
|
|||||||
@@ -9,12 +9,8 @@ from typing import Any, Generator, Optional, Type, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .. import utils
|
from .. import utils
|
||||||
from ..errors import EvaluatorError
|
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..operators import (
|
from ..operators import AbstractConservedQuantity, CurrentState
|
||||||
AbstractConservedQuantity,
|
|
||||||
CurrentState,
|
|
||||||
)
|
|
||||||
from ..parameter import Configuration, Parameters
|
from ..parameter import Configuration, Parameters
|
||||||
from ..pbar import PBars, ProgressBarActor, progress_worker
|
from ..pbar import PBars, ProgressBarActor, progress_worker
|
||||||
|
|
||||||
@@ -97,9 +93,9 @@ class RK4IP:
|
|||||||
self.store_num = len(self.z_targets)
|
self.store_num = len(self.z_targets)
|
||||||
|
|
||||||
# Setup initial values for every physical quantity that we want to track
|
# Setup initial values for every physical quantity that we want to track
|
||||||
try:
|
if self.params.A_eff_arr is not None:
|
||||||
C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
|
C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
|
||||||
except EvaluatorError:
|
else:
|
||||||
C_to_A_factor = 1.0
|
C_to_A_factor = 1.0
|
||||||
z = self.z_targets.pop(0)
|
z = self.z_targets.pop(0)
|
||||||
# Initial step size
|
# Initial step size
|
||||||
|
|||||||
@@ -1072,8 +1072,9 @@ def annotate_fwhm(
|
|||||||
|
|
||||||
|
|
||||||
def partial_plot(root: os.PathLike):
|
def partial_plot(root: os.PathLike):
|
||||||
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 8))
|
|
||||||
path = Path(root)
|
path = Path(root)
|
||||||
|
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 8))
|
||||||
|
fig.suptitle(path.name)
|
||||||
spec_list = sorted(
|
spec_list = sorted(
|
||||||
path.glob(SPEC1_FN.format("*")), key=lambda el: int(re.search("[0-9]+", el.name)[0])
|
path.glob(SPEC1_FN.format("*")), key=lambda el: int(re.search("[0-9]+", el.name)[0])
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user