cleanup with operators/value tracker/current_state
- current_state now always computes its derived values - it implements a copy function
This commit is contained in:
@@ -32,3 +32,4 @@ convention = "numpy"
|
|||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 100
|
line-length = 100
|
||||||
|
|
||||||
|
|||||||
@@ -312,7 +312,7 @@ default_rules: list[Rule] = [
|
|||||||
Rule("w_num", len, ["w"]),
|
Rule("w_num", len, ["w"]),
|
||||||
Rule("dw", lambda w: w[1] - w[0]),
|
Rule("dw", lambda w: w[1] - w[0]),
|
||||||
Rule(["fft", "ifft"], utils.fft_functions, priorities=1),
|
Rule(["fft", "ifft"], utils.fft_functions, priorities=1),
|
||||||
Rule("interpolation_range", lambda dt: (2 * units.c * dt, 8e-6)),
|
Rule("interpolation_range", lambda dt: (max(100e-9, 2 * units.c * dt), 8e-6)),
|
||||||
# Pulse
|
# Pulse
|
||||||
Rule("field_0", pulse.finalize_pulse),
|
Rule("field_0", pulse.finalize_pulse),
|
||||||
Rule(["input_time", "input_field"], pulse.load_custom_field),
|
Rule(["input_time", "input_field"], pulse.load_custom_field),
|
||||||
@@ -340,7 +340,6 @@ default_rules: list[Rule] = [
|
|||||||
Rule("L_NL", pulse.L_NL),
|
Rule("L_NL", pulse.L_NL),
|
||||||
Rule("L_sol", pulse.L_sol),
|
Rule("L_sol", pulse.L_sol),
|
||||||
Rule("c_to_a_factor", lambda: 1, priorities=-1),
|
Rule("c_to_a_factor", lambda: 1, priorities=-1),
|
||||||
Rule("c_to_a_factor", pulse.c_to_a_factor),
|
|
||||||
# Fiber Dispersion
|
# Fiber Dispersion
|
||||||
Rule("w_for_disp", units.m, ["wl_for_disp"]),
|
Rule("w_for_disp", units.m, ["wl_for_disp"]),
|
||||||
Rule("hr_w", fiber.delayed_raman_w),
|
Rule("hr_w", fiber.delayed_raman_w),
|
||||||
@@ -419,6 +418,7 @@ envelope_rules = default_rules + [
|
|||||||
# Pulse
|
# Pulse
|
||||||
Rule("spectrum_factor", pulse.spectrum_factor_envelope, priorities=-1),
|
Rule("spectrum_factor", pulse.spectrum_factor_envelope, priorities=-1),
|
||||||
Rule("pre_field_0", pulse.initial_field_envelope, priorities=1),
|
Rule("pre_field_0", pulse.initial_field_envelope, priorities=1),
|
||||||
|
Rule("c_to_a_factor", pulse.c_to_a_factor),
|
||||||
# Dispersion
|
# Dispersion
|
||||||
Rule(["wl_for_disp", "dispersion_ind"], fiber.lambda_for_envelope_dispersion),
|
Rule(["wl_for_disp", "dispersion_ind"], fiber.lambda_for_envelope_dispersion),
|
||||||
Rule("beta2_coefficients", fiber.dispersion_coefficients),
|
Rule("beta2_coefficients", fiber.dispersion_coefficients),
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ Nothing except the solver should depend on this file
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
@@ -20,25 +21,25 @@ class CurrentState:
|
|||||||
length: float
|
length: float
|
||||||
z: float
|
z: float
|
||||||
current_step_size: float
|
current_step_size: float
|
||||||
step: int
|
conversion_factor: np.ndarray | float
|
||||||
conversion_factor: np.ndarray
|
|
||||||
converter: Callable[[np.ndarray], np.ndarray]
|
converter: Callable[[np.ndarray], np.ndarray]
|
||||||
__spectrum: np.ndarray
|
stats: dict[str, Any]
|
||||||
__spec2: np.ndarray
|
spectrum: np.ndarray
|
||||||
__field: np.ndarray
|
spec2: np.ndarray
|
||||||
__field2: np.ndarray
|
field: np.ndarray
|
||||||
|
field2: np.ndarray
|
||||||
|
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"length",
|
"length",
|
||||||
"z",
|
"z",
|
||||||
"current_step_size",
|
"current_step_size",
|
||||||
"step",
|
|
||||||
"conversion_factor",
|
"conversion_factor",
|
||||||
"converter",
|
"converter",
|
||||||
"_CurrentState__spectrum",
|
"spectrum",
|
||||||
"_CurrentState__spec2",
|
"spectrum2",
|
||||||
"_CurrentState__field",
|
"field",
|
||||||
"_CurrentState__field2",
|
"field2",
|
||||||
|
"stats",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -46,18 +47,31 @@ class CurrentState:
|
|||||||
length: float,
|
length: float,
|
||||||
z: float,
|
z: float,
|
||||||
current_step_size: float,
|
current_step_size: float,
|
||||||
step: int,
|
|
||||||
spectrum: np.ndarray,
|
spectrum: np.ndarray,
|
||||||
conversion_factor: np.ndarray,
|
conversion_factor: np.ndarray | float,
|
||||||
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft,
|
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft,
|
||||||
|
spectrum2: np.ndarray | None = None,
|
||||||
|
field: np.ndarray | None = None,
|
||||||
|
field2: np.ndarray | None = None,
|
||||||
|
stats: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
self.length = length
|
self.length = length
|
||||||
self.z = z
|
self.z = z
|
||||||
self.current_step_size = current_step_size
|
self.current_step_size = current_step_size
|
||||||
self.step = step
|
|
||||||
self.conversion_factor = conversion_factor
|
self.conversion_factor = conversion_factor
|
||||||
self.converter = converter
|
self.converter = converter
|
||||||
self.spectrum = spectrum
|
|
||||||
|
if spectrum2 is None and field is None and field2 is None:
|
||||||
|
self.set_spectrum(spectrum)
|
||||||
|
elif any(el is None for el in (spectrum2, field, field2)):
|
||||||
|
raise ValueError(
|
||||||
|
"You must provide either all three of (spectrum2, field, field2) or none of them"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.spectrum2 = spectrum2
|
||||||
|
self.field = field
|
||||||
|
self.field2 = field2
|
||||||
|
self.stats = stats or {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def z_ratio(self) -> float:
|
def z_ratio(self) -> float:
|
||||||
@@ -67,125 +81,25 @@ class CurrentState:
|
|||||||
def actual_spectrum(self) -> np.ndarray:
|
def actual_spectrum(self) -> np.ndarray:
|
||||||
return self.conversion_factor * self.spectrum
|
return self.conversion_factor * self.spectrum
|
||||||
|
|
||||||
@property
|
def set_spectrum(self, new_spectrum: np.ndarray):
|
||||||
def spectrum(self) -> np.ndarray:
|
self.spectrum = new_spectrum
|
||||||
return self.__spectrum
|
self.spectrum2 = math.abs2(self.spectrum)
|
||||||
|
self.field = self.converter(self.spectrum)
|
||||||
@spectrum.setter
|
self.field2 = math.abs2(self.field)
|
||||||
def spectrum(self, new_value: np.ndarray):
|
|
||||||
self.__spectrum = new_value
|
|
||||||
self.__spec2 = None
|
|
||||||
self.__field = None
|
|
||||||
self.__field2 = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def spec2(self) -> np.ndarray:
|
|
||||||
if self.__spec2 is None:
|
|
||||||
self.__spec2 = math.abs2(self.spectrum)
|
|
||||||
return self.__spec2
|
|
||||||
|
|
||||||
@property
|
|
||||||
def field(self) -> np.ndarray:
|
|
||||||
if self.__field is None:
|
|
||||||
self.__field = self.converter(self.spectrum)
|
|
||||||
return self.__field
|
|
||||||
|
|
||||||
@property
|
|
||||||
def field2(self) -> np.ndarray:
|
|
||||||
if self.__field2 is None:
|
|
||||||
self.__field2 = math.abs2(self.field)
|
|
||||||
return self.__field2
|
|
||||||
|
|
||||||
def force_values(self, spec2: np.ndarray, field: np.ndarray, field2: np.ndarray):
|
|
||||||
"""force these values instead of recomputing them
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spectrum : np.ndarray
|
|
||||||
spectrum
|
|
||||||
spec2 : np.ndarray
|
|
||||||
|spectrum|^2
|
|
||||||
field : np.ndarray
|
|
||||||
field = converter(spectrum)
|
|
||||||
field2 : np.ndarray
|
|
||||||
|field|^2
|
|
||||||
"""
|
|
||||||
self.__spec2 = spec2
|
|
||||||
self.__field = field
|
|
||||||
self.__field2 = field2
|
|
||||||
|
|
||||||
def replace(self, new_spectrum: np.ndarray) -> CurrentState:
|
|
||||||
"""returns a new state with new attributes"""
|
|
||||||
return CurrentState(
|
|
||||||
length=self.length,
|
|
||||||
z=self.z,
|
|
||||||
current_step_size=self.current_step_size,
|
|
||||||
step=self.step,
|
|
||||||
conversion_factor=self.conversion_factor,
|
|
||||||
converter=self.converter,
|
|
||||||
spectrum=new_spectrum,
|
|
||||||
)
|
|
||||||
|
|
||||||
def with_params(self, **params) -> CurrentState:
|
|
||||||
"""returns a new CurrentState with modified params, except for the solution"""
|
|
||||||
my_params = dict(
|
|
||||||
length=self.length,
|
|
||||||
z=self.z,
|
|
||||||
current_step_size=self.current_step_size,
|
|
||||||
step=self.step,
|
|
||||||
conversion_factor=self.conversion_factor,
|
|
||||||
converter=self.converter,
|
|
||||||
)
|
|
||||||
new_state = CurrentState(spectrum=self.__spectrum, **(my_params | params))
|
|
||||||
new_state.force_values(self.spec2, self.field, self.field2)
|
|
||||||
return new_state
|
|
||||||
|
|
||||||
def copy(self) -> CurrentState:
|
def copy(self) -> CurrentState:
|
||||||
new = CurrentState(
|
return CurrentState(
|
||||||
length=self.length,
|
self.length,
|
||||||
z=self.z,
|
self.z,
|
||||||
current_step_size=self.current_step_size,
|
self.current_step_size,
|
||||||
step=self.step,
|
self.spectrum.copy(),
|
||||||
conversion_factor=self.conversion_factor,
|
self.conversion_factor,
|
||||||
converter=self.converter,
|
self.converter,
|
||||||
spectrum=self.__spectrum,
|
self.spectrum2.copy(),
|
||||||
|
self.field.copy(),
|
||||||
|
self.field2.copy(),
|
||||||
|
deepcopy(self.stats),
|
||||||
)
|
)
|
||||||
new.force_values(self.__spec2, self.__field, self.__field2)
|
|
||||||
return new
|
|
||||||
|
|
||||||
|
|
||||||
class ValueTracker(ABC):
|
|
||||||
def values(self) -> dict[str, float]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def all_values(self) -> dict[str, float]:
|
|
||||||
out = self.values()
|
|
||||||
for operator in vars(self).values():
|
|
||||||
if isinstance(operator, ValueTracker):
|
|
||||||
out = operator.all_values() | out
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
value_pair_list = list(vars(self).items())
|
|
||||||
if len(value_pair_list) == 0:
|
|
||||||
value_pair_str_list = ""
|
|
||||||
elif len(value_pair_list) == 1:
|
|
||||||
value_pair_str_list = [self.__value_repr(value_pair_list[0][0], value_pair_list[0][1])]
|
|
||||||
else:
|
|
||||||
value_pair_str_list = [k + "=" + self.__value_repr(k, v) for k, v in value_pair_list]
|
|
||||||
|
|
||||||
return self.__class__.__name__ + "(" + ", ".join(value_pair_str_list) + ")"
|
|
||||||
|
|
||||||
def __value_repr(self, k: str, v) -> str:
|
|
||||||
if k.endswith("_const") and isinstance(v, (list, np.ndarray, tuple)):
|
|
||||||
return repr(v[0])
|
|
||||||
return repr(v)
|
|
||||||
|
|
||||||
|
|
||||||
class Operator(ValueTracker):
|
|
||||||
@abstractmethod
|
|
||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NoOpTime(Operator):
|
class NoOpTime(Operator):
|
||||||
@@ -540,7 +454,6 @@ class ConstantWaveVector(AbstractWaveVector):
|
|||||||
dispersion_ind: np.ndarray,
|
dispersion_ind: np.ndarray,
|
||||||
w_order: np.ndarray,
|
w_order: np.ndarray,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.beta_arr = np.zeros(w_num, dtype=float)
|
self.beta_arr = np.zeros(w_num, dtype=float)
|
||||||
self.beta_arr[dispersion_ind] = fiber.beta(w_for_disp, n_op())[2:-2]
|
self.beta_arr[dispersion_ind] = fiber.beta(w_for_disp, n_op())[2:-2]
|
||||||
left_ind, *_, right_ind = np.nonzero(self.beta_arr[w_order])[0]
|
left_ind, *_, right_ind = np.nonzero(self.beta_arr[w_order])[0]
|
||||||
@@ -817,7 +730,6 @@ class VariableScalarGamma(AbstractGamma):
|
|||||||
class Plasma(Operator):
|
class Plasma(Operator):
|
||||||
mat_plasma: plasma.Plasma
|
mat_plasma: plasma.Plasma
|
||||||
gas_op: AbstractGas
|
gas_op: AbstractGas
|
||||||
ionization_fraction = 0.0
|
|
||||||
|
|
||||||
def __init__(self, dt: float, w: np.ndarray, gas_op: AbstractGas):
|
def __init__(self, dt: float, w: np.ndarray, gas_op: AbstractGas):
|
||||||
self.gas_op = gas_op
|
self.gas_op = gas_op
|
||||||
@@ -827,12 +739,9 @@ class Plasma(Operator):
|
|||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||||
N0 = self.gas_op.number_density(state)
|
N0 = self.gas_op.number_density(state)
|
||||||
plasma_info = self.mat_plasma(state.field, N0)
|
plasma_info = self.mat_plasma(state.field, N0)
|
||||||
self.ionization_fraction = plasma_info.electron_density[-1] / N0
|
state.stats["ionization_fraction"] = plasma_info.electron_density[-1] / N0
|
||||||
return self.factor_out * np.fft.rfft(plasma_info.polarization)
|
return self.factor_out * np.fft.rfft(plasma_info.polarization)
|
||||||
|
|
||||||
def values(self) -> dict[str, float]:
|
|
||||||
return dict(ionization_fraction=self.ionization_fraction)
|
|
||||||
|
|
||||||
|
|
||||||
class NoPlasma(NoOpFreq, Plasma):
|
class NoPlasma(NoOpFreq, Plasma):
|
||||||
pass
|
pass
|
||||||
@@ -863,7 +772,7 @@ class PhotonNumberLoss(AbstractConservedQuantity):
|
|||||||
|
|
||||||
def __call__(self, state: CurrentState) -> float:
|
def __call__(self, state: CurrentState) -> float:
|
||||||
return pulse.photon_number_with_loss(
|
return pulse.photon_number_with_loss(
|
||||||
state.spec2,
|
state.spectrum2,
|
||||||
self.w,
|
self.w,
|
||||||
self.dw,
|
self.dw,
|
||||||
self.gamma_op(state),
|
self.gamma_op(state),
|
||||||
@@ -879,7 +788,7 @@ class PhotonNumberNoLoss(AbstractConservedQuantity):
|
|||||||
self.gamma_op = gamma_op
|
self.gamma_op = gamma_op
|
||||||
|
|
||||||
def __call__(self, state: CurrentState) -> float:
|
def __call__(self, state: CurrentState) -> float:
|
||||||
return pulse.photon_number(state.spec2, self.w, self.dw, self.gamma_op(state))
|
return pulse.photon_number(state.spectrum2, self.w, self.dw, self.gamma_op(state))
|
||||||
|
|
||||||
|
|
||||||
class EnergyLoss(AbstractConservedQuantity):
|
class EnergyLoss(AbstractConservedQuantity):
|
||||||
|
|||||||
@@ -408,7 +408,7 @@ class Parameters:
|
|||||||
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
spectrum_factor: float = Parameter(type_checker(float))
|
spectrum_factor: float = Parameter(type_checker(float))
|
||||||
c_to_a_factor: np.ndarray = Parameter(type_checker(float, np.ndarray))
|
c_to_a_factor: np.ndarray = Parameter(type_checker(float, int, np.ndarray))
|
||||||
w: np.ndarray = Parameter(type_checker(np.ndarray))
|
w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
l: np.ndarray = Parameter(type_checker(np.ndarray))
|
l: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
w_c: np.ndarray = Parameter(type_checker(np.ndarray))
|
w_c: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Iterator, Type
|
from typing import Any, Iterator, Type
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -14,7 +14,6 @@ from scgenerator.operators import (
|
|||||||
CurrentState,
|
CurrentState,
|
||||||
LinearOperator,
|
LinearOperator,
|
||||||
NonLinearOperator,
|
NonLinearOperator,
|
||||||
ValueTracker,
|
|
||||||
)
|
)
|
||||||
from scgenerator.utils import get_arg_names
|
from scgenerator.utils import get_arg_names
|
||||||
|
|
||||||
@@ -55,12 +54,12 @@ class IntegratorFactory:
|
|||||||
return cls(**kwargs)
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Integrator(ValueTracker):
|
class Integrator:
|
||||||
linear_operator: LinearOperator
|
linear_operator: LinearOperator
|
||||||
nonlinear_operator: NonLinearOperator
|
nonlinear_operator: NonLinearOperator
|
||||||
state: CurrentState
|
state: CurrentState
|
||||||
target_error: float
|
target_error: float
|
||||||
_tracked_values: dict[str, float]
|
_tracked_values: dict[float, dict[str, Any]]
|
||||||
logger: logging.Logger
|
logger: logging.Logger
|
||||||
__factory: IntegratorFactory = IntegratorFactory()
|
__factory: IntegratorFactory = IntegratorFactory()
|
||||||
order = 4
|
order = 4
|
||||||
@@ -109,16 +108,13 @@ class Integrator(ValueTracker):
|
|||||||
tracked values
|
tracked values
|
||||||
"""
|
"""
|
||||||
return self._tracked_values | dict(z=self.state.z, step=self.state.step)
|
return self._tracked_values | dict(z=self.state.z, step=self.state.step)
|
||||||
|
|
||||||
def record_tracked_values(self):
|
|
||||||
self._tracked_values = super().all_values()
|
|
||||||
|
|
||||||
def nl(self, spectrum: np.ndarray) -> np.ndarray:
|
def nl(self, spectrum: np.ndarray) -> np.ndarray:
|
||||||
return self.nonlinear_operator(self.state.replace(spectrum))
|
return self.nonlinear_operator(self.state.replace(spectrum))
|
||||||
|
|
||||||
def accept_step(
|
def accept_step(
|
||||||
self, new_state: CurrentState, previous_step_size: float, next_step_size: float
|
self, new_state: CurrentState, previous_step_size: float, next_step_size: float
|
||||||
) -> CurrentState:
|
):
|
||||||
self.state = new_state
|
self.state = new_state
|
||||||
self.state.current_step_size = next_step_size
|
self.state.current_step_size = next_step_size
|
||||||
self.state.z += previous_step_size
|
self.state.z += previous_step_size
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ name = "/Users/benoitsierro/tests/test_sc/Chang2011Fig2"
|
|||||||
|
|
||||||
wavelength = 800e-9
|
wavelength = 800e-9
|
||||||
shape = "gaussian"
|
shape = "gaussian"
|
||||||
energy = 2.5e-6
|
energy = 2.5e-7
|
||||||
width = 30e-15
|
width = 30e-15
|
||||||
|
|
||||||
core_radius = 10e-6
|
core_radius = 10e-6
|
||||||
@@ -11,9 +11,8 @@ gas_name = "argon"
|
|||||||
pressure = 3.2e5
|
pressure = 3.2e5
|
||||||
|
|
||||||
length = 0.1
|
length = 0.1
|
||||||
interpolation_range = [120e-9, 3000e-9]
|
|
||||||
full_field = true
|
full_field = true
|
||||||
|
photoionization = false
|
||||||
dt = 0.04e-15
|
dt = 0.04e-15
|
||||||
t_num = 32768
|
t_num = 32768
|
||||||
z_num = 128
|
z_num = 128
|
||||||
step_size = 10e-6
|
|
||||||
|
|||||||
@@ -1,19 +1,14 @@
|
|||||||
import warnings
|
|
||||||
import numpy as np
|
|
||||||
import rediscache
|
|
||||||
import scgenerator as sc
|
import scgenerator as sc
|
||||||
from customfunc.app import PlotApp
|
from customfunc.app import PlotApp
|
||||||
from scipy.interpolate import interp1d
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# warnings.filterwarnings("error")
|
# warnings.filterwarnings("error")
|
||||||
|
|
||||||
|
|
||||||
@rediscache.rcache
|
|
||||||
def get_specs(params: dict):
|
def get_specs(params: dict):
|
||||||
p = sc.Parameters(**params)
|
p = sc.Parameters(**params)
|
||||||
sim = sc.RK4IP(p)
|
sim = sc.RK4IP(p)
|
||||||
return [s[-1] for s in tqdm(sim.irun(), total=p.z_num)], p.dump_dict()
|
return [s.actual_spectrum for _, s in tqdm(sim.irun(), total=p.z_num)], p.dump_dict()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -25,7 +20,7 @@ def main():
|
|||||||
rt = sc.PlotRange(-500, 500, "fs")
|
rt = sc.PlotRange(-500, 500, "fs")
|
||||||
x, o, ext = rs.sort_axis(params.w)
|
x, o, ext = rs.sort_axis(params.w)
|
||||||
vmin = -50
|
vmin = -50
|
||||||
with PlotApp(i=(int, 0, params.z_num - 1)) as app:
|
with PlotApp(i=range(params.z_num)) as app:
|
||||||
spec_ax = app[0]
|
spec_ax = app[0]
|
||||||
spec_ax.set_xlabel(rs.unit.label)
|
spec_ax.set_xlabel(rs.unit.label)
|
||||||
field_ax = app[1]
|
field_ax = app[1]
|
||||||
@@ -42,8 +37,8 @@ def main():
|
|||||||
|
|
||||||
@app.cache
|
@app.cache
|
||||||
def compute(i):
|
def compute(i):
|
||||||
xt, field = sc.transform_1D_values(params.ifft(specs[i]), rt, params)
|
xt, field = sc.transform_1D_values(params.ifft(specs[i]), rt, params=params)
|
||||||
x, spec = sc.transform_1D_values(sc.abs2(specs[i]), rs, params, log=True)
|
x, spec = sc.transform_1D_values(sc.abs2(specs[i]), rs, params=params, log=True)
|
||||||
# spec = np.where(spec > vmin, spec, vmin)
|
# spec = np.where(spec > vmin, spec, vmin)
|
||||||
field2 = sc.abs2(field)
|
field2 = sc.abs2(field)
|
||||||
bot, top = sc.math.envelope_ind(field2)
|
bot, top = sc.math.envelope_ind(field2)
|
||||||
|
|||||||
31
tests/test_current_state.py
Normal file
31
tests/test_current_state.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from scgenerator.operators import CurrentState
|
||||||
|
|
||||||
|
|
||||||
|
def test_creation():
|
||||||
|
x = (np.linspace(0, 1, 128, dtype=complex),)
|
||||||
|
cs = CurrentState(1.0, 0, 0.1, x, 1.0)
|
||||||
|
|
||||||
|
assert cs.converter is np.fft.ifft
|
||||||
|
assert cs.stats == {}
|
||||||
|
assert np.allclose(cs.spectrum2, np.abs(np.fft.ifft(x)) ** 2)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
cs = CurrentState(1.0, 0, 0.0, x, 1.0, spectrum2=np.abs(x) ** 3)
|
||||||
|
|
||||||
|
cs = CurrentState(1.0, 0, 0.1, x, 1.0, spectrum2=x.copy(), field=x.copy(), field2=x.copy())
|
||||||
|
|
||||||
|
assert np.allclose(cs.spectrum2, cs.spectrum)
|
||||||
|
assert np.allclose(cs.spectrum, cs.field)
|
||||||
|
assert np.allclose(cs.field, cs.field2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_copy():
|
||||||
|
x = (np.linspace(0, 1, 128, dtype=complex),)
|
||||||
|
cs = CurrentState(1.0, 0, 0.1, x, 1.0)
|
||||||
|
cs2 = cs.copy()
|
||||||
|
|
||||||
|
assert cs.spectrum is not cs2.spectrum
|
||||||
|
assert np.all(cs.field2 == cs2.field2)
|
||||||
Reference in New Issue
Block a user