From 235097904667d57f4ab611bcbd3a9d8238e5bb53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Fri, 24 Mar 2023 09:34:40 +0100 Subject: [PATCH] cleanup with operators/value tracker/current_state - current_state now always computes its derived values - it implements a copy function --- pyproject.toml | 1 + src/scgenerator/evaluator.py | 4 +- src/scgenerator/operators.py | 189 ++++++++--------------------- src/scgenerator/parameter.py | 2 +- src/scgenerator/solver.py | 14 +-- testing/configs/Chang2011Fig2.toml | 5 +- testing/test_full_field.py | 13 +- tests/test_current_state.py | 31 +++++ 8 files changed, 95 insertions(+), 164 deletions(-) create mode 100644 tests/test_current_state.py diff --git a/pyproject.toml b/pyproject.toml index 6c67852..4b0ab77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,3 +32,4 @@ convention = "numpy" [tool.black] line-length = 100 + diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index 2c9fa07..cac8c40 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -312,7 +312,7 @@ default_rules: list[Rule] = [ Rule("w_num", len, ["w"]), Rule("dw", lambda w: w[1] - w[0]), Rule(["fft", "ifft"], utils.fft_functions, priorities=1), - Rule("interpolation_range", lambda dt: (2 * units.c * dt, 8e-6)), + Rule("interpolation_range", lambda dt: (max(100e-9, 2 * units.c * dt), 8e-6)), # Pulse Rule("field_0", pulse.finalize_pulse), Rule(["input_time", "input_field"], pulse.load_custom_field), @@ -340,7 +340,6 @@ default_rules: list[Rule] = [ Rule("L_NL", pulse.L_NL), Rule("L_sol", pulse.L_sol), Rule("c_to_a_factor", lambda: 1, priorities=-1), - Rule("c_to_a_factor", pulse.c_to_a_factor), # Fiber Dispersion Rule("w_for_disp", units.m, ["wl_for_disp"]), Rule("hr_w", fiber.delayed_raman_w), @@ -419,6 +418,7 @@ envelope_rules = default_rules + [ # Pulse Rule("spectrum_factor", pulse.spectrum_factor_envelope, priorities=-1), Rule("pre_field_0", pulse.initial_field_envelope, priorities=1), + Rule("c_to_a_factor", pulse.c_to_a_factor), # Dispersion Rule(["wl_for_disp", "dispersion_ind"], fiber.lambda_for_envelope_dispersion), Rule("beta2_coefficients", fiber.dispersion_coefficients), diff --git a/src/scgenerator/operators.py b/src/scgenerator/operators.py index d9e6236..db55660 100644 --- a/src/scgenerator/operators.py +++ b/src/scgenerator/operators.py @@ -5,8 +5,9 @@ Nothing except the solver should depend on this file from __future__ import annotations from abc import ABC, abstractmethod +from copy import deepcopy from dataclasses import dataclass -from typing import Callable +from typing import Any, Callable import numpy as np from scipy.interpolate import interp1d @@ -20,25 +21,25 @@ class CurrentState: length: float z: float current_step_size: float - step: int - conversion_factor: np.ndarray + conversion_factor: np.ndarray | float converter: Callable[[np.ndarray], np.ndarray] - __spectrum: np.ndarray - __spec2: np.ndarray - __field: np.ndarray - __field2: np.ndarray + stats: dict[str, Any] + spectrum: np.ndarray + spec2: np.ndarray + field: np.ndarray + field2: np.ndarray __slots__ = [ "length", "z", "current_step_size", - "step", "conversion_factor", "converter", - "_CurrentState__spectrum", - "_CurrentState__spec2", - "_CurrentState__field", - "_CurrentState__field2", + "spectrum", + "spectrum2", + "field", + "field2", + "stats", ] def __init__( @@ -46,18 +47,31 @@ class CurrentState: length: float, z: float, current_step_size: float, - step: int, spectrum: np.ndarray, - conversion_factor: np.ndarray, + conversion_factor: np.ndarray | float, converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft, + spectrum2: np.ndarray | None = None, + field: np.ndarray | None = None, + field2: np.ndarray | None = None, + stats: dict[str, Any] | None = None, ): self.length = length self.z = z self.current_step_size = current_step_size - self.step = step self.conversion_factor = conversion_factor self.converter = converter - self.spectrum = spectrum + + if spectrum2 is None and field is None and field2 is None: + self.set_spectrum(spectrum) + elif any(el is None for el in (spectrum2, field, field2)): + raise ValueError( + "You must provide either all three of (spectrum2, field, field2) or none of them" + ) + else: + self.spectrum2 = spectrum2 + self.field = field + self.field2 = field2 + self.stats = stats or {} @property def z_ratio(self) -> float: @@ -67,125 +81,25 @@ class CurrentState: def actual_spectrum(self) -> np.ndarray: return self.conversion_factor * self.spectrum - @property - def spectrum(self) -> np.ndarray: - return self.__spectrum - - @spectrum.setter - def spectrum(self, new_value: np.ndarray): - self.__spectrum = new_value - self.__spec2 = None - self.__field = None - self.__field2 = None - - @property - def spec2(self) -> np.ndarray: - if self.__spec2 is None: - self.__spec2 = math.abs2(self.spectrum) - return self.__spec2 - - @property - def field(self) -> np.ndarray: - if self.__field is None: - self.__field = self.converter(self.spectrum) - return self.__field - - @property - def field2(self) -> np.ndarray: - if self.__field2 is None: - self.__field2 = math.abs2(self.field) - return self.__field2 - - def force_values(self, spec2: np.ndarray, field: np.ndarray, field2: np.ndarray): - """force these values instead of recomputing them - - Parameters - ---------- - spectrum : np.ndarray - spectrum - spec2 : np.ndarray - |spectrum|^2 - field : np.ndarray - field = converter(spectrum) - field2 : np.ndarray - |field|^2 - """ - self.__spec2 = spec2 - self.__field = field - self.__field2 = field2 - - def replace(self, new_spectrum: np.ndarray) -> CurrentState: - """returns a new state with new attributes""" - return CurrentState( - length=self.length, - z=self.z, - current_step_size=self.current_step_size, - step=self.step, - conversion_factor=self.conversion_factor, - converter=self.converter, - spectrum=new_spectrum, - ) - - def with_params(self, **params) -> CurrentState: - """returns a new CurrentState with modified params, except for the solution""" - my_params = dict( - length=self.length, - z=self.z, - current_step_size=self.current_step_size, - step=self.step, - conversion_factor=self.conversion_factor, - converter=self.converter, - ) - new_state = CurrentState(spectrum=self.__spectrum, **(my_params | params)) - new_state.force_values(self.spec2, self.field, self.field2) - return new_state + def set_spectrum(self, new_spectrum: np.ndarray): + self.spectrum = new_spectrum + self.spectrum2 = math.abs2(self.spectrum) + self.field = self.converter(self.spectrum) + self.field2 = math.abs2(self.field) def copy(self) -> CurrentState: - new = CurrentState( - length=self.length, - z=self.z, - current_step_size=self.current_step_size, - step=self.step, - conversion_factor=self.conversion_factor, - converter=self.converter, - spectrum=self.__spectrum, + return CurrentState( + self.length, + self.z, + self.current_step_size, + self.spectrum.copy(), + self.conversion_factor, + self.converter, + self.spectrum2.copy(), + self.field.copy(), + self.field2.copy(), + deepcopy(self.stats), ) - new.force_values(self.__spec2, self.__field, self.__field2) - return new - - -class ValueTracker(ABC): - def values(self) -> dict[str, float]: - return {} - - def all_values(self) -> dict[str, float]: - out = self.values() - for operator in vars(self).values(): - if isinstance(operator, ValueTracker): - out = operator.all_values() | out - return out - - def __repr__(self) -> str: - value_pair_list = list(vars(self).items()) - if len(value_pair_list) == 0: - value_pair_str_list = "" - elif len(value_pair_list) == 1: - value_pair_str_list = [self.__value_repr(value_pair_list[0][0], value_pair_list[0][1])] - else: - value_pair_str_list = [k + "=" + self.__value_repr(k, v) for k, v in value_pair_list] - - return self.__class__.__name__ + "(" + ", ".join(value_pair_str_list) + ")" - - def __value_repr(self, k: str, v) -> str: - if k.endswith("_const") and isinstance(v, (list, np.ndarray, tuple)): - return repr(v[0]) - return repr(v) - - -class Operator(ValueTracker): - @abstractmethod - def __call__(self, state: CurrentState) -> np.ndarray: - pass class NoOpTime(Operator): @@ -540,7 +454,6 @@ class ConstantWaveVector(AbstractWaveVector): dispersion_ind: np.ndarray, w_order: np.ndarray, ): - self.beta_arr = np.zeros(w_num, dtype=float) self.beta_arr[dispersion_ind] = fiber.beta(w_for_disp, n_op())[2:-2] left_ind, *_, right_ind = np.nonzero(self.beta_arr[w_order])[0] @@ -817,7 +730,6 @@ class VariableScalarGamma(AbstractGamma): class Plasma(Operator): mat_plasma: plasma.Plasma gas_op: AbstractGas - ionization_fraction = 0.0 def __init__(self, dt: float, w: np.ndarray, gas_op: AbstractGas): self.gas_op = gas_op @@ -827,12 +739,9 @@ class Plasma(Operator): def __call__(self, state: CurrentState) -> np.ndarray: N0 = self.gas_op.number_density(state) plasma_info = self.mat_plasma(state.field, N0) - self.ionization_fraction = plasma_info.electron_density[-1] / N0 + state.stats["ionization_fraction"] = plasma_info.electron_density[-1] / N0 return self.factor_out * np.fft.rfft(plasma_info.polarization) - def values(self) -> dict[str, float]: - return dict(ionization_fraction=self.ionization_fraction) - class NoPlasma(NoOpFreq, Plasma): pass @@ -863,7 +772,7 @@ class PhotonNumberLoss(AbstractConservedQuantity): def __call__(self, state: CurrentState) -> float: return pulse.photon_number_with_loss( - state.spec2, + state.spectrum2, self.w, self.dw, self.gamma_op(state), @@ -879,7 +788,7 @@ class PhotonNumberNoLoss(AbstractConservedQuantity): self.gamma_op = gamma_op def __call__(self, state: CurrentState) -> float: - return pulse.photon_number(state.spec2, self.w, self.dw, self.gamma_op(state)) + return pulse.photon_number(state.spectrum2, self.w, self.dw, self.gamma_op(state)) class EnergyLoss(AbstractConservedQuantity): diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 51d66fa..edf96ab 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -408,7 +408,7 @@ class Parameters: gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray)) A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray)) spectrum_factor: float = Parameter(type_checker(float)) - c_to_a_factor: np.ndarray = Parameter(type_checker(float, np.ndarray)) + c_to_a_factor: np.ndarray = Parameter(type_checker(float, int, np.ndarray)) w: np.ndarray = Parameter(type_checker(np.ndarray)) l: np.ndarray = Parameter(type_checker(np.ndarray)) w_c: np.ndarray = Parameter(type_checker(np.ndarray)) diff --git a/src/scgenerator/solver.py b/src/scgenerator/solver.py index 8ca7054..63fd726 100644 --- a/src/scgenerator/solver.py +++ b/src/scgenerator/solver.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging from abc import abstractmethod from collections import defaultdict -from typing import Iterator, Type +from typing import Any, Iterator, Type import numba import numpy as np @@ -14,7 +14,6 @@ from scgenerator.operators import ( CurrentState, LinearOperator, NonLinearOperator, - ValueTracker, ) from scgenerator.utils import get_arg_names @@ -55,12 +54,12 @@ class IntegratorFactory: return cls(**kwargs) -class Integrator(ValueTracker): +class Integrator: linear_operator: LinearOperator nonlinear_operator: NonLinearOperator state: CurrentState target_error: float - _tracked_values: dict[str, float] + _tracked_values: dict[float, dict[str, Any]] logger: logging.Logger __factory: IntegratorFactory = IntegratorFactory() order = 4 @@ -109,16 +108,13 @@ class Integrator(ValueTracker): tracked values """ return self._tracked_values | dict(z=self.state.z, step=self.state.step) - - def record_tracked_values(self): - self._tracked_values = super().all_values() - + def nl(self, spectrum: np.ndarray) -> np.ndarray: return self.nonlinear_operator(self.state.replace(spectrum)) def accept_step( self, new_state: CurrentState, previous_step_size: float, next_step_size: float - ) -> CurrentState: + ): self.state = new_state self.state.current_step_size = next_step_size self.state.z += previous_step_size diff --git a/testing/configs/Chang2011Fig2.toml b/testing/configs/Chang2011Fig2.toml index 50ad148..8d4fa7e 100644 --- a/testing/configs/Chang2011Fig2.toml +++ b/testing/configs/Chang2011Fig2.toml @@ -2,7 +2,7 @@ name = "/Users/benoitsierro/tests/test_sc/Chang2011Fig2" wavelength = 800e-9 shape = "gaussian" -energy = 2.5e-6 +energy = 2.5e-7 width = 30e-15 core_radius = 10e-6 @@ -11,9 +11,8 @@ gas_name = "argon" pressure = 3.2e5 length = 0.1 -interpolation_range = [120e-9, 3000e-9] full_field = true +photoionization = false dt = 0.04e-15 t_num = 32768 z_num = 128 -step_size = 10e-6 diff --git a/testing/test_full_field.py b/testing/test_full_field.py index 0218dcd..f48cf36 100644 --- a/testing/test_full_field.py +++ b/testing/test_full_field.py @@ -1,19 +1,14 @@ -import warnings -import numpy as np -import rediscache import scgenerator as sc from customfunc.app import PlotApp -from scipy.interpolate import interp1d from tqdm import tqdm # warnings.filterwarnings("error") -@rediscache.rcache def get_specs(params: dict): p = sc.Parameters(**params) sim = sc.RK4IP(p) - return [s[-1] for s in tqdm(sim.irun(), total=p.z_num)], p.dump_dict() + return [s.actual_spectrum for _, s in tqdm(sim.irun(), total=p.z_num)], p.dump_dict() def main(): @@ -25,7 +20,7 @@ def main(): rt = sc.PlotRange(-500, 500, "fs") x, o, ext = rs.sort_axis(params.w) vmin = -50 - with PlotApp(i=(int, 0, params.z_num - 1)) as app: + with PlotApp(i=range(params.z_num)) as app: spec_ax = app[0] spec_ax.set_xlabel(rs.unit.label) field_ax = app[1] @@ -42,8 +37,8 @@ def main(): @app.cache def compute(i): - xt, field = sc.transform_1D_values(params.ifft(specs[i]), rt, params) - x, spec = sc.transform_1D_values(sc.abs2(specs[i]), rs, params, log=True) + xt, field = sc.transform_1D_values(params.ifft(specs[i]), rt, params=params) + x, spec = sc.transform_1D_values(sc.abs2(specs[i]), rs, params=params, log=True) # spec = np.where(spec > vmin, spec, vmin) field2 = sc.abs2(field) bot, top = sc.math.envelope_ind(field2) diff --git a/tests/test_current_state.py b/tests/test_current_state.py new file mode 100644 index 0000000..0767879 --- /dev/null +++ b/tests/test_current_state.py @@ -0,0 +1,31 @@ +import numpy as np +import pytest + +from scgenerator.operators import CurrentState + + +def test_creation(): + x = (np.linspace(0, 1, 128, dtype=complex),) + cs = CurrentState(1.0, 0, 0.1, x, 1.0) + + assert cs.converter is np.fft.ifft + assert cs.stats == {} + assert np.allclose(cs.spectrum2, np.abs(np.fft.ifft(x)) ** 2) + + with pytest.raises(ValueError): + cs = CurrentState(1.0, 0, 0.0, x, 1.0, spectrum2=np.abs(x) ** 3) + + cs = CurrentState(1.0, 0, 0.1, x, 1.0, spectrum2=x.copy(), field=x.copy(), field2=x.copy()) + + assert np.allclose(cs.spectrum2, cs.spectrum) + assert np.allclose(cs.spectrum, cs.field) + assert np.allclose(cs.field, cs.field2) + + +def test_copy(): + x = (np.linspace(0, 1, 128, dtype=complex),) + cs = CurrentState(1.0, 0, 0.1, x, 1.0) + cs2 = cs.copy() + + assert cs.spectrum is not cs2.spectrum + assert np.all(cs.field2 == cs2.field2)