This commit is contained in:
Benoît Sierro
2021-10-19 13:59:16 +02:00
parent 134fd501c3
commit 24807371f3
6 changed files with 425 additions and 84 deletions

View File

@@ -292,9 +292,7 @@ def build_sim_grid(
time_window: float = None, time_window: float = None,
t_num: int = None, t_num: int = None,
dt: float = None, dt: float = None,
) -> tuple[ ) -> tuple[np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray]:
np.ndarray, np.ndarray, float, int, float, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray
]:
"""computes a bunch of values that relate to the simulation grid """computes a bunch of values that relate to the simulation grid
Parameters Parameters
@@ -332,8 +330,6 @@ def build_sim_grid(
pump angular frequency pump angular frequency
w : np.ndarray, shape (t_num, ) w : np.ndarray, shape (t_num, )
actual angualr frequency grid in rad/s actual angualr frequency grid in rad/s
w_power_fact : np.ndarray, shape (deg, t_num)
set of all the necessaray powers of w_c
l : np.ndarray, shape (t_num) l : np.ndarray, shape (t_num)
wavelengths in m wavelengths in m
""" """
@@ -343,9 +339,9 @@ def build_sim_grid(
dt = t[1] - t[0] dt = t[1] - t[0]
t_num = len(t) t_num = len(t)
z_targets = np.linspace(0, length, z_num) z_targets = np.linspace(0, length, z_num)
w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength, interpolation_degree) w_c, w0, w = update_frequency_domain(t, wavelength, interpolation_degree)
l = 2 * pi * c / w l = 2 * pi * c / w
return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact, l return z_targets, t, time_window, t_num, dt, w_c, w0, w, l
def update_frequency_domain( def update_frequency_domain(
@@ -365,10 +361,9 @@ def update_frequency_domain(
Returns Returns
------- -------
Tuple[np.ndarray, float, np.ndarray, np.ndarray] Tuple[np.ndarray, float, np.ndarray, np.ndarray]
w_c, w0, w, w_power_fact w_c, w0, w
""" """
w_c = wspace(t) w_c = wspace(t)
w0 = 2 * pi * c / wavelength w0 = 2 * pi * c / wavelength
w = w_c + w0 w = w_c + w0
w_power_fact = np.array([power_fact(w_c, k) for k in range(2, deg + 3)]) return w_c, w0, w
return w_c, w0, w, w_power_fact

View File

@@ -434,7 +434,6 @@ class Parameters(_AbstractParameters):
l: np.ndarray = Parameter(type_checker(np.ndarray)) l: np.ndarray = Parameter(type_checker(np.ndarray))
w_c: np.ndarray = Parameter(type_checker(np.ndarray)) w_c: np.ndarray = Parameter(type_checker(np.ndarray))
w0: float = Parameter(positive(float)) w0: float = Parameter(positive(float))
w_power_fact: np.ndarray = Parameter(validator_list(type_checker(np.ndarray)))
t: np.ndarray = Parameter(type_checker(np.ndarray)) t: np.ndarray = Parameter(type_checker(np.ndarray))
L_D: float = Parameter(non_negative(float, int)) L_D: float = Parameter(non_negative(float, int))
L_NL: float = Parameter(non_negative(float, int)) L_NL: float = Parameter(non_negative(float, int))
@@ -1045,7 +1044,7 @@ class Configuration:
default_rules: list[Rule] = [ default_rules: list[Rule] = [
# Grid # Grid
*Rule.deduce( *Rule.deduce(
["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "w_power_fact", "l"], ["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "l"],
math.build_sim_grid, math.build_sim_grid,
["time_window", "t_num", "dt"], ["time_window", "t_num", "dt"],
2, 2,
@@ -1067,7 +1066,7 @@ default_rules: list[Rule] = [
Rule("pre_field_0", pulse.initial_field, priorities=1), Rule("pre_field_0", pulse.initial_field, priorities=1),
Rule( Rule(
"field_0", "field_0",
pulse.add_shot_noise, pulse.finalize_pulse,
[ [
"pre_field_0", "pre_field_0",
"quantum_noise", "quantum_noise",
@@ -1076,6 +1075,7 @@ default_rules: list[Rule] = [
"time_window", "time_window",
"dt", "dt",
"additional_noise_factor", "additional_noise_factor",
"input_transmission",
], ],
), ),
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]), Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),

View File

@@ -986,10 +986,10 @@ def delayed_raman_t(t: np.ndarray, raman_type: str) -> np.ndarray:
return hr_arr return hr_arr
def delayed_raman_w(t: np.ndarray, dt: float, raman_type: str) -> np.ndarray: def delayed_raman_w(t: np.ndarray, raman_type: str) -> np.ndarray:
"""returns the delayed raman response function as function of w """returns the delayed raman response function as function of w
see delayed_raman_t for detailes""" see delayed_raman_t for detailes"""
return fft(delayed_raman_t(t, raman_type)) * dt return fft(delayed_raman_t(t, raman_type)) * (t[1] - t[0])
def create_non_linear_op(behaviors, w_c, w0, gamma, raman_type="stolen", f_r=0, hr_w=None): def create_non_linear_op(behaviors, w_c, w0, gamma, raman_type="stolen", f_r=0, hr_w=None):
@@ -1058,7 +1058,7 @@ def create_non_linear_op(behaviors, w_c, w0, gamma, raman_type="stolen", f_r=0,
return N_func return N_func
def fast_dispersion_op(w_c, beta_arr, power_fact_arr, where=slice(None), alpha=None): def fast_dispersion_op(w_c, beta_arr, power_fact_arr, where=slice(None)):
""" """
dispersive operator dispersive operator
@@ -1083,10 +1083,7 @@ def fast_dispersion_op(w_c, beta_arr, power_fact_arr, where=slice(None), alpha=N
out = np.zeros_like(dispersion) out = np.zeros_like(dispersion)
out[where] = dispersion[where] out[where] = dispersion[where]
if alpha is None:
return -1j * out return -1j * out
else:
return -1j * out - alpha / 2
def _fast_disp_loop(dispersion: np.ndarray, beta_arr, power_fact_arr): def _fast_disp_loop(dispersion: np.ndarray, beta_arr, power_fact_arr):

View File

@@ -0,0 +1,357 @@
"""
This file includes Dispersion, NonLinear and Loss classes to be used in the solver
Nothing except the solver should depend on this file
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import numpy as np
from scipy.interpolate import interp1d
from . import fiber
from .. import math
class SpectrumDescriptor:
name: str
value: np.ndarray
def __set__(self, instance, value):
instance.field = np.fft.ifft(value)
self.value = value
def __get__(self, instance, owner):
return self.value
def __delete__(self, instance):
raise AttributeError("Cannot delete Spectrum field")
def __set_name__(self, owner, name):
self.name = name
@dataclass
class CurrentState:
length: float
z: float
h: float
spectrum: np.ndarray = SpectrumDescriptor()
field: np.ndarray = field(init=False)
@property
def z_ratio(self) -> float:
return self.z / self.length
class NoOp:
def __init__(self, w: np.ndarray):
self.zero_arr = np.zeros_like(w)
##################################################
################### DISPERSION ###################
##################################################
class AbstractDispersion(ABC):
@abstractmethod
def __call__(self, state: CurrentState) -> np.ndarray:
"""returns the dispersion in the frequency domain
Parameters
----------
state : CurrentState
current state of the simulation
Returns
-------
np.ndarray
dispersive component
"""
class ConstantPolyDispersion(AbstractDispersion):
"""
dispersion approximated by fitting a polynom on the dispersion and
evaluating on the envelope
"""
coefs: np.ndarray
w_c: np.ndarray
def __init__(
self,
wl_for_disp: np.ndarray,
beta2_arr: np.ndarray,
w0: float,
w_c: np.ndarray,
interpolation_range: tuple[float, float] = None,
interpolation_degree: int = 8,
):
self.coefs = fiber.dispersion_coefficients(
wl_for_disp, beta2_arr, w0, interpolation_range, interpolation_degree
)
self.w_c = w_c
self.w_power_fact = np.array(
[math.power_fact(w_c, k) for k in range(2, interpolation_degree + 3)]
)
def __call__(self, state: CurrentState) -> np.ndarray:
return fiber.fast_dispersion_op(self.w_c, self.coefs, self.w_power_fact)
##################################################
##################### LINEAR #####################
##################################################
class LinearOperator:
def __init__(self, disp: AbstractDispersion, loss: AbstractLoss):
self.disp = disp
self.loss = loss
def __call__(self, state: CurrentState) -> np.ndarray:
"""returns the linear operator to be multiplied by the spectrum in the frequency domain
Parameters
----------
state : CurrentState
current state of the simulation
Returns
-------
np.ndarray
linear component
"""
return self.disp(state) - self.loss(state) / 2
##################################################
################### NON LINEAR ###################
##################################################
# Raman
class AbstractRaman(ABC):
f_r: float = 0.0
@abstractmethod
def __call__(self, state: CurrentState) -> np.ndarray:
"""returns the raman component
Parameters
----------
state : CurrentState
current state of the simulation
Returns
-------
np.ndarray
raman component
"""
class NoRaman(NoOp, AbstractRaman):
def __call__(self, state: CurrentState) -> np.ndarray:
return self.zero_arr
class Raman(AbstractRaman):
def __init__(self, raman_type: str, t: np.ndarray):
self.hr_w = fiber.delayed_raman_w(t, raman_type)
self.f_r = 0.245 if raman_type == "agrawal" else 0.18
def __call__(self, state: CurrentState) -> np.ndarray:
return self.f_r * np.fft.ifft(self.hr_w * np.fft.fft(math.abs2(state.field)))
# SPM
class AbstractSPM(ABC):
fraction: float = 1.0
@abstractmethod
def __call__(self, state: CurrentState) -> np.ndarray:
"""returns the SPM component
Parameters
----------
state : CurrentState
current state of the simulation
Returns
-------
np.ndarray
SPM component
"""
class NoSPM(NoOp, AbstractSPM):
def __call__(self, state: CurrentState) -> np.ndarray:
return self.zero_arr
class SPM(AbstractSPM):
def __init__(self, raman_op: AbstractRaman):
self.fraction = 1 - raman_op.f_r
def __call__(self, state: CurrentState) -> np.ndarray:
return self.fraction * math.abs2(state.field)
# Selt Steepening
class AbstractSelfSteepening(ABC):
@abstractmethod
def __call__(self, state: CurrentState) -> np.ndarray:
"""returns the self-steepening component
Parameters
----------
state : CurrentState
current state of the simulation
Returns
-------
np.ndarray
self-steepening component
"""
class NoSelfSteepening(NoOp, AbstractSelfSteepening):
def __call__(self, state: CurrentState) -> np.ndarray:
return self.zero_arr
class SelfSteepening(AbstractSelfSteepening):
def __init__(self, w_c: np.ndarray, w0: float):
self.arr = w_c / w0
def __call__(self, state: CurrentState) -> np.ndarray:
return self.arr
# Gamma operator
class AbstractGamma(ABC):
@abstractmethod
def __call__(self, state: CurrentState) -> np.ndarray:
"""returns the gamma component
Parameters
----------
state : CurrentState
current state of the simulation
Returns
-------
np.ndarray
gamma component
"""
class NoGamma(AbstractSPM):
def __init__(self, w: np.ndarray) -> None:
self.ones_arr = np.ones_like(w)
def __call__(self, state: CurrentState) -> np.ndarray:
return self.ones_arr
class ConstantGamma(AbstractSelfSteepening):
def __init__(self, gamma_arr: np.ndarray):
self.arr = gamma_arr
def __call__(self, state: CurrentState) -> np.ndarray:
return self.arr
# Nonlinear combination
class AbstractNonLinearOperator(ABC):
@abstractmethod
def __call__(self, state: CurrentState) -> np.ndarray:
"""returns the nonlinear operator applied on the spectrum in the frequency domain
Parameters
----------
state : CurrentState
current state of the simulation
Returns
-------
np.ndarray
nonlinear component
"""
class EnvelopeNonLinearOperator(AbstractNonLinearOperator):
def __init__(
self,
gamma_op: AbstractGamma,
ss_op: AbstractSelfSteepening,
spm_op: AbstractSPM,
raman_op: AbstractRaman,
):
self.gamma_op = gamma_op
##################################################
###################### LOSS ######################
##################################################
class AbstractLoss(ABC):
@abstractmethod
def __call__(self, state: CurrentState) -> np.ndarray:
"""returns the loss in the frequency domain
Parameters
----------
state : CurrentState
current state of the simulation
Returns
-------
np.ndarray
loss in 1/m
"""
class ConstantLoss(AbstractLoss):
alpha_arr: np.ndarray
def __init__(self, alpha: float, w: np.ndarray):
self.alpha_arr = alpha * np.ones_like(w)
def __call__(self, state: CurrentState) -> np.ndarray:
return self.alpha_arr
class CapillaryLoss(ConstantLoss):
def __init__(
self,
l: np.ndarray,
core_radius: float,
interpolation_range: tuple[float, float],
he_mode: tuple[int, int],
):
mask = (l < interpolation_range[1]) & (l > 0)
alpha = fiber.capillary_loss(l[mask], he_mode, core_radius)
self.alpha_arr = np.zeros_like(l)
self.alpha_arr[mask] = alpha
class CustomConstantLoss(ConstantLoss):
def __init__(self, l: np.ndarray, loss_file: str):
loss_data = np.load(loss_file)
wl = loss_data["wavelength"]
loss = loss_data["loss"]
self.alpha_arr = interp1d(wl, loss, fill_value=0, bounds_error=False)(l)

View File

@@ -447,7 +447,7 @@ def shot_noise(w_c, w0, T, dt, additional_noise_factor=1.0):
return out * additional_noise_factor return out * additional_noise_factor
def add_shot_noise( def finalize_pulse(
field_0: np.ndarray, field_0: np.ndarray,
quantum_noise: bool, quantum_noise: bool,
w_c: bool, w_c: bool,
@@ -455,10 +455,11 @@ def add_shot_noise(
time_window: float, time_window: float,
dt: float, dt: float,
additional_noise_factor: float, additional_noise_factor: float,
input_transmission: float,
) -> np.ndarray: ) -> np.ndarray:
if quantum_noise: if quantum_noise:
field_0 = field_0 + shot_noise(w_c, w0, time_window, dt, additional_noise_factor) field_0 = field_0 + shot_noise(w_c, w0, time_window, dt, additional_noise_factor)
return field_0 return np.sqrt(input_transmission) * field_0
def mean_phase(spectra): def mean_phase(spectra):

View File

@@ -15,6 +15,7 @@ from ..parameter import Configuration, Parameters
from ..pbar import PBars, ProgressBarActor, progress_worker from ..pbar import PBars, ProgressBarActor, progress_worker
from . import pulse from . import pulse
from .fiber import create_non_linear_op, fast_dispersion_op from .fiber import create_non_linear_op, fast_dispersion_op
from .properties import CurrentState
try: try:
import ray import ray
@@ -22,18 +23,6 @@ except ModuleNotFoundError:
ray = None ray = None
@dataclass
class CurrentState:
length: float
spectrum: np.ndarray
z: float
h: float
@property
def z_ratio(self) -> float:
return self.z / self.length
class RK4IP: class RK4IP:
def __init__( def __init__(
self, self,
@@ -58,7 +47,7 @@ class RK4IP:
yield from self.irun() yield from self.irun()
def __len__(self) -> int: def __len__(self) -> int:
return self.len return self.params.z_num
def set( def set(
self, self,
@@ -79,62 +68,61 @@ class RK4IP:
self.logger = get_logger(self.params.output_path) self.logger = get_logger(self.params.output_path)
self.resuming = False self.resuming = False
self.w_c = params.w_c self.dw = self.params.w[1] - self.params.w[0]
self.w = params.w self.z_targets = self.params.z_targets
self.dw = self.w[1] - self.w[0]
self.w0 = params.w0
self.w_power_fact = params.w_power_fact
self.alpha = params.alpha_arr
self.spec_0 = np.sqrt(params.input_transmission) * params.spec_0
self.z_targets = params.z_targets
self.len = len(params.z_targets)
self.z_final = params.length
self.beta2_coefficients = ( self.beta2_coefficients = (
params.beta_func if params.beta_func is not None else params.beta2_coefficients params.beta_func if params.beta_func is not None else params.beta2_coefficients
) )
self.gamma = params.gamma_func if params.gamma_func is not None else params.gamma_arr self.gamma = params.gamma_func if params.gamma_func is not None else params.gamma_arr
self.C_to_A_factor = (params.A_eff_arr / params.A_eff_arr[0]) ** (1 / 4) self.C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
self.behaviors = params.behaviors self.error_ok = (
self.raman_type = params.raman_type params.tolerated_error if self.params.adapt_step_size else self.params.step_size
self.hr_w = params.hr_w )
self.adapt_step_size = params.adapt_step_size
self.error_ok = params.tolerated_error if self.adapt_step_size else params.step_size
self.dynamic_dispersion = params.dynamic_dispersion
self.starting_num = params.recovery_last_stored
self._setup_functions() self._setup_functions()
self._setup_sim_parameters() self._setup_sim_parameters()
def _setup_functions(self): def _setup_functions(self):
self.N_func = create_non_linear_op( self.N_func = create_non_linear_op(
self.behaviors, self.w_c, self.w0, self.gamma, self.raman_type, hr_w=self.hr_w self.params.behaviors,
self.params.w_c,
self.params.w0,
self.gamma,
self.params.raman_type,
hr_w=self.params.hr_w,
) )
if self.dynamic_dispersion: if self.params.dynamic_dispersion:
self.disp = lambda r: fast_dispersion_op( self.disp = lambda r: fast_dispersion_op(
self.w_c, self.beta2_coefficients(r), self.w_power_fact, alpha=self.alpha self.params.w_c,
self.beta2_coefficients(r),
self.params.w_power_fact,
alpha=self.params.alpha_arr,
) )
else: else:
self.disp = lambda r: fast_dispersion_op( self.disp = lambda r: fast_dispersion_op(
self.w_c, self.beta2_coefficients, self.w_power_fact, alpha=self.alpha self.params.w_c,
self.beta2_coefficients,
self.params.w_power_fact,
alpha=self.params.alpha_arr,
) )
# Set up which quantity is conserved for adaptive step size # Set up which quantity is conserved for adaptive step size
if self.adapt_step_size: if self.params.adapt_step_size:
if "raman" in self.behaviors and self.alpha is not None: if "raman" in self.params.behaviors and self.params.alpha_arr is not None:
self.logger.debug("Conserved quantity : photon number with loss") self.logger.debug("Conserved quantity : photon number with loss")
self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number_with_loss( self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number_with_loss(
spectrum, self.w, self.dw, self.gamma, self.alpha, h spectrum, self.params.w, self.dw, self.gamma, self.params.alpha_arr, h
) )
elif "raman" in self.behaviors: elif "raman" in self.params.behaviors:
self.logger.debug("Conserved quantity : photon number without loss") self.logger.debug("Conserved quantity : photon number without loss")
self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number( self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number(
spectrum, self.w, self.dw, self.gamma spectrum, self.params.w, self.dw, self.gamma
) )
elif self.alpha is not None: elif self.params.alpha_arr is not None:
self.logger.debug("Conserved quantity : energy with loss") self.logger.debug("Conserved quantity : energy with loss")
self.conserved_quantity_func = lambda spectrum, h: pulse.pulse_energy_with_loss( self.conserved_quantity_func = lambda spectrum, h: pulse.pulse_energy_with_loss(
self.C_to_A_factor * spectrum, self.dw, self.alpha, h self.C_to_A_factor * spectrum, self.dw, self.params.alpha_arr, h
) )
else: else:
self.logger.debug("Conserved quantity : energy without loss") self.logger.debug("Conserved quantity : energy without loss")
@@ -147,26 +135,29 @@ class RK4IP:
def _setup_sim_parameters(self): def _setup_sim_parameters(self):
# making sure to keep only the z that we want # making sure to keep only the z that we want
self.z_stored = list(self.z_targets.copy()[0 : self.starting_num + 1]) self.z_stored = list(self.z_targets.copy()[0 : self.params.recovery_last_stored + 1])
self.z_targets = list(self.z_targets.copy()[self.starting_num :]) self.z_targets = list(self.z_targets.copy()[self.params.recovery_last_stored :])
self.z_targets.sort() self.z_targets.sort()
self.store_num = len(self.z_targets) self.store_num = len(self.z_targets)
# Initial setup of simulation parameters # Initial setup of simulation parameters
self.d_w = self.w_c[1] - self.w_c[0] # resolution of the frequency grid
self.z = self.z_targets.pop(0) self.z = self.z_targets.pop(0)
# Setup initial values for every physical quantity that we want to track # Setup initial values for every physical quantity that we want to track
self.current_spectrum = self.spec_0.copy() / self.C_to_A_factor self.state = CurrentState(
self.stored_spectra = self.starting_num * [None] + [self.current_spectrum.copy()] length=self.params.length, spectrum=self.params.spec_0.copy() / self.C_to_A_factor
)
self.stored_spectra = self.params.recovery_last_stored * [None] + [
self.state.spectrum.copy()
]
self.cons_qty = [ self.cons_qty = [
self.conserved_quantity_func(self.current_spectrum, 0), self.conserved_quantity_func(self.state.spectrum, 0),
0, 0,
] ]
self.size_fac = 2 ** (1 / 5) self.size_fac = 2 ** (1 / 5)
# Initial step size # Initial step size
if self.adapt_step_size: if self.params.adapt_step_size:
self.initial_h = (self.z_targets[0] - self.z) / 2 self.initial_h = (self.z_targets[0] - self.z) / 2
else: else:
self.initial_h = self.error_ok self.initial_h = self.error_ok
@@ -179,7 +170,7 @@ class RK4IP:
num : int num : int
index of the z postition index of the z postition
""" """
self._save_data(self.C_to_A_factor * self.current_spectrum, f"spectrum_{num}") self._save_data(self.C_to_A_factor * self.state.spectrum, f"spectrum_{num}")
self._save_data(self.cons_qty, f"cons_qty") self._save_data(self.cons_qty, f"cons_qty")
self.step_saved() self.step_saved()
@@ -191,7 +182,7 @@ class RK4IP:
np.ndarray np.ndarray
spectrum spectrum
""" """
return self.C_to_A_factor * self.current_spectrum return self.C_to_A_factor * self.state.spectrum
def _save_data(self, data: np.ndarray, name: str): def _save_data(self, data: np.ndarray, name: str):
"""calls the appropriate method to save data """calls the appropriate method to save data
@@ -249,9 +240,9 @@ class RK4IP:
yield step, len(self.stored_spectra) - 1, self.get_current_spectrum() yield step, len(self.stored_spectra) - 1, self.get_current_spectrum()
while self.z < self.z_final: while self.z < self.params.length:
h_taken, h_next_step, self.current_spectrum = self.take_step( h_taken, h_next_step, self.state.spectrum = self.take_step(
step, h_next_step, self.current_spectrum.copy() step, h_next_step, self.state.spectrum.copy()
) )
self.z += h_taken self.z += h_taken
@@ -262,7 +253,7 @@ class RK4IP:
if store: if store:
self.logger.debug("{} steps, z = {:.4f}, h = {:.5g}".format(step, self.z, h_taken)) self.logger.debug("{} steps, z = {:.4f}, h = {:.5g}".format(step, self.z, h_taken))
self.stored_spectra.append(self.current_spectrum) self.stored_spectra.append(self.state.spectrum)
yield step, len(self.stored_spectra) - 1, self.get_current_spectrum() yield step, len(self.stored_spectra) - 1, self.get_current_spectrum()
@@ -270,7 +261,7 @@ class RK4IP:
del self.z_targets[0] del self.z_targets[0]
# reset the constant step size after a spectrum is stored # reset the constant step size after a spectrum is stored
if not self.adapt_step_size: if not self.params.adapt_step_size:
h_next_step = self.error_ok h_next_step = self.error_ok
if len(self.z_targets) == 0: if len(self.z_targets) == 0:
@@ -310,7 +301,7 @@ class RK4IP:
keep = False keep = False
while not keep: while not keep:
h = h_next_step h = h_next_step
z_ratio = self.z / self.z_final z_ratio = self.z / self.params.length
expD = np.exp(h / 2 * self.disp(z_ratio)) expD = np.exp(h / 2 * self.disp(z_ratio))
@@ -321,7 +312,7 @@ class RK4IP:
k4 = h * self.N_func(expD * (A_I + k3), z_ratio) k4 = h * self.N_func(expD * (A_I + k3), z_ratio)
new_spectrum = expD * (A_I + k1 / 6 + k2 / 3 + k3 / 3) + k4 / 6 new_spectrum = expD * (A_I + k1 / 6 + k2 / 3 + k3 / 3) + k4 / 6
if self.adapt_step_size: if self.params.adapt_step_size:
self.cons_qty[step] = self.conserved_quantity_func(new_spectrum, h) self.cons_qty[step] = self.conserved_quantity_func(new_spectrum, h)
curr_p_change = np.abs(self.cons_qty[step - 1] - self.cons_qty[step]) curr_p_change = np.abs(self.cons_qty[step - 1] - self.cons_qty[step])
cons_qty_change_ok = self.error_ok * self.cons_qty[step - 1] cons_qty_change_ok = self.error_ok * self.cons_qty[step - 1]
@@ -364,7 +355,7 @@ class SequentialRK4IP(RK4IP):
) )
def step_saved(self): def step_saved(self):
self.pbars.update(1, self.z / self.z_final - self.pbars[1].n) self.pbars.update(1, self.z / self.params.length - self.pbars[1].n)
class MutliProcRK4IP(RK4IP): class MutliProcRK4IP(RK4IP):
@@ -385,7 +376,7 @@ class MutliProcRK4IP(RK4IP):
) )
def step_saved(self): def step_saved(self):
self.p_queue.put((self.worker_id, self.z / self.z_final)) self.p_queue.put((self.worker_id, self.z / self.params.length))
class RayRK4IP(RK4IP): class RayRK4IP(RK4IP):
@@ -414,7 +405,7 @@ class RayRK4IP(RK4IP):
self.run() self.run()
def step_saved(self): def step_saved(self):
self.p_actor.update.remote(self.worker_id, self.z / self.z_final) self.p_actor.update.remote(self.worker_id, self.z / self.params.length)
self.p_actor.update.remote(0) self.p_actor.update.remote(0)
@@ -500,7 +491,7 @@ class Simulations:
self.ensure_finised_and_complete() self.ensure_finised_and_complete()
def _run_available(self): def _run_available(self):
for variable, params in self.configuration: for _, params in self.configuration:
params.compute() params.compute()
utils.save_parameters(params.prepare_for_dump(), params.output_path) utils.save_parameters(params.prepare_for_dump(), params.output_path)