cleanup with operators/value tracker/current_state

- current_state now always computes its derived values
- it implements a copy function
This commit is contained in:
Benoît Sierro
2023-03-24 09:34:40 +01:00
parent 504f40edd2
commit 2350979046
8 changed files with 95 additions and 164 deletions

View File

@@ -32,3 +32,4 @@ convention = "numpy"
[tool.black] [tool.black]
line-length = 100 line-length = 100

View File

@@ -312,7 +312,7 @@ default_rules: list[Rule] = [
Rule("w_num", len, ["w"]), Rule("w_num", len, ["w"]),
Rule("dw", lambda w: w[1] - w[0]), Rule("dw", lambda w: w[1] - w[0]),
Rule(["fft", "ifft"], utils.fft_functions, priorities=1), Rule(["fft", "ifft"], utils.fft_functions, priorities=1),
Rule("interpolation_range", lambda dt: (2 * units.c * dt, 8e-6)), Rule("interpolation_range", lambda dt: (max(100e-9, 2 * units.c * dt), 8e-6)),
# Pulse # Pulse
Rule("field_0", pulse.finalize_pulse), Rule("field_0", pulse.finalize_pulse),
Rule(["input_time", "input_field"], pulse.load_custom_field), Rule(["input_time", "input_field"], pulse.load_custom_field),
@@ -340,7 +340,6 @@ default_rules: list[Rule] = [
Rule("L_NL", pulse.L_NL), Rule("L_NL", pulse.L_NL),
Rule("L_sol", pulse.L_sol), Rule("L_sol", pulse.L_sol),
Rule("c_to_a_factor", lambda: 1, priorities=-1), Rule("c_to_a_factor", lambda: 1, priorities=-1),
Rule("c_to_a_factor", pulse.c_to_a_factor),
# Fiber Dispersion # Fiber Dispersion
Rule("w_for_disp", units.m, ["wl_for_disp"]), Rule("w_for_disp", units.m, ["wl_for_disp"]),
Rule("hr_w", fiber.delayed_raman_w), Rule("hr_w", fiber.delayed_raman_w),
@@ -419,6 +418,7 @@ envelope_rules = default_rules + [
# Pulse # Pulse
Rule("spectrum_factor", pulse.spectrum_factor_envelope, priorities=-1), Rule("spectrum_factor", pulse.spectrum_factor_envelope, priorities=-1),
Rule("pre_field_0", pulse.initial_field_envelope, priorities=1), Rule("pre_field_0", pulse.initial_field_envelope, priorities=1),
Rule("c_to_a_factor", pulse.c_to_a_factor),
# Dispersion # Dispersion
Rule(["wl_for_disp", "dispersion_ind"], fiber.lambda_for_envelope_dispersion), Rule(["wl_for_disp", "dispersion_ind"], fiber.lambda_for_envelope_dispersion),
Rule("beta2_coefficients", fiber.dispersion_coefficients), Rule("beta2_coefficients", fiber.dispersion_coefficients),

View File

@@ -5,8 +5,9 @@ Nothing except the solver should depend on this file
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable from typing import Any, Callable
import numpy as np import numpy as np
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
@@ -20,25 +21,25 @@ class CurrentState:
length: float length: float
z: float z: float
current_step_size: float current_step_size: float
step: int conversion_factor: np.ndarray | float
conversion_factor: np.ndarray
converter: Callable[[np.ndarray], np.ndarray] converter: Callable[[np.ndarray], np.ndarray]
__spectrum: np.ndarray stats: dict[str, Any]
__spec2: np.ndarray spectrum: np.ndarray
__field: np.ndarray spec2: np.ndarray
__field2: np.ndarray field: np.ndarray
field2: np.ndarray
__slots__ = [ __slots__ = [
"length", "length",
"z", "z",
"current_step_size", "current_step_size",
"step",
"conversion_factor", "conversion_factor",
"converter", "converter",
"_CurrentState__spectrum", "spectrum",
"_CurrentState__spec2", "spectrum2",
"_CurrentState__field", "field",
"_CurrentState__field2", "field2",
"stats",
] ]
def __init__( def __init__(
@@ -46,18 +47,31 @@ class CurrentState:
length: float, length: float,
z: float, z: float,
current_step_size: float, current_step_size: float,
step: int,
spectrum: np.ndarray, spectrum: np.ndarray,
conversion_factor: np.ndarray, conversion_factor: np.ndarray | float,
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft, converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft,
spectrum2: np.ndarray | None = None,
field: np.ndarray | None = None,
field2: np.ndarray | None = None,
stats: dict[str, Any] | None = None,
): ):
self.length = length self.length = length
self.z = z self.z = z
self.current_step_size = current_step_size self.current_step_size = current_step_size
self.step = step
self.conversion_factor = conversion_factor self.conversion_factor = conversion_factor
self.converter = converter self.converter = converter
self.spectrum = spectrum
if spectrum2 is None and field is None and field2 is None:
self.set_spectrum(spectrum)
elif any(el is None for el in (spectrum2, field, field2)):
raise ValueError(
"You must provide either all three of (spectrum2, field, field2) or none of them"
)
else:
self.spectrum2 = spectrum2
self.field = field
self.field2 = field2
self.stats = stats or {}
@property @property
def z_ratio(self) -> float: def z_ratio(self) -> float:
@@ -67,125 +81,25 @@ class CurrentState:
def actual_spectrum(self) -> np.ndarray: def actual_spectrum(self) -> np.ndarray:
return self.conversion_factor * self.spectrum return self.conversion_factor * self.spectrum
@property def set_spectrum(self, new_spectrum: np.ndarray):
def spectrum(self) -> np.ndarray: self.spectrum = new_spectrum
return self.__spectrum self.spectrum2 = math.abs2(self.spectrum)
self.field = self.converter(self.spectrum)
@spectrum.setter self.field2 = math.abs2(self.field)
def spectrum(self, new_value: np.ndarray):
self.__spectrum = new_value
self.__spec2 = None
self.__field = None
self.__field2 = None
@property
def spec2(self) -> np.ndarray:
if self.__spec2 is None:
self.__spec2 = math.abs2(self.spectrum)
return self.__spec2
@property
def field(self) -> np.ndarray:
if self.__field is None:
self.__field = self.converter(self.spectrum)
return self.__field
@property
def field2(self) -> np.ndarray:
if self.__field2 is None:
self.__field2 = math.abs2(self.field)
return self.__field2
def force_values(self, spec2: np.ndarray, field: np.ndarray, field2: np.ndarray):
"""force these values instead of recomputing them
Parameters
----------
spectrum : np.ndarray
spectrum
spec2 : np.ndarray
|spectrum|^2
field : np.ndarray
field = converter(spectrum)
field2 : np.ndarray
|field|^2
"""
self.__spec2 = spec2
self.__field = field
self.__field2 = field2
def replace(self, new_spectrum: np.ndarray) -> CurrentState:
"""returns a new state with new attributes"""
return CurrentState(
length=self.length,
z=self.z,
current_step_size=self.current_step_size,
step=self.step,
conversion_factor=self.conversion_factor,
converter=self.converter,
spectrum=new_spectrum,
)
def with_params(self, **params) -> CurrentState:
"""returns a new CurrentState with modified params, except for the solution"""
my_params = dict(
length=self.length,
z=self.z,
current_step_size=self.current_step_size,
step=self.step,
conversion_factor=self.conversion_factor,
converter=self.converter,
)
new_state = CurrentState(spectrum=self.__spectrum, **(my_params | params))
new_state.force_values(self.spec2, self.field, self.field2)
return new_state
def copy(self) -> CurrentState: def copy(self) -> CurrentState:
new = CurrentState( return CurrentState(
length=self.length, self.length,
z=self.z, self.z,
current_step_size=self.current_step_size, self.current_step_size,
step=self.step, self.spectrum.copy(),
conversion_factor=self.conversion_factor, self.conversion_factor,
converter=self.converter, self.converter,
spectrum=self.__spectrum, self.spectrum2.copy(),
self.field.copy(),
self.field2.copy(),
deepcopy(self.stats),
) )
new.force_values(self.__spec2, self.__field, self.__field2)
return new
class ValueTracker(ABC):
def values(self) -> dict[str, float]:
return {}
def all_values(self) -> dict[str, float]:
out = self.values()
for operator in vars(self).values():
if isinstance(operator, ValueTracker):
out = operator.all_values() | out
return out
def __repr__(self) -> str:
value_pair_list = list(vars(self).items())
if len(value_pair_list) == 0:
value_pair_str_list = ""
elif len(value_pair_list) == 1:
value_pair_str_list = [self.__value_repr(value_pair_list[0][0], value_pair_list[0][1])]
else:
value_pair_str_list = [k + "=" + self.__value_repr(k, v) for k, v in value_pair_list]
return self.__class__.__name__ + "(" + ", ".join(value_pair_str_list) + ")"
def __value_repr(self, k: str, v) -> str:
if k.endswith("_const") and isinstance(v, (list, np.ndarray, tuple)):
return repr(v[0])
return repr(v)
class Operator(ValueTracker):
@abstractmethod
def __call__(self, state: CurrentState) -> np.ndarray:
pass
class NoOpTime(Operator): class NoOpTime(Operator):
@@ -540,7 +454,6 @@ class ConstantWaveVector(AbstractWaveVector):
dispersion_ind: np.ndarray, dispersion_ind: np.ndarray,
w_order: np.ndarray, w_order: np.ndarray,
): ):
self.beta_arr = np.zeros(w_num, dtype=float) self.beta_arr = np.zeros(w_num, dtype=float)
self.beta_arr[dispersion_ind] = fiber.beta(w_for_disp, n_op())[2:-2] self.beta_arr[dispersion_ind] = fiber.beta(w_for_disp, n_op())[2:-2]
left_ind, *_, right_ind = np.nonzero(self.beta_arr[w_order])[0] left_ind, *_, right_ind = np.nonzero(self.beta_arr[w_order])[0]
@@ -817,7 +730,6 @@ class VariableScalarGamma(AbstractGamma):
class Plasma(Operator): class Plasma(Operator):
mat_plasma: plasma.Plasma mat_plasma: plasma.Plasma
gas_op: AbstractGas gas_op: AbstractGas
ionization_fraction = 0.0
def __init__(self, dt: float, w: np.ndarray, gas_op: AbstractGas): def __init__(self, dt: float, w: np.ndarray, gas_op: AbstractGas):
self.gas_op = gas_op self.gas_op = gas_op
@@ -827,12 +739,9 @@ class Plasma(Operator):
def __call__(self, state: CurrentState) -> np.ndarray: def __call__(self, state: CurrentState) -> np.ndarray:
N0 = self.gas_op.number_density(state) N0 = self.gas_op.number_density(state)
plasma_info = self.mat_plasma(state.field, N0) plasma_info = self.mat_plasma(state.field, N0)
self.ionization_fraction = plasma_info.electron_density[-1] / N0 state.stats["ionization_fraction"] = plasma_info.electron_density[-1] / N0
return self.factor_out * np.fft.rfft(plasma_info.polarization) return self.factor_out * np.fft.rfft(plasma_info.polarization)
def values(self) -> dict[str, float]:
return dict(ionization_fraction=self.ionization_fraction)
class NoPlasma(NoOpFreq, Plasma): class NoPlasma(NoOpFreq, Plasma):
pass pass
@@ -863,7 +772,7 @@ class PhotonNumberLoss(AbstractConservedQuantity):
def __call__(self, state: CurrentState) -> float: def __call__(self, state: CurrentState) -> float:
return pulse.photon_number_with_loss( return pulse.photon_number_with_loss(
state.spec2, state.spectrum2,
self.w, self.w,
self.dw, self.dw,
self.gamma_op(state), self.gamma_op(state),
@@ -879,7 +788,7 @@ class PhotonNumberNoLoss(AbstractConservedQuantity):
self.gamma_op = gamma_op self.gamma_op = gamma_op
def __call__(self, state: CurrentState) -> float: def __call__(self, state: CurrentState) -> float:
return pulse.photon_number(state.spec2, self.w, self.dw, self.gamma_op(state)) return pulse.photon_number(state.spectrum2, self.w, self.dw, self.gamma_op(state))
class EnergyLoss(AbstractConservedQuantity): class EnergyLoss(AbstractConservedQuantity):

View File

@@ -408,7 +408,7 @@ class Parameters:
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray)) gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray)) A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray))
spectrum_factor: float = Parameter(type_checker(float)) spectrum_factor: float = Parameter(type_checker(float))
c_to_a_factor: np.ndarray = Parameter(type_checker(float, np.ndarray)) c_to_a_factor: np.ndarray = Parameter(type_checker(float, int, np.ndarray))
w: np.ndarray = Parameter(type_checker(np.ndarray)) w: np.ndarray = Parameter(type_checker(np.ndarray))
l: np.ndarray = Parameter(type_checker(np.ndarray)) l: np.ndarray = Parameter(type_checker(np.ndarray))
w_c: np.ndarray = Parameter(type_checker(np.ndarray)) w_c: np.ndarray = Parameter(type_checker(np.ndarray))

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from collections import defaultdict from collections import defaultdict
from typing import Iterator, Type from typing import Any, Iterator, Type
import numba import numba
import numpy as np import numpy as np
@@ -14,7 +14,6 @@ from scgenerator.operators import (
CurrentState, CurrentState,
LinearOperator, LinearOperator,
NonLinearOperator, NonLinearOperator,
ValueTracker,
) )
from scgenerator.utils import get_arg_names from scgenerator.utils import get_arg_names
@@ -55,12 +54,12 @@ class IntegratorFactory:
return cls(**kwargs) return cls(**kwargs)
class Integrator(ValueTracker): class Integrator:
linear_operator: LinearOperator linear_operator: LinearOperator
nonlinear_operator: NonLinearOperator nonlinear_operator: NonLinearOperator
state: CurrentState state: CurrentState
target_error: float target_error: float
_tracked_values: dict[str, float] _tracked_values: dict[float, dict[str, Any]]
logger: logging.Logger logger: logging.Logger
__factory: IntegratorFactory = IntegratorFactory() __factory: IntegratorFactory = IntegratorFactory()
order = 4 order = 4
@@ -109,16 +108,13 @@ class Integrator(ValueTracker):
tracked values tracked values
""" """
return self._tracked_values | dict(z=self.state.z, step=self.state.step) return self._tracked_values | dict(z=self.state.z, step=self.state.step)
def record_tracked_values(self):
self._tracked_values = super().all_values()
def nl(self, spectrum: np.ndarray) -> np.ndarray: def nl(self, spectrum: np.ndarray) -> np.ndarray:
return self.nonlinear_operator(self.state.replace(spectrum)) return self.nonlinear_operator(self.state.replace(spectrum))
def accept_step( def accept_step(
self, new_state: CurrentState, previous_step_size: float, next_step_size: float self, new_state: CurrentState, previous_step_size: float, next_step_size: float
) -> CurrentState: ):
self.state = new_state self.state = new_state
self.state.current_step_size = next_step_size self.state.current_step_size = next_step_size
self.state.z += previous_step_size self.state.z += previous_step_size

