From aebc0cef85b23b0791c4399dc98be0e3f365bcca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Thu, 21 Oct 2021 15:46:39 +0200 Subject: [PATCH] Parameters are now computed lazily --- .gitignore | 1 + src/scgenerator/const.py | 2 +- src/scgenerator/errors.py | 2 +- src/scgenerator/evaluator.py | 15 ++++++- src/scgenerator/legacy.py | 2 +- src/scgenerator/parameter.py | 63 +++++++++++++++------------- src/scgenerator/physics/fiber.py | 1 - src/scgenerator/physics/materials.py | 15 ++++--- src/scgenerator/physics/simulate.py | 5 +-- src/scgenerator/scripts/__init__.py | 3 +- src/scgenerator/spectra.py | 1 - 11 files changed, 61 insertions(+), 49 deletions(-) diff --git a/.gitignore b/.gitignore index 3fc7b58..1e04fc0 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ **/*.npy plots* +/make_*.py Archive *.mp4 *.png diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index 5a68d20..e6ecafe 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -1,4 +1,4 @@ -__version__ = "0.2.3rules" +__version__ = "0.2.4dev" from typing import Any diff --git a/src/scgenerator/errors.py b/src/scgenerator/errors.py index 8f211e6..3b1d790 100644 --- a/src/scgenerator/errors.py +++ b/src/scgenerator/errors.py @@ -40,5 +40,5 @@ class EvaluatorError(Exception): pass -class NoDefaultError(Exception): +class NoDefaultError(EvaluatorError): pass diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index 1372c54..db82c03 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -95,6 +95,17 @@ class EvalStat: priority: float = np.inf +class pdict(dict): + """a dictionary that cannot have any None value""" + + def __setitem__(self, k, v): + if v is None: + if k in self: + del self[k] + else: + super().__setitem__(k, v) + + class Evaluator: defaults: dict[str, Any] = {} @@ -231,7 +242,7 @@ class Evaluator: if param_name == target: value = returned_value break - except (EvaluatorError, KeyError, NoDefaultError) as e: + except EvaluatorError as e: error = e self.logger.debug( prefix + f"error using {rule.func.__name__} : {str(error).strip()}" @@ -269,7 +280,7 @@ class Evaluator: def validate_condition(self, rule: Rule) -> bool: try: return all(self.compute(k) == v for k, v in rule.conditions.items()) - except (EvaluatorError, KeyError, NoDefaultError): + except EvaluatorError: return False def attempted_rules_str(self, target: str) -> str: diff --git a/src/scgenerator/legacy.py b/src/scgenerator/legacy.py index c812edd..e2cd9f0 100644 --- a/src/scgenerator/legacy.py +++ b/src/scgenerator/legacy.py @@ -67,7 +67,7 @@ def convert_sim_folder(path: os.PathLike): processed_specs.add(descr) if (parent := descr.parent) is not None: new_params.prev_data_dir = str(new_paths[parent].final_path) - save_parameters(new_params.prepare_for_dump(), new_params.final_path) + save_parameters(new_params.dump_dict(), new_params.final_path) for spec_num in range(start_z, end_z): old_spec = old_path / SPECN_FN1.format(spec_num) if move_specs: diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 2a3e4fa..5871509 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -5,7 +5,7 @@ import enum import os import time from copy import copy -from dataclasses import asdict, dataclass, fields +from dataclasses import dataclass, field, fields from functools import lru_cache from pathlib import Path from typing import Any, Callable, Iterable, Iterator, TypeVar, Union @@ -14,7 +14,7 @@ import numpy as np from . import env, legacy, utils from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__ -from .evaluator import Evaluator +from .evaluator import Evaluator, pdict from .logger import get_logger from .operators import LinearOperator, NonLinearOperator from .utils import fiber_folder, update_path_name @@ -204,6 +204,7 @@ class Parameter: self.converter = converter self.default = default self.display_info = display_info + self.value = None def __set_name__(self, owner, name): self.name = name @@ -214,7 +215,10 @@ class Parameter: def __get__(self, instance, owner): if instance is None: return self - return instance.__dict__[self.name] + if self.name not in instance._param_dico: + instance._evaluator.compute(self.name) + return instance._param_dico[self.name] + # return instance.__dict__[self.name] def __delete__(self, instance): raise AttributeError("Cannot delete parameter") @@ -222,13 +226,15 @@ class Parameter: def __set__(self, instance, value): if isinstance(value, Parameter): defaut = None if self.default is None else copy(self.default) - instance.__dict__[self.name] = defaut + instance._param_dico[self.name] = defaut + # instance.__dict__[self.name] = defaut else: if value is not None: if self.converter is not None: value = self.converter(value) self.validator(self.name, value) - instance.__dict__[self.name] = value + instance._param_dico[self.name] = value + # instance.__dict__[self.name] = value def display(self, num: float) -> str: if self.display_info is None: @@ -241,12 +247,16 @@ class Parameter: return f"{num_str} {unit}" -@dataclass +@dataclass(repr=False) class Parameters: """ This class defines each valid parameter's name, type and valid value. """ + # internal machinery + _param_dico: pdict[str, Any] = field(init=False, default_factory=pdict, repr=False) + _evaluator: Evaluator = field(init=False, repr=False) + # root name: str = Parameter(string, default="no name") prev_data_dir: str = Parameter(string) @@ -348,40 +358,37 @@ class Parameters: L_D: float = Parameter(non_negative(float, int)) L_NL: float = Parameter(non_negative(float, int)) L_sol: float = Parameter(non_negative(float, int)) - dynamic_dispersion: bool = Parameter(boolean) adapt_step_size: bool = Parameter(boolean) hr_w: np.ndarray = Parameter(type_checker(np.ndarray)) z_targets: np.ndarray = Parameter(type_checker(np.ndarray)) const_qty: np.ndarray = Parameter(type_checker(np.ndarray)) - beta_func: Callable[[float], list[float]] = Parameter(func_validator) - gamma_func: Callable[[float], float] = Parameter(func_validator) num: int = Parameter(non_negative(int)) datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime)) version: str = Parameter(string) - def prepare_for_dump(self) -> dict[str, Any]: - param = asdict(self) - param = Parameters.strip_params_dict(param) + def __post_init__(self): + self._evaluator = Evaluator.default() + self._evaluator.set(self._param_dico) + + def __repr__(self) -> str: + return "Parameter(" + ", ".join(f"{k}={v}" for k, v in self.dump_dict().items()) + ")" + + def dump_dict(self) -> dict[str, Any]: + param = Parameters.strip_params_dict(self._param_dico) param["datetime"] = datetime_module.datetime.now() param["version"] = __version__ return param - def compute(self, to_compute: list[str] = MANDATORY_PARAMETERS): - param_dict = {k: v for k, v in asdict(self).items() if v is not None} - evaluator = Evaluator.default() - evaluator.set(**param_dict) - results = [evaluator.compute(p_name) for p_name in to_compute] - valid_fields = self.all_parameters() - for k, v in evaluator.params.items(): - if k in valid_fields: - setattr(self, k, v) - return results + def compute_in_place(self, *to_compute: str): + if len(to_compute) == 0: + to_compute = MANDATORY_PARAMETERS + for k in to_compute: + getattr(self, k) def pformat(self) -> str: return "\n".join( - f"{k} = {VariationDescriptor.format_value(k, v)}" - for k, v in self.prepare_for_dump().items() + f"{k} = {VariationDescriptor.format_value(k, v)}" for k, v in self.dump_dict().items() ) @classmethod @@ -392,12 +399,6 @@ class Parameters: def load(cls, path: os.PathLike) -> "Parameters": return cls(**utils.load_toml(path)) - @classmethod - def load_and_compute(cls, path: os.PathLike) -> "Parameters": - p = cls.load(path) - p.compute() - return p - @staticmethod def strip_params_dict(dico: dict[str, Any]) -> dict[str, Any]: """prepares a dictionary for serialization. Some keys may not be preserved @@ -409,6 +410,8 @@ class Parameters: dictionary """ forbiden_keys = { + "_param_dico", + "_evaluator", "w_c", "w_power_fact", "field_0", diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 779a83b..3997fd3 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -5,7 +5,6 @@ from numpy import e from numpy.fft import fft from numpy.polynomial.chebyshev import Chebyshev, cheb2poly from scipy.interpolate import interp1d -from sympy import re from .. import utils from ..cache import np_cache diff --git a/src/scgenerator/physics/materials.py b/src/scgenerator/physics/materials.py index 43dee9a..218b3ac 100644 --- a/src/scgenerator/physics/materials.py +++ b/src/scgenerator/physics/materials.py @@ -1,4 +1,4 @@ -from typing import Any, Callable +from typing import Callable import numpy as np import scipy.special @@ -130,12 +130,12 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None): chi = np.zeros_like(lambda_) # = n^2 - 1 if kind == 1: logger.debug("materials : using Sellmeier 1st kind equation") - for b, c in zip(B, C): - chi[ind] += temp_l ** 2 * b / (temp_l ** 2 - c) + for b, c_ in zip(B, C): + chi[ind] += temp_l ** 2 * b / (temp_l ** 2 - c_) elif kind == 2: # gives n-1 logger.debug("materials : using Sellmeier 2nd kind equation") - for b, c in zip(B, C): - chi[ind] += b / (c - 1 / temp_l ** 2) + for b, c_ in zip(B, C): + chi[ind] += b / (c_ - 1 / temp_l ** 2) chi += const chi = (chi + 1) ** 2 - 1 elif kind == 3: # Schott formula @@ -239,7 +239,10 @@ def ionization_rate_ADK( omega_p = ionization_energy / hbar nstar = Z * np.sqrt(2.1787e-18 / ionization_energy) - omega_t = lambda field: e * np.abs(field) / np.sqrt(2 * me * ionization_energy) + + def omega_t(field): + return e * np.abs(field) / np.sqrt(2 * me * ionization_energy) + Cnstar = 2 ** (2 * nstar) / (scipy.special.gamma(nstar + 1) ** 2) omega_pC = omega_p * Cnstar diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 947731a..4a208c1 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -440,8 +440,7 @@ class Simulations: def _run_available(self): for _, params in self.configuration: - params.compute() - utils.save_parameters(params.prepare_for_dump(), params.output_path) + utils.save_parameters(params.dump_dict(), params.output_path) self.new_sim(params) self.finish() @@ -694,8 +693,6 @@ def parallel_RK4IP( ]: logger = get_logger(__name__) params = list(Configuration(config)) - for _, param in params: - param.compute() n = len(params) z_num = params[0][1].z_num diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index f223f10..af28fc2 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -239,7 +239,7 @@ def finish_plot(fig: plt.Figure, legend_axes: plt.Axes, all_labels: list[str], p handles, _ = legend_axes.get_legend_handles_labels() - legend = legend_axes.legend(handles, all_labels, prop=dict(size=8, family="monospace")) + legend_axes.legend(handles, all_labels, prop=dict(size=8, family="monospace")) out_path = env.output_path() @@ -261,7 +261,6 @@ def finish_plot(fig: plt.Figure, legend_axes: plt.Axes, all_labels: list[str], p def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]: cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"]) for style, (descriptor, params) in zip(cc, Configuration(config_path)): - params.compute() yield style, descriptor.branch.formatted_descriptor(), params diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 915118a..d27f793 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -134,7 +134,6 @@ class SimulationSeries: else: raise FileNotFoundError(f"No simulation in {path}") self.params = Parameters.load(self.path / PARAM_FN) - self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"]) self.t = self.params.t self.w = self.params.w if self.params.prev_data_dir is not None: