cleanup with operators/value tracker/current_state

- current_state now always computes its derived values
- it implements a copy function
This commit is contained in:
Benoît Sierro
2023-03-24 09:34:40 +01:00
parent 504f40edd2
commit 2350979046
8 changed files with 95 additions and 164 deletions

View File

@@ -32,3 +32,4 @@ convention = "numpy"
[tool.black]
line-length = 100

View File

@@ -312,7 +312,7 @@ default_rules: list[Rule] = [
Rule("w_num", len, ["w"]),
Rule("dw", lambda w: w[1] - w[0]),
Rule(["fft", "ifft"], utils.fft_functions, priorities=1),
Rule("interpolation_range", lambda dt: (2 * units.c * dt, 8e-6)),
Rule("interpolation_range", lambda dt: (max(100e-9, 2 * units.c * dt), 8e-6)),
# Pulse
Rule("field_0", pulse.finalize_pulse),
Rule(["input_time", "input_field"], pulse.load_custom_field),
@@ -340,7 +340,6 @@ default_rules: list[Rule] = [
Rule("L_NL", pulse.L_NL),
Rule("L_sol", pulse.L_sol),
Rule("c_to_a_factor", lambda: 1, priorities=-1),
Rule("c_to_a_factor", pulse.c_to_a_factor),
# Fiber Dispersion
Rule("w_for_disp", units.m, ["wl_for_disp"]),
Rule("hr_w", fiber.delayed_raman_w),
@@ -419,6 +418,7 @@ envelope_rules = default_rules + [
# Pulse
Rule("spectrum_factor", pulse.spectrum_factor_envelope, priorities=-1),
Rule("pre_field_0", pulse.initial_field_envelope, priorities=1),
Rule("c_to_a_factor", pulse.c_to_a_factor),
# Dispersion
Rule(["wl_for_disp", "dispersion_ind"], fiber.lambda_for_envelope_dispersion),
Rule("beta2_coefficients", fiber.dispersion_coefficients),

View File

@@ -5,8 +5,9 @@ Nothing except the solver should depend on this file
from __future__ import annotations
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import Callable
from typing import Any, Callable
import numpy as np
from scipy.interpolate import interp1d
@@ -20,25 +21,25 @@ class CurrentState:
length: float
z: float
current_step_size: float
step: int
conversion_factor: np.ndarray
conversion_factor: np.ndarray | float
converter: Callable[[np.ndarray], np.ndarray]
__spectrum: np.ndarray
__spec2: np.ndarray
__field: np.ndarray
__field2: np.ndarray
stats: dict[str, Any]
spectrum: np.ndarray
spec2: np.ndarray
field: np.ndarray
field2: np.ndarray
__slots__ = [
"length",
"z",
"current_step_size",
"step",
"conversion_factor",
"converter",
"_CurrentState__spectrum",
"_CurrentState__spec2",
"_CurrentState__field",
"_CurrentState__field2",
"spectrum",
"spectrum2",
"field",
"field2",
"stats",
]
def __init__(
@@ -46,18 +47,31 @@ class CurrentState:
length: float,
z: float,
current_step_size: float,
step: int,
spectrum: np.ndarray,
conversion_factor: np.ndarray,
conversion_factor: np.ndarray | float,
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft,
spectrum2: np.ndarray | None = None,
field: np.ndarray | None = None,
field2: np.ndarray | None = None,
stats: dict[str, Any] | None = None,
):
self.length = length
self.z = z
self.current_step_size = current_step_size
self.step = step
self.conversion_factor = conversion_factor
self.converter = converter
self.spectrum = spectrum
if spectrum2 is None and field is None and field2 is None:
self.set_spectrum(spectrum)
elif any(el is None for el in (spectrum2, field, field2)):
raise ValueError(
"You must provide either all three of (spectrum2, field, field2) or none of them"
)
else:
self.spectrum2 = spectrum2
self.field = field
self.field2 = field2
self.stats = stats or {}
@property
def z_ratio(self) -> float:
@@ -67,125 +81,25 @@ class CurrentState:
def actual_spectrum(self) -> np.ndarray:
return self.conversion_factor * self.spectrum
@property
def spectrum(self) -> np.ndarray:
return self.__spectrum
@spectrum.setter
def spectrum(self, new_value: np.ndarray):
self.__spectrum = new_value
self.__spec2 = None
self.__field = None
self.__field2 = None
@property
def spec2(self) -> np.ndarray:
if self.__spec2 is None:
self.__spec2 = math.abs2(self.spectrum)
return self.__spec2
@property
def field(self) -> np.ndarray:
if self.__field is None:
self.__field = self.converter(self.spectrum)
return self.__field
@property
def field2(self) -> np.ndarray:
if self.__field2 is None:
self.__field2 = math.abs2(self.field)
return self.__field2
def force_values(self, spec2: np.ndarray, field: np.ndarray, field2: np.ndarray):
"""force these values instead of recomputing them
Parameters
----------
spectrum : np.ndarray
spectrum
spec2 : np.ndarray
|spectrum|^2
field : np.ndarray
field = converter(spectrum)
field2 : np.ndarray
|field|^2
"""
self.__spec2 = spec2
self.__field = field
self.__field2 = field2
def replace(self, new_spectrum: np.ndarray) -> CurrentState:
"""returns a new state with new attributes"""
return CurrentState(
length=self.length,
z=self.z,
current_step_size=self.current_step_size,
step=self.step,
conversion_factor=self.conversion_factor,
converter=self.converter,
spectrum=new_spectrum,
)
def with_params(self, **params) -> CurrentState:
"""returns a new CurrentState with modified params, except for the solution"""
my_params = dict(
length=self.length,
z=self.z,
current_step_size=self.current_step_size,
step=self.step,
conversion_factor=self.conversion_factor,
converter=self.converter,
)
new_state = CurrentState(spectrum=self.__spectrum, **(my_params | params))
new_state.force_values(self.spec2, self.field, self.field2)
return new_state
def set_spectrum(self, new_spectrum: np.ndarray):
self.spectrum = new_spectrum
self.spectrum2 = math.abs2(self.spectrum)
self.field = self.converter(self.spectrum)
self.field2 = math.abs2(self.field)
def copy(self) -> CurrentState:
new = CurrentState(
length=self.length,
z=self.z,
current_step_size=self.current_step_size,
step=self.step,
conversion_factor=self.conversion_factor,
converter=self.converter,
spectrum=self.__spectrum,
return CurrentState(
self.length,
self.z,
self.current_step_size,
self.spectrum.copy(),
self.conversion_factor,
self.converter,
self.spectrum2.copy(),
self.field.copy(),
self.field2.copy(),
deepcopy(self.stats),
)
new.force_values(self.__spec2, self.__field, self.__field2)
return new
class ValueTracker(ABC):
def values(self) -> dict[str, float]:
return {}
def all_values(self) -> dict[str, float]:
out = self.values()
for operator in vars(self).values():
if isinstance(operator, ValueTracker):
out = operator.all_values() | out
return out
def __repr__(self) -> str:
value_pair_list = list(vars(self).items())
if len(value_pair_list) == 0:
value_pair_str_list = ""
elif len(value_pair_list) == 1:
value_pair_str_list = [self.__value_repr(value_pair_list[0][0], value_pair_list[0][1])]
else:
value_pair_str_list = [k + "=" + self.__value_repr(k, v) for k, v in value_pair_list]
return self.__class__.__name__ + "(" + ", ".join(value_pair_str_list) + ")"
def __value_repr(self, k: str, v) -> str:
if k.endswith("_const") and isinstance(v, (list, np.ndarray, tuple)):
return repr(v[0])
return repr(v)
class Operator(ValueTracker):
@abstractmethod
def __call__(self, state: CurrentState) -> np.ndarray:
pass
class NoOpTime(Operator):
@@ -540,7 +454,6 @@ class ConstantWaveVector(AbstractWaveVector):
dispersion_ind: np.ndarray,
w_order: np.ndarray,
):
self.beta_arr = np.zeros(w_num, dtype=float)
self.beta_arr[dispersion_ind] = fiber.beta(w_for_disp, n_op())[2:-2]
left_ind, *_, right_ind = np.nonzero(self.beta_arr[w_order])[0]
@@ -817,7 +730,6 @@ class VariableScalarGamma(AbstractGamma):
class Plasma(Operator):
mat_plasma: plasma.Plasma
gas_op: AbstractGas
ionization_fraction = 0.0
def __init__(self, dt: float, w: np.ndarray, gas_op: AbstractGas):
self.gas_op = gas_op
@@ -827,12 +739,9 @@ class Plasma(Operator):
def __call__(self, state: CurrentState) -> np.ndarray:
N0 = self.gas_op.number_density(state)
plasma_info = self.mat_plasma(state.field, N0)
self.ionization_fraction = plasma_info.electron_density[-1] / N0
state.stats["ionization_fraction"] = plasma_info.electron_density[-1] / N0
return self.factor_out * np.fft.rfft(plasma_info.polarization)
def values(self) -> dict[str, float]:
return dict(ionization_fraction=self.ionization_fraction)
class NoPlasma(NoOpFreq, Plasma):
pass
@@ -863,7 +772,7 @@ class PhotonNumberLoss(AbstractConservedQuantity):
def __call__(self, state: CurrentState) -> float:
return pulse.photon_number_with_loss(
state.spec2,
state.spectrum2,
self.w,
self.dw,
self.gamma_op(state),
@@ -879,7 +788,7 @@ class PhotonNumberNoLoss(AbstractConservedQuantity):
self.gamma_op = gamma_op
def __call__(self, state: CurrentState) -> float:
return pulse.photon_number(state.spec2, self.w, self.dw, self.gamma_op(state))
return pulse.photon_number(state.spectrum2, self.w, self.dw, self.gamma_op(state))
class EnergyLoss(AbstractConservedQuantity):

View File

@@ -408,7 +408,7 @@ class Parameters:
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray))
spectrum_factor: float = Parameter(type_checker(float))
c_to_a_factor: np.ndarray = Parameter(type_checker(float, np.ndarray))
c_to_a_factor: np.ndarray = Parameter(type_checker(float, int, np.ndarray))
w: np.ndarray = Parameter(type_checker(np.ndarray))
l: np.ndarray = Parameter(type_checker(np.ndarray))
w_c: np.ndarray = Parameter(type_checker(np.ndarray))

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from abc import abstractmethod
from collections import defaultdict
from typing import Iterator, Type
from typing import Any, Iterator, Type
import numba
import numpy as np
@@ -14,7 +14,6 @@ from scgenerator.operators import (
CurrentState,
LinearOperator,
NonLinearOperator,
ValueTracker,
)
from scgenerator.utils import get_arg_names
@@ -55,12 +54,12 @@ class IntegratorFactory:
return cls(**kwargs)
class Integrator(ValueTracker):
class Integrator:
linear_operator: LinearOperator
nonlinear_operator: NonLinearOperator
state: CurrentState
target_error: float
_tracked_values: dict[str, float]
_tracked_values: dict[float, dict[str, Any]]
logger: logging.Logger
__factory: IntegratorFactory = IntegratorFactory()
order = 4
@@ -110,15 +109,12 @@ class Integrator(ValueTracker):
"""
return self._tracked_values | dict(z=self.state.z, step=self.state.step)
def record_tracked_values(self):
self._tracked_values = super().all_values()
def nl(self, spectrum: np.ndarray) -> np.ndarray:
return self.nonlinear_operator(self.state.replace(spectrum))
def accept_step(
self, new_state: CurrentState, previous_step_size: float, next_step_size: float
) -> CurrentState:
):
self.state = new_state
self.state.current_step_size = next_step_size
self.state.z += previous_step_size

View File

@@ -2,7 +2,7 @@ name = "/Users/benoitsierro/tests/test_sc/Chang2011Fig2"
wavelength = 800e-9
shape = "gaussian"
energy = 2.5e-6
energy = 2.5e-7
width = 30e-15
core_radius = 10e-6
@@ -11,9 +11,8 @@ gas_name = "argon"
pressure = 3.2e5
length = 0.1
interpolation_range = [120e-9, 3000e-9]
full_field = true
photoionization = false
dt = 0.04e-15
t_num = 32768
z_num = 128
step_size = 10e-6

View File

@@ -1,19 +1,14 @@
import warnings
import numpy as np
import rediscache
import scgenerator as sc
from customfunc.app import PlotApp
from scipy.interpolate import interp1d
from tqdm import tqdm
# warnings.filterwarnings("error")
@rediscache.rcache
def get_specs(params: dict):
p = sc.Parameters(**params)
sim = sc.RK4IP(p)
return [s[-1] for s in tqdm(sim.irun(), total=p.z_num)], p.dump_dict()
return [s.actual_spectrum for _, s in tqdm(sim.irun(), total=p.z_num)], p.dump_dict()
def main():
@@ -25,7 +20,7 @@ def main():
rt = sc.PlotRange(-500, 500, "fs")
x, o, ext = rs.sort_axis(params.w)
vmin = -50
with PlotApp(i=(int, 0, params.z_num - 1)) as app:
with PlotApp(i=range(params.z_num)) as app:
spec_ax = app[0]
spec_ax.set_xlabel(rs.unit.label)
field_ax = app[1]
@@ -42,8 +37,8 @@ def main():
@app.cache
def compute(i):
xt, field = sc.transform_1D_values(params.ifft(specs[i]), rt, params)
x, spec = sc.transform_1D_values(sc.abs2(specs[i]), rs, params, log=True)
xt, field = sc.transform_1D_values(params.ifft(specs[i]), rt, params=params)
x, spec = sc.transform_1D_values(sc.abs2(specs[i]), rs, params=params, log=True)
# spec = np.where(spec > vmin, spec, vmin)
field2 = sc.abs2(field)
bot, top = sc.math.envelope_ind(field2)

View File

@@ -0,0 +1,31 @@
import numpy as np
import pytest
from scgenerator.operators import CurrentState
def test_creation():
x = (np.linspace(0, 1, 128, dtype=complex),)
cs = CurrentState(1.0, 0, 0.1, x, 1.0)
assert cs.converter is np.fft.ifft
assert cs.stats == {}
assert np.allclose(cs.spectrum2, np.abs(np.fft.ifft(x)) ** 2)
with pytest.raises(ValueError):
cs = CurrentState(1.0, 0, 0.0, x, 1.0, spectrum2=np.abs(x) ** 3)
cs = CurrentState(1.0, 0, 0.1, x, 1.0, spectrum2=x.copy(), field=x.copy(), field2=x.copy())
assert np.allclose(cs.spectrum2, cs.spectrum)
assert np.allclose(cs.spectrum, cs.field)
assert np.allclose(cs.field, cs.field2)
def test_copy():
x = (np.linspace(0, 1, 128, dtype=complex),)
cs = CurrentState(1.0, 0, 0.1, x, 1.0)
cs2 = cs.copy()
assert cs.spectrum is not cs2.spectrum
assert np.all(cs.field2 == cs2.field2)