View File

@@ -2,7 +2,7 @@ name = "/Users/benoitsierro/tests/test_sc/Chang2011Fig2"
wavelength = 800e-9 wavelength = 800e-9
shape = "gaussian" shape = "gaussian"
energy = 2.5e-6 energy = 2.5e-7
width = 30e-15 width = 30e-15
core_radius = 10e-6 core_radius = 10e-6
@@ -11,9 +11,8 @@ gas_name = "argon"
pressure = 3.2e5 pressure = 3.2e5
length = 0.1 length = 0.1
interpolation_range = [120e-9, 3000e-9]
full_field = true full_field = true
photoionization = false
dt = 0.04e-15 dt = 0.04e-15
t_num = 32768 t_num = 32768
z_num = 128 z_num = 128
step_size = 10e-6

View File

@@ -1,19 +1,14 @@
import warnings
import numpy as np
import rediscache
import scgenerator as sc import scgenerator as sc
from customfunc.app import PlotApp from customfunc.app import PlotApp
from scipy.interpolate import interp1d
from tqdm import tqdm from tqdm import tqdm
# warnings.filterwarnings("error") # warnings.filterwarnings("error")
@rediscache.rcache
def get_specs(params: dict): def get_specs(params: dict):
p = sc.Parameters(**params) p = sc.Parameters(**params)
sim = sc.RK4IP(p) sim = sc.RK4IP(p)
return [s[-1] for s in tqdm(sim.irun(), total=p.z_num)], p.dump_dict() return [s.actual_spectrum for _, s in tqdm(sim.irun(), total=p.z_num)], p.dump_dict()
def main(): def main():
@@ -25,7 +20,7 @@ def main():
rt = sc.PlotRange(-500, 500, "fs") rt = sc.PlotRange(-500, 500, "fs")
x, o, ext = rs.sort_axis(params.w) x, o, ext = rs.sort_axis(params.w)
vmin = -50 vmin = -50
with PlotApp(i=(int, 0, params.z_num - 1)) as app: with PlotApp(i=range(params.z_num)) as app:
spec_ax = app[0] spec_ax = app[0]
spec_ax.set_xlabel(rs.unit.label) spec_ax.set_xlabel(rs.unit.label)
field_ax = app[1] field_ax = app[1]
@@ -42,8 +37,8 @@ def main():
@app.cache @app.cache
def compute(i): def compute(i):
xt, field = sc.transform_1D_values(params.ifft(specs[i]), rt, params) xt, field = sc.transform_1D_values(params.ifft(specs[i]), rt, params=params)
x, spec = sc.transform_1D_values(sc.abs2(specs[i]), rs, params, log=True) x, spec = sc.transform_1D_values(sc.abs2(specs[i]), rs, params=params, log=True)
# spec = np.where(spec > vmin, spec, vmin) # spec = np.where(spec > vmin, spec, vmin)
field2 = sc.abs2(field) field2 = sc.abs2(field)
bot, top = sc.math.envelope_ind(field2) bot, top = sc.math.envelope_ind(field2)

View File

@@ -0,0 +1,31 @@
import numpy as np
import pytest
from scgenerator.operators import CurrentState
def test_creation():
x = (np.linspace(0, 1, 128, dtype=complex),)
cs = CurrentState(1.0, 0, 0.1, x, 1.0)
assert cs.converter is np.fft.ifft
assert cs.stats == {}
assert np.allclose(cs.spectrum2, np.abs(np.fft.ifft(x)) ** 2)
with pytest.raises(ValueError):
cs = CurrentState(1.0, 0, 0.0, x, 1.0, spectrum2=np.abs(x) ** 3)
cs = CurrentState(1.0, 0, 0.1, x, 1.0, spectrum2=x.copy(), field=x.copy(), field2=x.copy())
assert np.allclose(cs.spectrum2, cs.spectrum)
assert np.allclose(cs.spectrum, cs.field)
assert np.allclose(cs.field, cs.field2)
def test_copy():
x = (np.linspace(0, 1, 128, dtype=complex),)
cs = CurrentState(1.0, 0, 0.1, x, 1.0)
cs2 = cs.copy()
assert cs.spectrum is not cs2.spectrum
assert np.all(cs.field2 == cs2.field2)