misc
This commit is contained in:
@@ -377,8 +377,7 @@ default_rules: list[Rule] = [
|
||||
Rule("loss_op", operators.NoLoss, priorities=-1),
|
||||
Rule("plasma_op", operators.NoPlasma, priorities=-1),
|
||||
Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1),
|
||||
Rule("step_taker", solver.RK4IPStepTaker),
|
||||
Rule("integrator", solver.ConstantStepIntegrator, priorities=-1),
|
||||
Rule("integrator", solver.ERK54),
|
||||
]
|
||||
|
||||
envelope_rules = default_rules + [
|
||||
|
||||
@@ -6,8 +6,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
@@ -21,37 +20,6 @@ from .utils import load_material_dico
|
||||
|
||||
|
||||
class SpectrumDescriptor:
|
||||
name: str
|
||||
value: np.ndarray = None
|
||||
_counter = 0
|
||||
_converter: Callable[[np.ndarray], np.ndarray]
|
||||
|
||||
def __init__(self, spec2_name: str, field_name: str, field2_name: str):
|
||||
self.spec2_name = spec2_name
|
||||
self.field_name = field_name
|
||||
self.field2_name = field2_name
|
||||
|
||||
def __set__(self, instance: CurrentState, value: np.ndarray):
|
||||
self._counter += 1
|
||||
setattr(instance, self.spec2_name, math.abs2(value))
|
||||
setattr(instance, self.field_name, instance.converter(value))
|
||||
setattr(instance, self.field2_name, math.abs2(getattr(instance, self.field_name)))
|
||||
self.value = value
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return self.value
|
||||
|
||||
def __delete__(self, instance):
|
||||
raise AttributeError("Cannot delete Spectrum field")
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
for field_name in ["converter", "field", "field2", "spec2"]:
|
||||
if not hasattr(owner, field_name):
|
||||
raise AttributeError(f"{owner!r} doesn't have a {field_name!r} attribute")
|
||||
self.name = name
|
||||
|
||||
|
||||
class SpectrumDescriptor2:
|
||||
name: str
|
||||
spectrum: np.ndarray = None
|
||||
__spec2: np.ndarray = None
|
||||
@@ -100,14 +68,7 @@ class CurrentState:
|
||||
step: int
|
||||
C_to_A_factor: np.ndarray
|
||||
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft
|
||||
spectrum: np.ndarray = SpectrumDescriptor("spec2", "field", "field2")
|
||||
spec2: np.ndarray = dataclasses.field(init=False)
|
||||
field: np.ndarray = dataclasses.field(init=False)
|
||||
field2: np.ndarray = dataclasses.field(init=False)
|
||||
prev_spectrum: np.ndarray = SpectrumDescriptor("prev_spec2", "prev_field", "prev_field2")
|
||||
prev_spec2: np.ndarray = dataclasses.field(init=False)
|
||||
prev_field: np.ndarray = dataclasses.field(init=False)
|
||||
prev_field2: np.ndarray = dataclasses.field(init=False)
|
||||
solution: SpectrumDescriptor = SpectrumDescriptor()
|
||||
|
||||
@property
|
||||
def z_ratio(self) -> float:
|
||||
@@ -116,7 +77,7 @@ class CurrentState:
|
||||
def replace(self, new_spectrum: np.ndarray, **new_params) -> CurrentState:
|
||||
"""returns a new state with new attributes"""
|
||||
params = dict(
|
||||
spectrum=new_spectrum,
|
||||
solution=new_spectrum,
|
||||
length=self.length,
|
||||
z=self.z,
|
||||
current_step_size=self.current_step_size,
|
||||
@@ -127,6 +88,13 @@ class CurrentState:
|
||||
)
|
||||
return CurrentState(**(params | new_params))
|
||||
|
||||
def copy(self) -> CurrentState:
|
||||
return replace(self, solution=self.solution.spectrum)
|
||||
|
||||
@property
|
||||
def actual_spectrum(self) -> np.ndarray:
|
||||
return self.C_to_A_factor * self.solution.spectrum
|
||||
|
||||
|
||||
class ValueTracker(ABC):
|
||||
def values(self) -> dict[str, float]:
|
||||
@@ -644,7 +612,7 @@ class EnvelopeRaman(AbstractRaman):
|
||||
self.f_r = 0.245 if raman_type == "agrawal" else 0.18
|
||||
|
||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||
return self.f_r * np.fft.ifft(self.hr_w * np.fft.fft(state.field2))
|
||||
return self.f_r * np.fft.ifft(self.hr_w * np.fft.fft(state.solution.field2))
|
||||
|
||||
|
||||
class FullFieldRaman(AbstractRaman):
|
||||
@@ -654,7 +622,7 @@ class FullFieldRaman(AbstractRaman):
|
||||
self.multiplier = units.epsilon0 * chi3 * self.f_r
|
||||
|
||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||
return self.multiplier * np.fft.ifft(np.fft.fft(state.field2) * self.hr_w)
|
||||
return self.multiplier * np.fft.ifft(np.fft.fft(state.solution.field2) * self.hr_w)
|
||||
|
||||
|
||||
##################################################
|
||||
@@ -693,7 +661,7 @@ class EnvelopeSPM(AbstractSPM):
|
||||
self.fraction = 1 - raman_op.f_r
|
||||
|
||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||
return self.fraction * state.field2
|
||||
return self.fraction * state.solution.field2
|
||||
|
||||
|
||||
class FullFieldSPM(AbstractSPM):
|
||||
@@ -702,7 +670,7 @@ class FullFieldSPM(AbstractSPM):
|
||||
self.factor = self.fraction * chi3 * units.epsilon0
|
||||
|
||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||
return self.factor * state.field2 * state.field
|
||||
return self.factor * state.solution.field2 * state.solution.field
|
||||
|
||||
|
||||
##################################################
|
||||
@@ -819,7 +787,7 @@ class PhotonNumberLoss(AbstractConservedQuantity):
|
||||
|
||||
def __call__(self, state: CurrentState) -> float:
|
||||
return pulse.photon_number_with_loss(
|
||||
state.spec2,
|
||||
state.solution.spec2,
|
||||
self.w,
|
||||
self.dw,
|
||||
self.gamma_op(state),
|
||||
@@ -835,7 +803,7 @@ class PhotonNumberNoLoss(AbstractConservedQuantity):
|
||||
self.gamma_op = gamma_op
|
||||
|
||||
def __call__(self, state: CurrentState) -> float:
|
||||
return pulse.photon_number(state.spec2, self.w, self.dw, self.gamma_op(state))
|
||||
return pulse.photon_number(state.solution.spec2, self.w, self.dw, self.gamma_op(state))
|
||||
|
||||
|
||||
class EnergyLoss(AbstractConservedQuantity):
|
||||
@@ -845,7 +813,7 @@ class EnergyLoss(AbstractConservedQuantity):
|
||||
|
||||
def __call__(self, state: CurrentState) -> float:
|
||||
return pulse.pulse_energy_with_loss(
|
||||
math.abs2(state.C_to_A_factor * state.spectrum),
|
||||
math.abs2(state.C_to_A_factor * state.solution.spectrum),
|
||||
self.dw,
|
||||
self.loss_op(state),
|
||||
state.current_step_size,
|
||||
@@ -857,7 +825,7 @@ class EnergyNoLoss(AbstractConservedQuantity):
|
||||
self.dw = w[1] - w[0]
|
||||
|
||||
def __call__(self, state: CurrentState) -> float:
|
||||
return pulse.pulse_energy(math.abs2(state.C_to_A_factor * state.spectrum), self.dw)
|
||||
return pulse.pulse_energy(math.abs2(state.C_to_A_factor * state.solution.spectrum), self.dw)
|
||||
|
||||
|
||||
def conserved_quantity(
|
||||
@@ -970,7 +938,7 @@ class EnvelopeNonLinearOperator(NonLinearOperator):
|
||||
-1j
|
||||
* self.gamma_op(state)
|
||||
* (1 + self.ss_op(state))
|
||||
* np.fft.fft(state.field * (self.spm_op(state) + self.raman_op(state)))
|
||||
* np.fft.fft(state.solution.field * (self.spm_op(state) + self.raman_op(state)))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from collections import defaultdict
|
||||
import multiprocessing
|
||||
import multiprocessing.connection
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Optional, Type, Union
|
||||
from typing import Any, Generator, Iterator, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import utils
|
||||
from .. import solver, utils
|
||||
from ..logger import get_logger
|
||||
from ..operators import CurrentState
|
||||
from ..parameter import Configuration, Parameters
|
||||
@@ -45,7 +45,7 @@ class RK4IP:
|
||||
size_fac: float
|
||||
cons_qty: list[float]
|
||||
|
||||
state: CurrentState
|
||||
init_state: CurrentState
|
||||
stored_spectra: list[np.ndarray]
|
||||
|
||||
def __init__(
|
||||
@@ -96,7 +96,7 @@ class RK4IP:
|
||||
initial_h = (self.z_targets[1] - self.z_targets[0]) / 2
|
||||
else:
|
||||
initial_h = self.error_ok
|
||||
self.state = CurrentState(
|
||||
self.init_state = CurrentState(
|
||||
length=self.params.length,
|
||||
z=self.z_targets.pop(0),
|
||||
current_step_size=initial_h,
|
||||
@@ -104,14 +104,14 @@ class RK4IP:
|
||||
step=1,
|
||||
C_to_A_factor=C_to_A_factor,
|
||||
converter=self.params.ifft,
|
||||
spectrum=self.params.spec_0.copy() / C_to_A_factor,
|
||||
solution=self.params.spec_0.copy() / C_to_A_factor,
|
||||
)
|
||||
self.stored_spectra = self.params.recovery_last_stored * [None] + [
|
||||
self.state.spectrum.copy()
|
||||
self.init_state.solution.spectrum.copy()
|
||||
]
|
||||
self.tracked_values = TrackedValues()
|
||||
|
||||
def _save_current_spectrum(self, num: int):
|
||||
def _save_current_spectrum(self, spectrum: np.ndarray, num: int):
|
||||
"""saves the spectrum and the corresponding cons_qty array
|
||||
|
||||
Parameters
|
||||
@@ -119,19 +119,8 @@ class RK4IP:
|
||||
num : int
|
||||
index of the z postition
|
||||
"""
|
||||
self.write(self.get_current_spectrum(), f"spectrum_{num}")
|
||||
self.write(spectrum, f"spectrum_{num}")
|
||||
self.write(self.tracked_values, "tracked_values")
|
||||
self.step_saved()
|
||||
|
||||
def get_current_spectrum(self) -> np.ndarray:
|
||||
"""returns the current spectrum
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
spectrum
|
||||
"""
|
||||
return self.state.C_to_A_factor * self.state.spectrum
|
||||
|
||||
def write(self, data: np.ndarray, name: str):
|
||||
"""calls the appropriate method to save data
|
||||
@@ -147,14 +136,15 @@ class RK4IP:
|
||||
|
||||
def run(self) -> list[np.ndarray]:
|
||||
time_start = datetime.today()
|
||||
|
||||
for step, num, _ in self.irun():
|
||||
state = self.init_state
|
||||
for num, state in self.irun():
|
||||
if self.save_data:
|
||||
self._save_current_spectrum(num)
|
||||
self._save_current_spectrum(state.actual_spectrum, num)
|
||||
self.step_saved(state)
|
||||
|
||||
self.logger.info(
|
||||
"propagation finished in {} steps ({} seconds)".format(
|
||||
step, (datetime.today() - time_start).total_seconds()
|
||||
state.step, (datetime.today() - time_start).total_seconds()
|
||||
)
|
||||
)
|
||||
|
||||
@@ -163,49 +153,49 @@ class RK4IP:
|
||||
|
||||
return self.stored_spectra
|
||||
|
||||
def irun(self) -> Generator[tuple[int, int, np.ndarray], None, None]:
|
||||
def irun(self) -> Iterator[tuple[int, CurrentState]]:
|
||||
"""run the simulation as a generator obj
|
||||
|
||||
Yields
|
||||
-------
|
||||
int
|
||||
current simulation step
|
||||
int
|
||||
current number of spectra returned
|
||||
np.ndarray
|
||||
spectrum
|
||||
CurrentState
|
||||
current simulation state
|
||||
"""
|
||||
|
||||
self.logger.debug(
|
||||
"Computing {} new spectra, first one at {}m".format(self.store_num, self.z_targets[0])
|
||||
)
|
||||
store = False
|
||||
|
||||
yield self.state.step, len(self.stored_spectra) - 1, self.get_current_spectrum()
|
||||
|
||||
while self.state.z < self.params.length:
|
||||
self.state = self.params.integrator(self.state)
|
||||
|
||||
self.state.step += 1
|
||||
new_tracked_values = (
|
||||
dict(step=self.state.step, z=self.state.z) | self.params.integrator.all_values()
|
||||
state = self.init_state.copy()
|
||||
yield len(self.stored_spectra) - 1, state
|
||||
integrator = solver.ERK54(
|
||||
state,
|
||||
self.params.linear_operator,
|
||||
self.params.nonlinear_operator,
|
||||
self.params.tolerated_error,
|
||||
self.params.dt,
|
||||
)
|
||||
self.logger.debug(f"tracked values at z={self.state.z} : {new_tracked_values}")
|
||||
for state in integrator:
|
||||
|
||||
new_tracked_values = integrator.all_values()
|
||||
self.logger.debug(f"tracked values at z={state.z} : {new_tracked_values}")
|
||||
self.tracked_values.append(new_tracked_values)
|
||||
|
||||
# Whether the current spectrum has to be stored depends on previous step
|
||||
if store:
|
||||
current_spec = self.get_current_spectrum()
|
||||
current_spec = state.actual_spectrum
|
||||
self.stored_spectra.append(current_spec)
|
||||
|
||||
yield self.state.step, len(self.stored_spectra) - 1, current_spec
|
||||
yield len(self.stored_spectra) - 1, state.copy()
|
||||
|
||||
self.z_stored.append(self.state.z)
|
||||
self.z_stored.append(state.z)
|
||||
del self.z_targets[0]
|
||||
|
||||
# reset the constant step size after a spectrum is stored
|
||||
if not self.params.adapt_step_size:
|
||||
self.state.current_step_size = self.error_ok
|
||||
integrator.state.current_step_size = self.error_ok
|
||||
|
||||
if len(self.z_targets) == 0:
|
||||
break
|
||||
@@ -213,14 +203,14 @@ class RK4IP:
|
||||
|
||||
# if the next step goes over a position at which we want to store
|
||||
# a spectrum, we shorten the step to reach this position exactly
|
||||
if self.state.z + self.state.current_step_size >= self.z_targets[0]:
|
||||
if state.z + integrator.state.current_step_size >= self.z_targets[0]:
|
||||
store = True
|
||||
self.state.current_step_size = self.z_targets[0] - self.state.z
|
||||
integrator.state.current_step_size = self.z_targets[0] - state.z
|
||||
|
||||
def step_saved(self):
|
||||
def step_saved(self, state: CurrentState):
|
||||
pass
|
||||
|
||||
def __iter__(self) -> Generator[tuple[int, int, np.ndarray], None, None]:
|
||||
def __iter__(self) -> Iterator[tuple[int, CurrentState]]:
|
||||
yield from self.irun()
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -240,8 +230,8 @@ class SequentialRK4IP(RK4IP):
|
||||
save_data=save_data,
|
||||
)
|
||||
|
||||
def step_saved(self):
|
||||
self.pbars.update(1, self.state.z / self.params.length - self.pbars[1].n)
|
||||
def step_saved(self, state: CurrentState):
|
||||
self.pbars.update(1, state.z / self.params.length - self.pbars[1].n)
|
||||
|
||||
|
||||
class MutliProcRK4IP(RK4IP):
|
||||
@@ -259,8 +249,8 @@ class MutliProcRK4IP(RK4IP):
|
||||
save_data=save_data,
|
||||
)
|
||||
|
||||
def step_saved(self):
|
||||
self.p_queue.put((self.worker_id, self.state.z / self.params.length))
|
||||
def step_saved(self, state: CurrentState):
|
||||
self.p_queue.put((self.worker_id, state.z / self.params.length))
|
||||
|
||||
|
||||
class RayRK4IP(RK4IP):
|
||||
@@ -286,8 +276,8 @@ class RayRK4IP(RK4IP):
|
||||
self.set(params, p_actor, worker_id, save_data)
|
||||
self.run()
|
||||
|
||||
def step_saved(self):
|
||||
self.p_actor.update.remote(self.worker_id, self.state.z / self.params.length)
|
||||
def step_saved(self, state: CurrentState):
|
||||
self.p_actor.update.remote(self.worker_id, state.z / self.params.length)
|
||||
self.p_actor.update.remote(0)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Iterator
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -69,7 +70,7 @@ class RK4IPStepTaker(StepTaker):
|
||||
l0, nl0 = self.cached_values(state)
|
||||
expD = np.exp(h * self.c2 * l0)
|
||||
|
||||
A_I = expD * state.spectrum
|
||||
A_I = expD * state.solution
|
||||
k1 = expD * (h * nl0)
|
||||
k2 = h * self.nonlinear_operator(state.replace(A_I + k1 * self.c2))
|
||||
k3 = h * self.nonlinear_operator(state.replace(A_I + k2 * self.c2))
|
||||
@@ -114,14 +115,37 @@ class RK4IPStepTaker(StepTaker):
|
||||
|
||||
|
||||
class Integrator(ValueTracker):
|
||||
last_step = 0.0
|
||||
linear_operator: LinearOperator
|
||||
nonlinear_operator: NonLinearOperator
|
||||
_tracked_values: dict[str, float]
|
||||
|
||||
def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator):
|
||||
self.linear_operator = linear_operator
|
||||
self.nonlinear_operator = nonlinear_operator
|
||||
self._tracked_values = {}
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, state: CurrentState) -> CurrentState:
|
||||
def __iter__(self) -> Iterator[CurrentState]:
|
||||
"""propagate the state with a step size of state.current_step_size
|
||||
and return a new state with updated z and previous_step_size attributes"""
|
||||
and yield a new state with updated z and previous_step_size attributes"""
|
||||
...
|
||||
|
||||
def all_values(self) -> dict[str, float]:
|
||||
"""override ValueTracker.all_values to account for the fact that operators are called
|
||||
multiple times per step, sometimes with different state, so we use value recorded
|
||||
earlier. Please call self.recorde_tracked_values() one time only just after calling
|
||||
the linear and nonlinear operators in your StepTaker.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, float]
|
||||
tracked values
|
||||
"""
|
||||
return self.values() | self._tracked_values
|
||||
|
||||
def record_tracked_values(self):
|
||||
self._tracked_values = super().all_values()
|
||||
|
||||
|
||||
class ConstantStepIntegrator(Integrator):
|
||||
def __call__(self, state: CurrentState) -> CurrentState:
|
||||
@@ -234,7 +258,7 @@ class LocalErrorIntegrator(Integrator):
|
||||
h_next_step = h * self.size_fac
|
||||
|
||||
self.local_error = delta
|
||||
fine_state.spectrum = fine_spec * self.fine_fac + coarse_spec * self.coarse_fac
|
||||
fine_state.solution = fine_spec * self.fine_fac + coarse_spec * self.coarse_fac
|
||||
fine_state.current_step_size = h_next_step
|
||||
fine_state.previous_step_size = h
|
||||
fine_state.z += h
|
||||
@@ -249,32 +273,122 @@ class LocalErrorIntegrator(Integrator):
|
||||
|
||||
|
||||
class ERK43(Integrator):
|
||||
state: CurrentState
|
||||
linear_operator: LinearOperator
|
||||
nonlinear_operator: NonLinearOperator
|
||||
tolerated_error: float
|
||||
dt: float
|
||||
current_error: float
|
||||
next_h_factor = 1.0
|
||||
|
||||
def __init__(
|
||||
self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator, dt: float
|
||||
self,
|
||||
init_state: CurrentState,
|
||||
linear_operator: LinearOperator,
|
||||
nonlinear_operator: NonLinearOperator,
|
||||
tolerated_error: float,
|
||||
dt: float,
|
||||
):
|
||||
self.state = init_state
|
||||
self.linear_operator = linear_operator
|
||||
self.nonlinear_operator = nonlinear_operator
|
||||
self.dt = dt
|
||||
self.tolerated_error = tolerated_error
|
||||
self.current_error = 0.0
|
||||
|
||||
def __call__(self, state: CurrentState) -> CurrentState:
|
||||
keep = False
|
||||
h_next_step = state.current_step_size
|
||||
while not keep:
|
||||
def __iter__(self) -> Iterator[CurrentState]:
|
||||
h_next_step = self.state.current_step_size
|
||||
k5 = self.nonlinear_operator(self.state)
|
||||
while True:
|
||||
lin = self.linear_operator(self.state)
|
||||
self.record_tracked_values()
|
||||
while True:
|
||||
h = h_next_step
|
||||
expD = np.exp(h * 0.5 * self.linear_operator(state))
|
||||
A_I = expD * state.spectrum
|
||||
k1 = expD * state.prev_spectrum
|
||||
k2 = self.nonlinear_operator(state.replace(A_I + 0.5 * h * k1))
|
||||
k3 = self.nonlinear_operator(state.replace(A_I + 0.5 * h * k2))
|
||||
k4 = self.nonlinear_operator(state.replace(expD * A_I + h * k3))
|
||||
expD = np.exp(h * 0.5 * lin)
|
||||
A_I = expD * self.state.solution.spectrum
|
||||
k1 = expD * k5
|
||||
k2 = self.nl(A_I + 0.5 * h * k1)
|
||||
k3 = self.nl(A_I + 0.5 * h * k2)
|
||||
k4 = self.nl(expD * A_I + h * k3)
|
||||
r = expD * (A_I + h / 6 * (k1 + 2 * k2 + 2 * k3))
|
||||
|
||||
new_fine = r + h / 6 * k4
|
||||
|
||||
k5 = self.nonlinear_operator(state.replace(new_fine))
|
||||
tmp_k5 = self.nl(new_fine)
|
||||
|
||||
new_coarse = r + h / 30 * (2 * k4 + 3 * k5)
|
||||
new_coarse = r + h / 30 * (2 * k4 + 3 * tmp_k5)
|
||||
|
||||
self.current_error = np.sqrt(self.dt * math.abs2(new_fine - new_coarse).sum())
|
||||
self.next_h_factor = max(
|
||||
0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25)
|
||||
)
|
||||
h_next_step = self.next_h_factor * h
|
||||
if self.current_error <= self.tolerated_error:
|
||||
break
|
||||
self.state.current_step_size = h_next_step
|
||||
self.state.previous_step_size = h
|
||||
self.state.z += h
|
||||
self.state.step += 1
|
||||
self.state.solution = new_fine
|
||||
k5 = tmp_k5
|
||||
yield self.state
|
||||
|
||||
def values(self) -> dict[str, float]:
|
||||
return dict(
|
||||
step=self.state.step,
|
||||
z=self.state.z,
|
||||
local_error=self.current_error,
|
||||
next_h_factor=self.next_h_factor,
|
||||
)
|
||||
|
||||
def nl(self, spectrum: np.ndarray) -> np.ndarray:
|
||||
return self.nonlinear_operator(self.state.replace(spectrum))
|
||||
|
||||
|
||||
class ERK54(ERK43):
|
||||
def __iter__(self) -> Iterator[CurrentState]:
|
||||
print("using ERK54")
|
||||
h_next_step = self.state.current_step_size
|
||||
k7 = self.nonlinear_operator(self.state)
|
||||
while True:
|
||||
lin = self.linear_operator(self.state)
|
||||
self.record_tracked_values()
|
||||
while True:
|
||||
h = h_next_step
|
||||
expD2 = np.exp(h * 0.5 * lin)
|
||||
expD4p = np.exp(h * 0.25 * lin)
|
||||
expD4m = 1 / expD4p
|
||||
|
||||
A_I = expD2 * self.state.solution.spectrum
|
||||
k1 = expD2 * k7
|
||||
k2 = self.nl(A_I + 0.5 * h * k1)
|
||||
k3 = expD4p * self.nl(expD4m * (A_I + h / 16 * (3 * k1 + k2)))
|
||||
k4 = self.nl(A_I + h / 4 * (k1 - k2 + 4 * k3))
|
||||
k5 = expD4m * self.nl(expD4p * (A_I + 3 * h / 16 * (k1 + 3 * k4)))
|
||||
k6 = self.nl(expD2 * (A_I + h / 7 * (-2 * k1 + k2 + 12 * k3 - 12 * k4 + 8 * k5)))
|
||||
|
||||
new_fine = (
|
||||
expD2 * (A_I + h / 90 * (7 * k1 + 32 * k3 + 12 * k4 + 32 * k5))
|
||||
+ 7 * h / 90 * k6
|
||||
)
|
||||
tmp_k7 = self.nl(new_fine)
|
||||
new_coarse = (
|
||||
expD2 * (A_I + h / 42 * (3 * k1 + 16 * k3 + 4 * k4 + 16 * k5)) + h / 14 * k7
|
||||
)
|
||||
|
||||
self.current_error = np.sqrt(
|
||||
self.dt * math.abs2(np.abs(new_fine) - np.abs(new_coarse)).sum()
|
||||
)
|
||||
self.next_h_factor = max(
|
||||
0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25)
|
||||
)
|
||||
h_next_step = self.next_h_factor * h
|
||||
if self.current_error <= self.tolerated_error:
|
||||
break
|
||||
self.state.current_step_size = h_next_step
|
||||
self.state.previous_step_size = h
|
||||
self.state.z += h
|
||||
self.state.step += 1
|
||||
self.state.solution = new_fine
|
||||
k7 = tmp_k7
|
||||
yield self.state
|
||||
|
||||
Reference in New Issue
Block a user