cleanup with operators/value tracker/current_state
- current_state now always computes its derived values - it implements a copy function
This commit is contained in:
@@ -32,3 +32,4 @@ convention = "numpy"
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
|
||||
|
||||
@@ -312,7 +312,7 @@ default_rules: list[Rule] = [
|
||||
Rule("w_num", len, ["w"]),
|
||||
Rule("dw", lambda w: w[1] - w[0]),
|
||||
Rule(["fft", "ifft"], utils.fft_functions, priorities=1),
|
||||
Rule("interpolation_range", lambda dt: (2 * units.c * dt, 8e-6)),
|
||||
Rule("interpolation_range", lambda dt: (max(100e-9, 2 * units.c * dt), 8e-6)),
|
||||
# Pulse
|
||||
Rule("field_0", pulse.finalize_pulse),
|
||||
Rule(["input_time", "input_field"], pulse.load_custom_field),
|
||||
@@ -340,7 +340,6 @@ default_rules: list[Rule] = [
|
||||
Rule("L_NL", pulse.L_NL),
|
||||
Rule("L_sol", pulse.L_sol),
|
||||
Rule("c_to_a_factor", lambda: 1, priorities=-1),
|
||||
Rule("c_to_a_factor", pulse.c_to_a_factor),
|
||||
# Fiber Dispersion
|
||||
Rule("w_for_disp", units.m, ["wl_for_disp"]),
|
||||
Rule("hr_w", fiber.delayed_raman_w),
|
||||
@@ -419,6 +418,7 @@ envelope_rules = default_rules + [
|
||||
# Pulse
|
||||
Rule("spectrum_factor", pulse.spectrum_factor_envelope, priorities=-1),
|
||||
Rule("pre_field_0", pulse.initial_field_envelope, priorities=1),
|
||||
Rule("c_to_a_factor", pulse.c_to_a_factor),
|
||||
# Dispersion
|
||||
Rule(["wl_for_disp", "dispersion_ind"], fiber.lambda_for_envelope_dispersion),
|
||||
Rule("beta2_coefficients", fiber.dispersion_coefficients),
|
||||
|
||||
@@ -5,8 +5,9 @@ Nothing except the solver should depend on this file
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
from scipy.interpolate import interp1d
|
||||
@@ -20,25 +21,25 @@ class CurrentState:
|
||||
length: float
|
||||
z: float
|
||||
current_step_size: float
|
||||
step: int
|
||||
conversion_factor: np.ndarray
|
||||
conversion_factor: np.ndarray | float
|
||||
converter: Callable[[np.ndarray], np.ndarray]
|
||||
__spectrum: np.ndarray
|
||||
__spec2: np.ndarray
|
||||
__field: np.ndarray
|
||||
__field2: np.ndarray
|
||||
stats: dict[str, Any]
|
||||
spectrum: np.ndarray
|
||||
spec2: np.ndarray
|
||||
field: np.ndarray
|
||||
field2: np.ndarray
|
||||
|
||||
__slots__ = [
|
||||
"length",
|
||||
"z",
|
||||
"current_step_size",
|
||||
"step",
|
||||
"conversion_factor",
|
||||
"converter",
|
||||
"_CurrentState__spectrum",
|
||||
"_CurrentState__spec2",
|
||||
"_CurrentState__field",
|
||||
"_CurrentState__field2",
|
||||
"spectrum",
|
||||
"spectrum2",
|
||||
"field",
|
||||
"field2",
|
||||
"stats",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -46,18 +47,31 @@ class CurrentState:
|
||||
length: float,
|
||||
z: float,
|
||||
current_step_size: float,
|
||||
step: int,
|
||||
spectrum: np.ndarray,
|
||||
conversion_factor: np.ndarray,
|
||||
conversion_factor: np.ndarray | float,
|
||||
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft,
|
||||
spectrum2: np.ndarray | None = None,
|
||||
field: np.ndarray | None = None,
|
||||
field2: np.ndarray | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
self.length = length
|
||||
self.z = z
|
||||
self.current_step_size = current_step_size
|
||||
self.step = step
|
||||
self.conversion_factor = conversion_factor
|
||||
self.converter = converter
|
||||
self.spectrum = spectrum
|
||||
|
||||
if spectrum2 is None and field is None and field2 is None:
|
||||
self.set_spectrum(spectrum)
|
||||
elif any(el is None for el in (spectrum2, field, field2)):
|
||||
raise ValueError(
|
||||
"You must provide either all three of (spectrum2, field, field2) or none of them"
|
||||
)
|
||||
else:
|
||||
self.spectrum2 = spectrum2
|
||||
self.field = field
|
||||
self.field2 = field2
|
||||
self.stats = stats or {}
|
||||
|
||||
@property
|
||||
def z_ratio(self) -> float:
|
||||
@@ -67,125 +81,25 @@ class CurrentState:
|
||||
def actual_spectrum(self) -> np.ndarray:
|
||||
return self.conversion_factor * self.spectrum
|
||||
|
||||
@property
|
||||
def spectrum(self) -> np.ndarray:
|
||||
return self.__spectrum
|
||||
|
||||
@spectrum.setter
|
||||
def spectrum(self, new_value: np.ndarray):
|
||||
self.__spectrum = new_value
|
||||
self.__spec2 = None
|
||||
self.__field = None
|
||||
self.__field2 = None
|
||||
|
||||
@property
|
||||
def spec2(self) -> np.ndarray:
|
||||
if self.__spec2 is None:
|
||||
self.__spec2 = math.abs2(self.spectrum)
|
||||
return self.__spec2
|
||||
|
||||
@property
|
||||
def field(self) -> np.ndarray:
|
||||
if self.__field is None:
|
||||
self.__field = self.converter(self.spectrum)
|
||||
return self.__field
|
||||
|
||||
@property
|
||||
def field2(self) -> np.ndarray:
|
||||
if self.__field2 is None:
|
||||
self.__field2 = math.abs2(self.field)
|
||||
return self.__field2
|
||||
|
||||
def force_values(self, spec2: np.ndarray, field: np.ndarray, field2: np.ndarray):
|
||||
"""force these values instead of recomputing them
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spectrum : np.ndarray
|
||||
spectrum
|
||||
spec2 : np.ndarray
|
||||
|spectrum|^2
|
||||
field : np.ndarray
|
||||
field = converter(spectrum)
|
||||
field2 : np.ndarray
|
||||
|field|^2
|
||||
"""
|
||||
self.__spec2 = spec2
|
||||
self.__field = field
|
||||
self.__field2 = field2
|
||||
|
||||
def replace(self, new_spectrum: np.ndarray) -> CurrentState:
|
||||
"""returns a new state with new attributes"""
|
||||
return CurrentState(
|
||||
length=self.length,
|
||||
z=self.z,
|
||||
current_step_size=self.current_step_size,
|
||||
step=self.step,
|
||||
conversion_factor=self.conversion_factor,
|
||||
converter=self.converter,
|
||||
spectrum=new_spectrum,
|
||||
)
|
||||
|
||||
def with_params(self, **params) -> CurrentState:
|
||||
"""returns a new CurrentState with modified params, except for the solution"""
|
||||
my_params = dict(
|
||||
length=self.length,
|
||||
z=self.z,
|
||||
current_step_size=self.current_step_size,
|
||||
step=self.step,
|
||||
conversion_factor=self.conversion_factor,
|
||||
converter=self.converter,
|
||||
)
|
||||
new_state = CurrentState(spectrum=self.__spectrum, **(my_params | params))
|
||||
new_state.force_values(self.spec2, self.field, self.field2)
|
||||
return new_state
|
||||
def set_spectrum(self, new_spectrum: np.ndarray):
|
||||
self.spectrum = new_spectrum
|
||||
self.spectrum2 = math.abs2(self.spectrum)
|
||||
self.field = self.converter(self.spectrum)
|
||||
self.field2 = math.abs2(self.field)
|
||||
|
||||
def copy(self) -> CurrentState:
|
||||
new = CurrentState(
|
||||
length=self.length,
|
||||
z=self.z,
|
||||
current_step_size=self.current_step_size,
|
||||
step=self.step,
|
||||
conversion_factor=self.conversion_factor,
|
||||
converter=self.converter,
|
||||
spectrum=self.__spectrum,
|
||||
return CurrentState(
|
||||
self.length,
|
||||
self.z,
|
||||
self.current_step_size,
|
||||
self.spectrum.copy(),
|
||||
self.conversion_factor,
|
||||
self.converter,
|
||||
self.spectrum2.copy(),
|
||||
self.field.copy(),
|
||||
self.field2.copy(),
|
||||
deepcopy(self.stats),
|
||||
)
|
||||
new.force_values(self.__spec2, self.__field, self.__field2)
|
||||
return new
|
||||
|
||||
|
||||
class ValueTracker(ABC):
|
||||
def values(self) -> dict[str, float]:
|
||||
return {}
|
||||
|
||||
def all_values(self) -> dict[str, float]:
|
||||
out = self.values()
|
||||
for operator in vars(self).values():
|
||||
if isinstance(operator, ValueTracker):
|
||||
out = operator.all_values() | out
|
||||
return out
|
||||
|
||||
def __repr__(self) -> str:
|
||||
value_pair_list = list(vars(self).items())
|
||||
if len(value_pair_list) == 0:
|
||||
value_pair_str_list = ""
|
||||
elif len(value_pair_list) == 1:
|
||||
value_pair_str_list = [self.__value_repr(value_pair_list[0][0], value_pair_list[0][1])]
|
||||
else:
|
||||
value_pair_str_list = [k + "=" + self.__value_repr(k, v) for k, v in value_pair_list]
|
||||
|
||||
return self.__class__.__name__ + "(" + ", ".join(value_pair_str_list) + ")"
|
||||
|
||||
def __value_repr(self, k: str, v) -> str:
|
||||
if k.endswith("_const") and isinstance(v, (list, np.ndarray, tuple)):
|
||||
return repr(v[0])
|
||||
return repr(v)
|
||||
|
||||
|
||||
class Operator(ValueTracker):
|
||||
@abstractmethod
|
||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
class NoOpTime(Operator):
|
||||
@@ -540,7 +454,6 @@ class ConstantWaveVector(AbstractWaveVector):
|
||||
dispersion_ind: np.ndarray,
|
||||
w_order: np.ndarray,
|
||||
):
|
||||
|
||||
self.beta_arr = np.zeros(w_num, dtype=float)
|
||||
self.beta_arr[dispersion_ind] = fiber.beta(w_for_disp, n_op())[2:-2]
|
||||
left_ind, *_, right_ind = np.nonzero(self.beta_arr[w_order])[0]
|
||||
@@ -817,7 +730,6 @@ class VariableScalarGamma(AbstractGamma):
|
||||
class Plasma(Operator):
|
||||
mat_plasma: plasma.Plasma
|
||||
gas_op: AbstractGas
|
||||
ionization_fraction = 0.0
|
||||
|
||||
def __init__(self, dt: float, w: np.ndarray, gas_op: AbstractGas):
|
||||
self.gas_op = gas_op
|
||||
@@ -827,12 +739,9 @@ class Plasma(Operator):
|
||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||
N0 = self.gas_op.number_density(state)
|
||||
plasma_info = self.mat_plasma(state.field, N0)
|
||||
self.ionization_fraction = plasma_info.electron_density[-1] / N0
|
||||
state.stats["ionization_fraction"] = plasma_info.electron_density[-1] / N0
|
||||
return self.factor_out * np.fft.rfft(plasma_info.polarization)
|
||||
|
||||
def values(self) -> dict[str, float]:
|
||||
return dict(ionization_fraction=self.ionization_fraction)
|
||||
|
||||
|
||||
class NoPlasma(NoOpFreq, Plasma):
|
||||
pass
|
||||
@@ -863,7 +772,7 @@ class PhotonNumberLoss(AbstractConservedQuantity):
|
||||
|
||||
def __call__(self, state: CurrentState) -> float:
|
||||
return pulse.photon_number_with_loss(
|
||||
state.spec2,
|
||||
state.spectrum2,
|
||||
self.w,
|
||||
self.dw,
|
||||
self.gamma_op(state),
|
||||
@@ -879,7 +788,7 @@ class PhotonNumberNoLoss(AbstractConservedQuantity):
|
||||
self.gamma_op = gamma_op
|
||||
|
||||
def __call__(self, state: CurrentState) -> float:
|
||||
return pulse.photon_number(state.spec2, self.w, self.dw, self.gamma_op(state))
|
||||
return pulse.photon_number(state.spectrum2, self.w, self.dw, self.gamma_op(state))
|
||||
|
||||
|
||||
class EnergyLoss(AbstractConservedQuantity):
|
||||
|
||||
@@ -408,7 +408,7 @@ class Parameters:
|
||||
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
spectrum_factor: float = Parameter(type_checker(float))
|
||||
c_to_a_factor: np.ndarray = Parameter(type_checker(float, np.ndarray))
|
||||
c_to_a_factor: np.ndarray = Parameter(type_checker(float, int, np.ndarray))
|
||||
w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
l: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
w_c: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Iterator, Type
|
||||
from typing import Any, Iterator, Type
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
@@ -14,7 +14,6 @@ from scgenerator.operators import (
|
||||
CurrentState,
|
||||
LinearOperator,
|
||||
NonLinearOperator,
|
||||
ValueTracker,
|
||||
)
|
||||
from scgenerator.utils import get_arg_names
|
||||
|
||||
@@ -55,12 +54,12 @@ class IntegratorFactory:
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
class Integrator(ValueTracker):
|
||||
class Integrator:
|
||||
linear_operator: LinearOperator
|
||||
nonlinear_operator: NonLinearOperator
|
||||
state: CurrentState
|
||||
target_error: float
|
||||
_tracked_values: dict[str, float]
|
||||
_tracked_values: dict[float, dict[str, Any]]
|
||||
logger: logging.Logger
|
||||
__factory: IntegratorFactory = IntegratorFactory()
|
||||
order = 4
|
||||
@@ -109,16 +108,13 @@ class Integrator(ValueTracker):
|
||||
tracked values
|
||||
"""
|
||||
return self._tracked_values | dict(z=self.state.z, step=self.state.step)
|
||||
|
||||
def record_tracked_values(self):
|
||||
self._tracked_values = super().all_values()
|
||||
|
||||
|
||||
def nl(self, spectrum: np.ndarray) -> np.ndarray:
|
||||
return self.nonlinear_operator(self.state.replace(spectrum))
|
||||
|
||||
def accept_step(
|
||||
self, new_state: CurrentState, previous_step_size: float, next_step_size: float
|
||||
) -> CurrentState:
|
||||
):
|
||||
self.state = new_state
|
||||
self.state.current_step_size = next_step_size
|
||||
self.state.z += previous_step_size
|
||||
|
||||
@@ -2,7 +2,7 @@ name = "/Users/benoitsierro/tests/test_sc/Chang2011Fig2"
|
||||
|
||||
wavelength = 800e-9
|
||||
shape = "gaussian"
|
||||
energy = 2.5e-6
|
||||
energy = 2.5e-7
|
||||
width = 30e-15
|
||||
|
||||
core_radius = 10e-6
|
||||
@@ -11,9 +11,8 @@ gas_name = "argon"
|
||||
pressure = 3.2e5
|
||||
|
||||
length = 0.1
|
||||
interpolation_range = [120e-9, 3000e-9]
|
||||
full_field = true
|
||||
photoionization = false
|
||||
dt = 0.04e-15
|
||||
t_num = 32768
|
||||
z_num = 128
|
||||
step_size = 10e-6
|
||||
|
||||
@@ -1,19 +1,14 @@
|
||||
import warnings
|
||||
import numpy as np
|
||||
import rediscache
|
||||
import scgenerator as sc
|
||||
from customfunc.app import PlotApp
|
||||
from scipy.interpolate import interp1d
|
||||
from tqdm import tqdm
|
||||
|
||||
# warnings.filterwarnings("error")
|
||||
|
||||
|
||||
@rediscache.rcache
|
||||
def get_specs(params: dict):
|
||||
p = sc.Parameters(**params)
|
||||
sim = sc.RK4IP(p)
|
||||
return [s[-1] for s in tqdm(sim.irun(), total=p.z_num)], p.dump_dict()
|
||||
return [s.actual_spectrum for _, s in tqdm(sim.irun(), total=p.z_num)], p.dump_dict()
|
||||
|
||||
|
||||
def main():
|
||||
@@ -25,7 +20,7 @@ def main():
|
||||
rt = sc.PlotRange(-500, 500, "fs")
|
||||
x, o, ext = rs.sort_axis(params.w)
|
||||
vmin = -50
|
||||
with PlotApp(i=(int, 0, params.z_num - 1)) as app:
|
||||
with PlotApp(i=range(params.z_num)) as app:
|
||||
spec_ax = app[0]
|
||||
spec_ax.set_xlabel(rs.unit.label)
|
||||
field_ax = app[1]
|
||||
@@ -42,8 +37,8 @@ def main():
|
||||
|
||||
@app.cache
|
||||
def compute(i):
|
||||
xt, field = sc.transform_1D_values(params.ifft(specs[i]), rt, params)
|
||||
x, spec = sc.transform_1D_values(sc.abs2(specs[i]), rs, params, log=True)
|
||||
xt, field = sc.transform_1D_values(params.ifft(specs[i]), rt, params=params)
|
||||
x, spec = sc.transform_1D_values(sc.abs2(specs[i]), rs, params=params, log=True)
|
||||
# spec = np.where(spec > vmin, spec, vmin)
|
||||
field2 = sc.abs2(field)
|
||||
bot, top = sc.math.envelope_ind(field2)
|
||||
|
||||
31
tests/test_current_state.py
Normal file
31
tests/test_current_state.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from scgenerator.operators import CurrentState
|
||||
|
||||
|
||||
def test_creation():
|
||||
x = (np.linspace(0, 1, 128, dtype=complex),)
|
||||
cs = CurrentState(1.0, 0, 0.1, x, 1.0)
|
||||
|
||||
assert cs.converter is np.fft.ifft
|
||||
assert cs.stats == {}
|
||||
assert np.allclose(cs.spectrum2, np.abs(np.fft.ifft(x)) ** 2)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
cs = CurrentState(1.0, 0, 0.0, x, 1.0, spectrum2=np.abs(x) ** 3)
|
||||
|
||||
cs = CurrentState(1.0, 0, 0.1, x, 1.0, spectrum2=x.copy(), field=x.copy(), field2=x.copy())
|
||||
|
||||
assert np.allclose(cs.spectrum2, cs.spectrum)
|
||||
assert np.allclose(cs.spectrum, cs.field)
|
||||
assert np.allclose(cs.field, cs.field2)
|
||||
|
||||
|
||||
def test_copy():
|
||||
x = (np.linspace(0, 1, 128, dtype=complex),)
|
||||
cs = CurrentState(1.0, 0, 0.1, x, 1.0)
|
||||
cs2 = cs.copy()
|
||||
|
||||
assert cs.spectrum is not cs2.spectrum
|
||||
assert np.all(cs.field2 == cs2.field2)
|
||||
Reference in New Issue
Block a user