Parameters are now computed lazily
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -3,6 +3,7 @@
|
||||
**/*.npy
|
||||
|
||||
plots*
|
||||
/make_*.py
|
||||
Archive
|
||||
*.mp4
|
||||
*.png
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.2.3rules"
|
||||
__version__ = "0.2.4dev"
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -40,5 +40,5 @@ class EvaluatorError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoDefaultError(Exception):
|
||||
class NoDefaultError(EvaluatorError):
|
||||
pass
|
||||
|
||||
@@ -95,6 +95,17 @@ class EvalStat:
|
||||
priority: float = np.inf
|
||||
|
||||
|
||||
class pdict(dict):
|
||||
"""a dictionary that cannot have any None value"""
|
||||
|
||||
def __setitem__(self, k, v):
|
||||
if v is None:
|
||||
if k in self:
|
||||
del self[k]
|
||||
else:
|
||||
super().__setitem__(k, v)
|
||||
|
||||
|
||||
class Evaluator:
|
||||
defaults: dict[str, Any] = {}
|
||||
|
||||
@@ -231,7 +242,7 @@ class Evaluator:
|
||||
if param_name == target:
|
||||
value = returned_value
|
||||
break
|
||||
except (EvaluatorError, KeyError, NoDefaultError) as e:
|
||||
except EvaluatorError as e:
|
||||
error = e
|
||||
self.logger.debug(
|
||||
prefix + f"error using {rule.func.__name__} : {str(error).strip()}"
|
||||
@@ -269,7 +280,7 @@ class Evaluator:
|
||||
def validate_condition(self, rule: Rule) -> bool:
|
||||
try:
|
||||
return all(self.compute(k) == v for k, v in rule.conditions.items())
|
||||
except (EvaluatorError, KeyError, NoDefaultError):
|
||||
except EvaluatorError:
|
||||
return False
|
||||
|
||||
def attempted_rules_str(self, target: str) -> str:
|
||||
|
||||
@@ -67,7 +67,7 @@ def convert_sim_folder(path: os.PathLike):
|
||||
processed_specs.add(descr)
|
||||
if (parent := descr.parent) is not None:
|
||||
new_params.prev_data_dir = str(new_paths[parent].final_path)
|
||||
save_parameters(new_params.prepare_for_dump(), new_params.final_path)
|
||||
save_parameters(new_params.dump_dict(), new_params.final_path)
|
||||
for spec_num in range(start_z, end_z):
|
||||
old_spec = old_path / SPECN_FN1.format(spec_num)
|
||||
if move_specs:
|
||||
|
||||
@@ -5,7 +5,7 @@ import enum
|
||||
import os
|
||||
import time
|
||||
from copy import copy
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from dataclasses import dataclass, field, fields
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable, Iterator, TypeVar, Union
|
||||
@@ -14,7 +14,7 @@ import numpy as np
|
||||
|
||||
from . import env, legacy, utils
|
||||
from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__
|
||||
from .evaluator import Evaluator
|
||||
from .evaluator import Evaluator, pdict
|
||||
from .logger import get_logger
|
||||
from .operators import LinearOperator, NonLinearOperator
|
||||
from .utils import fiber_folder, update_path_name
|
||||
@@ -204,6 +204,7 @@ class Parameter:
|
||||
self.converter = converter
|
||||
self.default = default
|
||||
self.display_info = display_info
|
||||
self.value = None
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
self.name = name
|
||||
@@ -214,7 +215,10 @@ class Parameter:
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return self
|
||||
return instance.__dict__[self.name]
|
||||
if self.name not in instance._param_dico:
|
||||
instance._evaluator.compute(self.name)
|
||||
return instance._param_dico[self.name]
|
||||
# return instance.__dict__[self.name]
|
||||
|
||||
def __delete__(self, instance):
|
||||
raise AttributeError("Cannot delete parameter")
|
||||
@@ -222,13 +226,15 @@ class Parameter:
|
||||
def __set__(self, instance, value):
|
||||
if isinstance(value, Parameter):
|
||||
defaut = None if self.default is None else copy(self.default)
|
||||
instance.__dict__[self.name] = defaut
|
||||
instance._param_dico[self.name] = defaut
|
||||
# instance.__dict__[self.name] = defaut
|
||||
else:
|
||||
if value is not None:
|
||||
if self.converter is not None:
|
||||
value = self.converter(value)
|
||||
self.validator(self.name, value)
|
||||
instance.__dict__[self.name] = value
|
||||
instance._param_dico[self.name] = value
|
||||
# instance.__dict__[self.name] = value
|
||||
|
||||
def display(self, num: float) -> str:
|
||||
if self.display_info is None:
|
||||
@@ -241,12 +247,16 @@ class Parameter:
|
||||
return f"{num_str} {unit}"
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(repr=False)
|
||||
class Parameters:
|
||||
"""
|
||||
This class defines each valid parameter's name, type and valid value.
|
||||
"""
|
||||
|
||||
# internal machinery
|
||||
_param_dico: pdict[str, Any] = field(init=False, default_factory=pdict, repr=False)
|
||||
_evaluator: Evaluator = field(init=False, repr=False)
|
||||
|
||||
# root
|
||||
name: str = Parameter(string, default="no name")
|
||||
prev_data_dir: str = Parameter(string)
|
||||
@@ -348,40 +358,37 @@ class Parameters:
|
||||
L_D: float = Parameter(non_negative(float, int))
|
||||
L_NL: float = Parameter(non_negative(float, int))
|
||||
L_sol: float = Parameter(non_negative(float, int))
|
||||
dynamic_dispersion: bool = Parameter(boolean)
|
||||
adapt_step_size: bool = Parameter(boolean)
|
||||
hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
beta_func: Callable[[float], list[float]] = Parameter(func_validator)
|
||||
gamma_func: Callable[[float], float] = Parameter(func_validator)
|
||||
|
||||
num: int = Parameter(non_negative(int))
|
||||
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
||||
version: str = Parameter(string)
|
||||
|
||||
def prepare_for_dump(self) -> dict[str, Any]:
|
||||
param = asdict(self)
|
||||
param = Parameters.strip_params_dict(param)
|
||||
def __post_init__(self):
|
||||
self._evaluator = Evaluator.default()
|
||||
self._evaluator.set(self._param_dico)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Parameter(" + ", ".join(f"{k}={v}" for k, v in self.dump_dict().items()) + ")"
|
||||
|
||||
def dump_dict(self) -> dict[str, Any]:
|
||||
param = Parameters.strip_params_dict(self._param_dico)
|
||||
param["datetime"] = datetime_module.datetime.now()
|
||||
param["version"] = __version__
|
||||
return param
|
||||
|
||||
def compute(self, to_compute: list[str] = MANDATORY_PARAMETERS):
|
||||
param_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
||||
evaluator = Evaluator.default()
|
||||
evaluator.set(**param_dict)
|
||||
results = [evaluator.compute(p_name) for p_name in to_compute]
|
||||
valid_fields = self.all_parameters()
|
||||
for k, v in evaluator.params.items():
|
||||
if k in valid_fields:
|
||||
setattr(self, k, v)
|
||||
return results
|
||||
def compute_in_place(self, *to_compute: str):
|
||||
if len(to_compute) == 0:
|
||||
to_compute = MANDATORY_PARAMETERS
|
||||
for k in to_compute:
|
||||
getattr(self, k)
|
||||
|
||||
def pformat(self) -> str:
|
||||
return "\n".join(
|
||||
f"{k} = {VariationDescriptor.format_value(k, v)}"
|
||||
for k, v in self.prepare_for_dump().items()
|
||||
f"{k} = {VariationDescriptor.format_value(k, v)}" for k, v in self.dump_dict().items()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -392,12 +399,6 @@ class Parameters:
|
||||
def load(cls, path: os.PathLike) -> "Parameters":
|
||||
return cls(**utils.load_toml(path))
|
||||
|
||||
@classmethod
|
||||
def load_and_compute(cls, path: os.PathLike) -> "Parameters":
|
||||
p = cls.load(path)
|
||||
p.compute()
|
||||
return p
|
||||
|
||||
@staticmethod
|
||||
def strip_params_dict(dico: dict[str, Any]) -> dict[str, Any]:
|
||||
"""prepares a dictionary for serialization. Some keys may not be preserved
|
||||
@@ -409,6 +410,8 @@ class Parameters:
|
||||
dictionary
|
||||
"""
|
||||
forbiden_keys = {
|
||||
"_param_dico",
|
||||
"_evaluator",
|
||||
"w_c",
|
||||
"w_power_fact",
|
||||
"field_0",
|
||||
|
||||
@@ -5,7 +5,6 @@ from numpy import e
|
||||
from numpy.fft import fft
|
||||
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
|
||||
from scipy.interpolate import interp1d
|
||||
from sympy import re
|
||||
|
||||
from .. import utils
|
||||
from ..cache import np_cache
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import scipy.special
|
||||
@@ -130,12 +130,12 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None):
|
||||
chi = np.zeros_like(lambda_) # = n^2 - 1
|
||||
if kind == 1:
|
||||
logger.debug("materials : using Sellmeier 1st kind equation")
|
||||
for b, c in zip(B, C):
|
||||
chi[ind] += temp_l ** 2 * b / (temp_l ** 2 - c)
|
||||
for b, c_ in zip(B, C):
|
||||
chi[ind] += temp_l ** 2 * b / (temp_l ** 2 - c_)
|
||||
elif kind == 2: # gives n-1
|
||||
logger.debug("materials : using Sellmeier 2nd kind equation")
|
||||
for b, c in zip(B, C):
|
||||
chi[ind] += b / (c - 1 / temp_l ** 2)
|
||||
for b, c_ in zip(B, C):
|
||||
chi[ind] += b / (c_ - 1 / temp_l ** 2)
|
||||
chi += const
|
||||
chi = (chi + 1) ** 2 - 1
|
||||
elif kind == 3: # Schott formula
|
||||
@@ -239,7 +239,10 @@ def ionization_rate_ADK(
|
||||
|
||||
omega_p = ionization_energy / hbar
|
||||
nstar = Z * np.sqrt(2.1787e-18 / ionization_energy)
|
||||
omega_t = lambda field: e * np.abs(field) / np.sqrt(2 * me * ionization_energy)
|
||||
|
||||
def omega_t(field):
|
||||
return e * np.abs(field) / np.sqrt(2 * me * ionization_energy)
|
||||
|
||||
Cnstar = 2 ** (2 * nstar) / (scipy.special.gamma(nstar + 1) ** 2)
|
||||
omega_pC = omega_p * Cnstar
|
||||
|
||||
|
||||
@@ -440,8 +440,7 @@ class Simulations:
|
||||
|
||||
def _run_available(self):
|
||||
for _, params in self.configuration:
|
||||
params.compute()
|
||||
utils.save_parameters(params.prepare_for_dump(), params.output_path)
|
||||
utils.save_parameters(params.dump_dict(), params.output_path)
|
||||
|
||||
self.new_sim(params)
|
||||
self.finish()
|
||||
@@ -694,8 +693,6 @@ def parallel_RK4IP(
|
||||
]:
|
||||
logger = get_logger(__name__)
|
||||
params = list(Configuration(config))
|
||||
for _, param in params:
|
||||
param.compute()
|
||||
n = len(params)
|
||||
z_num = params[0][1].z_num
|
||||
|
||||
|
||||
@@ -239,7 +239,7 @@ def finish_plot(fig: plt.Figure, legend_axes: plt.Axes, all_labels: list[str], p
|
||||
|
||||
handles, _ = legend_axes.get_legend_handles_labels()
|
||||
|
||||
legend = legend_axes.legend(handles, all_labels, prop=dict(size=8, family="monospace"))
|
||||
legend_axes.legend(handles, all_labels, prop=dict(size=8, family="monospace"))
|
||||
|
||||
out_path = env.output_path()
|
||||
|
||||
@@ -261,7 +261,6 @@ def finish_plot(fig: plt.Figure, legend_axes: plt.Axes, all_labels: list[str], p
|
||||
def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]:
|
||||
cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"])
|
||||
for style, (descriptor, params) in zip(cc, Configuration(config_path)):
|
||||
params.compute()
|
||||
yield style, descriptor.branch.formatted_descriptor(), params
|
||||
|
||||
|
||||
|
||||
@@ -134,7 +134,6 @@ class SimulationSeries:
|
||||
else:
|
||||
raise FileNotFoundError(f"No simulation in {path}")
|
||||
self.params = Parameters.load(self.path / PARAM_FN)
|
||||
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
||||
self.t = self.params.t
|
||||
self.w = self.params.w
|
||||
if self.params.prev_data_dir is not None:
|
||||
|
||||
Reference in New Issue
Block a user