Parameters are now computed lazily

This commit is contained in:
Benoît Sierro
2021-10-21 15:46:39 +02:00
parent f9f7f2b234
commit aebc0cef85
11 changed files with 61 additions and 49 deletions

1
.gitignore vendored
View File

@@ -3,6 +3,7 @@
**/*.npy
plots*
/make_*.py
Archive
*.mp4
*.png

View File

@@ -1,4 +1,4 @@
__version__ = "0.2.3rules"
__version__ = "0.2.4dev"
from typing import Any

View File

@@ -40,5 +40,5 @@ class EvaluatorError(Exception):
pass
class NoDefaultError(Exception):
class NoDefaultError(EvaluatorError):
pass

View File

@@ -95,6 +95,17 @@ class EvalStat:
priority: float = np.inf
class pdict(dict):
"""a dictionary that cannot have any None value"""
def __setitem__(self, k, v):
if v is None:
if k in self:
del self[k]
else:
super().__setitem__(k, v)
class Evaluator:
defaults: dict[str, Any] = {}
@@ -231,7 +242,7 @@ class Evaluator:
if param_name == target:
value = returned_value
break
except (EvaluatorError, KeyError, NoDefaultError) as e:
except EvaluatorError as e:
error = e
self.logger.debug(
prefix + f"error using {rule.func.__name__} : {str(error).strip()}"
@@ -269,7 +280,7 @@ class Evaluator:
def validate_condition(self, rule: Rule) -> bool:
try:
return all(self.compute(k) == v for k, v in rule.conditions.items())
except (EvaluatorError, KeyError, NoDefaultError):
except EvaluatorError:
return False
def attempted_rules_str(self, target: str) -> str:

View File

@@ -67,7 +67,7 @@ def convert_sim_folder(path: os.PathLike):
processed_specs.add(descr)
if (parent := descr.parent) is not None:
new_params.prev_data_dir = str(new_paths[parent].final_path)
save_parameters(new_params.prepare_for_dump(), new_params.final_path)
save_parameters(new_params.dump_dict(), new_params.final_path)
for spec_num in range(start_z, end_z):
old_spec = old_path / SPECN_FN1.format(spec_num)
if move_specs:

View File

@@ -5,7 +5,7 @@ import enum
import os
import time
from copy import copy
from dataclasses import asdict, dataclass, fields
from dataclasses import dataclass, field, fields
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Iterable, Iterator, TypeVar, Union
@@ -14,7 +14,7 @@ import numpy as np
from . import env, legacy, utils
from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__
from .evaluator import Evaluator
from .evaluator import Evaluator, pdict
from .logger import get_logger
from .operators import LinearOperator, NonLinearOperator
from .utils import fiber_folder, update_path_name
@@ -204,6 +204,7 @@ class Parameter:
self.converter = converter
self.default = default
self.display_info = display_info
self.value = None
def __set_name__(self, owner, name):
self.name = name
@@ -214,7 +215,10 @@ class Parameter:
def __get__(self, instance, owner):
if instance is None:
return self
return instance.__dict__[self.name]
if self.name not in instance._param_dico:
instance._evaluator.compute(self.name)
return instance._param_dico[self.name]
# return instance.__dict__[self.name]
def __delete__(self, instance):
raise AttributeError("Cannot delete parameter")
@@ -222,13 +226,15 @@ class Parameter:
def __set__(self, instance, value):
if isinstance(value, Parameter):
defaut = None if self.default is None else copy(self.default)
instance.__dict__[self.name] = defaut
instance._param_dico[self.name] = defaut
# instance.__dict__[self.name] = defaut
else:
if value is not None:
if self.converter is not None:
value = self.converter(value)
self.validator(self.name, value)
instance.__dict__[self.name] = value
instance._param_dico[self.name] = value
# instance.__dict__[self.name] = value
def display(self, num: float) -> str:
if self.display_info is None:
@@ -241,12 +247,16 @@ class Parameter:
return f"{num_str} {unit}"
@dataclass
@dataclass(repr=False)
class Parameters:
"""
This class defines each valid parameter's name, type and valid value.
"""
# internal machinery
_param_dico: pdict[str, Any] = field(init=False, default_factory=pdict, repr=False)
_evaluator: Evaluator = field(init=False, repr=False)
# root
name: str = Parameter(string, default="no name")
prev_data_dir: str = Parameter(string)
@@ -348,40 +358,37 @@ class Parameters:
L_D: float = Parameter(non_negative(float, int))
L_NL: float = Parameter(non_negative(float, int))
L_sol: float = Parameter(non_negative(float, int))
dynamic_dispersion: bool = Parameter(boolean)
adapt_step_size: bool = Parameter(boolean)
hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
beta_func: Callable[[float], list[float]] = Parameter(func_validator)
gamma_func: Callable[[float], float] = Parameter(func_validator)
num: int = Parameter(non_negative(int))
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
version: str = Parameter(string)
def prepare_for_dump(self) -> dict[str, Any]:
param = asdict(self)
param = Parameters.strip_params_dict(param)
def __post_init__(self):
self._evaluator = Evaluator.default()
self._evaluator.set(self._param_dico)
def __repr__(self) -> str:
return "Parameter(" + ", ".join(f"{k}={v}" for k, v in self.dump_dict().items()) + ")"
def dump_dict(self) -> dict[str, Any]:
param = Parameters.strip_params_dict(self._param_dico)
param["datetime"] = datetime_module.datetime.now()
param["version"] = __version__
return param
def compute(self, to_compute: list[str] = MANDATORY_PARAMETERS):
param_dict = {k: v for k, v in asdict(self).items() if v is not None}
evaluator = Evaluator.default()
evaluator.set(**param_dict)
results = [evaluator.compute(p_name) for p_name in to_compute]
valid_fields = self.all_parameters()
for k, v in evaluator.params.items():
if k in valid_fields:
setattr(self, k, v)
return results
def compute_in_place(self, *to_compute: str):
if len(to_compute) == 0:
to_compute = MANDATORY_PARAMETERS
for k in to_compute:
getattr(self, k)
def pformat(self) -> str:
return "\n".join(
f"{k} = {VariationDescriptor.format_value(k, v)}"
for k, v in self.prepare_for_dump().items()
f"{k} = {VariationDescriptor.format_value(k, v)}" for k, v in self.dump_dict().items()
)
@classmethod
@@ -392,12 +399,6 @@ class Parameters:
def load(cls, path: os.PathLike) -> "Parameters":
return cls(**utils.load_toml(path))
@classmethod
def load_and_compute(cls, path: os.PathLike) -> "Parameters":
p = cls.load(path)
p.compute()
return p
@staticmethod
def strip_params_dict(dico: dict[str, Any]) -> dict[str, Any]:
"""prepares a dictionary for serialization. Some keys may not be preserved
@@ -409,6 +410,8 @@ class Parameters:
dictionary
"""
forbiden_keys = {
"_param_dico",
"_evaluator",
"w_c",
"w_power_fact",
"field_0",

View File

@@ -5,7 +5,6 @@ from numpy import e
from numpy.fft import fft
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
from scipy.interpolate import interp1d
from sympy import re
from .. import utils
from ..cache import np_cache

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable
from typing import Callable
import numpy as np
import scipy.special
@@ -130,12 +130,12 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None):
chi = np.zeros_like(lambda_) # = n^2 - 1
if kind == 1:
logger.debug("materials : using Sellmeier 1st kind equation")
for b, c in zip(B, C):
chi[ind] += temp_l ** 2 * b / (temp_l ** 2 - c)
for b, c_ in zip(B, C):
chi[ind] += temp_l ** 2 * b / (temp_l ** 2 - c_)
elif kind == 2: # gives n-1
logger.debug("materials : using Sellmeier 2nd kind equation")
for b, c in zip(B, C):
chi[ind] += b / (c - 1 / temp_l ** 2)
for b, c_ in zip(B, C):
chi[ind] += b / (c_ - 1 / temp_l ** 2)
chi += const
chi = (chi + 1) ** 2 - 1
elif kind == 3: # Schott formula
@@ -239,7 +239,10 @@ def ionization_rate_ADK(
omega_p = ionization_energy / hbar
nstar = Z * np.sqrt(2.1787e-18 / ionization_energy)
omega_t = lambda field: e * np.abs(field) / np.sqrt(2 * me * ionization_energy)
def omega_t(field):
return e * np.abs(field) / np.sqrt(2 * me * ionization_energy)
Cnstar = 2 ** (2 * nstar) / (scipy.special.gamma(nstar + 1) ** 2)
omega_pC = omega_p * Cnstar

View File

@@ -440,8 +440,7 @@ class Simulations:
def _run_available(self):
for _, params in self.configuration:
params.compute()
utils.save_parameters(params.prepare_for_dump(), params.output_path)
utils.save_parameters(params.dump_dict(), params.output_path)
self.new_sim(params)
self.finish()
@@ -694,8 +693,6 @@ def parallel_RK4IP(
]:
logger = get_logger(__name__)
params = list(Configuration(config))
for _, param in params:
param.compute()
n = len(params)
z_num = params[0][1].z_num

View File

@@ -239,7 +239,7 @@ def finish_plot(fig: plt.Figure, legend_axes: plt.Axes, all_labels: list[str], p
handles, _ = legend_axes.get_legend_handles_labels()
legend = legend_axes.legend(handles, all_labels, prop=dict(size=8, family="monospace"))
legend_axes.legend(handles, all_labels, prop=dict(size=8, family="monospace"))
out_path = env.output_path()
@@ -261,7 +261,6 @@ def finish_plot(fig: plt.Figure, legend_axes: plt.Axes, all_labels: list[str], p
def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]:
cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"])
for style, (descriptor, params) in zip(cc, Configuration(config_path)):
params.compute()
yield style, descriptor.branch.formatted_descriptor(), params

View File

@@ -134,7 +134,6 @@ class SimulationSeries:
else:
raise FileNotFoundError(f"No simulation in {path}")
self.params = Parameters.load(self.path / PARAM_FN)
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
self.t = self.params.t
self.w = self.params.w
if self.params.prev_data_dir is not None: