This commit is contained in:
Benoît Sierro
2021-11-09 16:10:33 +01:00
parent f757466fe1
commit 65b42bf2ee
4 changed files with 198 additions and 127 deletions

View File

@@ -377,8 +377,7 @@ default_rules: list[Rule] = [
Rule("loss_op", operators.NoLoss, priorities=-1), Rule("loss_op", operators.NoLoss, priorities=-1),
Rule("plasma_op", operators.NoPlasma, priorities=-1), Rule("plasma_op", operators.NoPlasma, priorities=-1),
Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1), Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1),
Rule("step_taker", solver.RK4IPStepTaker), Rule("integrator", solver.ERK54),
Rule("integrator", solver.ConstantStepIntegrator, priorities=-1),
] ]
envelope_rules = default_rules + [ envelope_rules = default_rules + [

View File

@@ -6,8 +6,7 @@ from __future__ import annotations
import dataclasses import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass, replace
import re
from typing import Any, Callable from typing import Any, Callable
import numpy as np import numpy as np
@@ -21,37 +20,6 @@ from .utils import load_material_dico
class SpectrumDescriptor: class SpectrumDescriptor:
name: str
value: np.ndarray = None
_counter = 0
_converter: Callable[[np.ndarray], np.ndarray]
def __init__(self, spec2_name: str, field_name: str, field2_name: str):
self.spec2_name = spec2_name
self.field_name = field_name
self.field2_name = field2_name
def __set__(self, instance: CurrentState, value: np.ndarray):
self._counter += 1
setattr(instance, self.spec2_name, math.abs2(value))
setattr(instance, self.field_name, instance.converter(value))
setattr(instance, self.field2_name, math.abs2(getattr(instance, self.field_name)))
self.value = value
def __get__(self, instance, owner):
return self.value
def __delete__(self, instance):
raise AttributeError("Cannot delete Spectrum field")
def __set_name__(self, owner, name):
for field_name in ["converter", "field", "field2", "spec2"]:
if not hasattr(owner, field_name):
raise AttributeError(f"{owner!r} doesn't have a {field_name!r} attribute")
self.name = name
class SpectrumDescriptor2:
name: str name: str
spectrum: np.ndarray = None spectrum: np.ndarray = None
__spec2: np.ndarray = None __spec2: np.ndarray = None
@@ -100,14 +68,7 @@ class CurrentState:
step: int step: int
C_to_A_factor: np.ndarray C_to_A_factor: np.ndarray
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft
spectrum: np.ndarray = SpectrumDescriptor("spec2", "field", "field2") solution: SpectrumDescriptor = SpectrumDescriptor()
spec2: np.ndarray = dataclasses.field(init=False)
field: np.ndarray = dataclasses.field(init=False)
field2: np.ndarray = dataclasses.field(init=False)
prev_spectrum: np.ndarray = SpectrumDescriptor("prev_spec2", "prev_field", "prev_field2")
prev_spec2: np.ndarray = dataclasses.field(init=False)
prev_field: np.ndarray = dataclasses.field(init=False)
prev_field2: np.ndarray = dataclasses.field(init=False)
@property @property
def z_ratio(self) -> float: def z_ratio(self) -> float:
@@ -116,7 +77,7 @@ class CurrentState:
def replace(self, new_spectrum: np.ndarray, **new_params) -> CurrentState: def replace(self, new_spectrum: np.ndarray, **new_params) -> CurrentState:
"""returns a new state with new attributes""" """returns a new state with new attributes"""
params = dict( params = dict(
spectrum=new_spectrum, solution=new_spectrum,
length=self.length, length=self.length,
z=self.z, z=self.z,
current_step_size=self.current_step_size, current_step_size=self.current_step_size,
@@ -127,6 +88,13 @@ class CurrentState:
) )
return CurrentState(**(params | new_params)) return CurrentState(**(params | new_params))
def copy(self) -> CurrentState:
return replace(self, solution=self.solution.spectrum)
@property
def actual_spectrum(self) -> np.ndarray:
return self.C_to_A_factor * self.solution.spectrum
class ValueTracker(ABC): class ValueTracker(ABC):
def values(self) -> dict[str, float]: def values(self) -> dict[str, float]:
@@ -644,7 +612,7 @@ class EnvelopeRaman(AbstractRaman):
self.f_r = 0.245 if raman_type == "agrawal" else 0.18 self.f_r = 0.245 if raman_type == "agrawal" else 0.18
def __call__(self, state: CurrentState) -> np.ndarray: def __call__(self, state: CurrentState) -> np.ndarray:
return self.f_r * np.fft.ifft(self.hr_w * np.fft.fft(state.field2)) return self.f_r * np.fft.ifft(self.hr_w * np.fft.fft(state.solution.field2))
class FullFieldRaman(AbstractRaman): class FullFieldRaman(AbstractRaman):
@@ -654,7 +622,7 @@ class FullFieldRaman(AbstractRaman):
self.multiplier = units.epsilon0 * chi3 * self.f_r self.multiplier = units.epsilon0 * chi3 * self.f_r
def __call__(self, state: CurrentState) -> np.ndarray: def __call__(self, state: CurrentState) -> np.ndarray:
return self.multiplier * np.fft.ifft(np.fft.fft(state.field2) * self.hr_w) return self.multiplier * np.fft.ifft(np.fft.fft(state.solution.field2) * self.hr_w)
################################################## ##################################################
@@ -693,7 +661,7 @@ class EnvelopeSPM(AbstractSPM):
self.fraction = 1 - raman_op.f_r self.fraction = 1 - raman_op.f_r
def __call__(self, state: CurrentState) -> np.ndarray: def __call__(self, state: CurrentState) -> np.ndarray:
return self.fraction * state.field2 return self.fraction * state.solution.field2
class FullFieldSPM(AbstractSPM): class FullFieldSPM(AbstractSPM):
@@ -702,7 +670,7 @@ class FullFieldSPM(AbstractSPM):
self.factor = self.fraction * chi3 * units.epsilon0 self.factor = self.fraction * chi3 * units.epsilon0
def __call__(self, state: CurrentState) -> np.ndarray: def __call__(self, state: CurrentState) -> np.ndarray:
return self.factor * state.field2 * state.field return self.factor * state.solution.field2 * state.solution.field
################################################## ##################################################
@@ -819,7 +787,7 @@ class PhotonNumberLoss(AbstractConservedQuantity):
def __call__(self, state: CurrentState) -> float: def __call__(self, state: CurrentState) -> float:
return pulse.photon_number_with_loss( return pulse.photon_number_with_loss(
state.spec2, state.solution.spec2,
self.w, self.w,
self.dw, self.dw,
self.gamma_op(state), self.gamma_op(state),
@@ -835,7 +803,7 @@ class PhotonNumberNoLoss(AbstractConservedQuantity):
self.gamma_op = gamma_op self.gamma_op = gamma_op
def __call__(self, state: CurrentState) -> float: def __call__(self, state: CurrentState) -> float:
return pulse.photon_number(state.spec2, self.w, self.dw, self.gamma_op(state)) return pulse.photon_number(state.solution.spec2, self.w, self.dw, self.gamma_op(state))
class EnergyLoss(AbstractConservedQuantity): class EnergyLoss(AbstractConservedQuantity):
@@ -845,7 +813,7 @@ class EnergyLoss(AbstractConservedQuantity):
def __call__(self, state: CurrentState) -> float: def __call__(self, state: CurrentState) -> float:
return pulse.pulse_energy_with_loss( return pulse.pulse_energy_with_loss(
math.abs2(state.C_to_A_factor * state.spectrum), math.abs2(state.C_to_A_factor * state.solution.spectrum),
self.dw, self.dw,
self.loss_op(state), self.loss_op(state),
state.current_step_size, state.current_step_size,
@@ -857,7 +825,7 @@ class EnergyNoLoss(AbstractConservedQuantity):
self.dw = w[1] - w[0] self.dw = w[1] - w[0]
def __call__(self, state: CurrentState) -> float: def __call__(self, state: CurrentState) -> float:
return pulse.pulse_energy(math.abs2(state.C_to_A_factor * state.spectrum), self.dw) return pulse.pulse_energy(math.abs2(state.C_to_A_factor * state.solution.spectrum), self.dw)
def conserved_quantity( def conserved_quantity(
@@ -970,7 +938,7 @@ class EnvelopeNonLinearOperator(NonLinearOperator):
-1j -1j
* self.gamma_op(state) * self.gamma_op(state)
* (1 + self.ss_op(state)) * (1 + self.ss_op(state))
* np.fft.fft(state.field * (self.spm_op(state) + self.raman_op(state))) * np.fft.fft(state.solution.field * (self.spm_op(state) + self.raman_op(state)))
) )

View File

@@ -1,15 +1,15 @@
from collections import defaultdict
import multiprocessing import multiprocessing
import multiprocessing.connection import multiprocessing.connection
import os import os
from collections import defaultdict
from datetime import datetime from datetime import datetime
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Any, Generator, Optional, Type, Union from typing import Any, Generator, Iterator, Optional, Type, Union
import numpy as np import numpy as np
from .. import utils from .. import solver, utils
from ..logger import get_logger from ..logger import get_logger
from ..operators import CurrentState from ..operators import CurrentState
from ..parameter import Configuration, Parameters from ..parameter import Configuration, Parameters
@@ -45,7 +45,7 @@ class RK4IP:
size_fac: float size_fac: float
cons_qty: list[float] cons_qty: list[float]
state: CurrentState init_state: CurrentState
stored_spectra: list[np.ndarray] stored_spectra: list[np.ndarray]
def __init__( def __init__(
@@ -96,7 +96,7 @@ class RK4IP:
initial_h = (self.z_targets[1] - self.z_targets[0]) / 2 initial_h = (self.z_targets[1] - self.z_targets[0]) / 2
else: else:
initial_h = self.error_ok initial_h = self.error_ok
self.state = CurrentState( self.init_state = CurrentState(
length=self.params.length, length=self.params.length,
z=self.z_targets.pop(0), z=self.z_targets.pop(0),
current_step_size=initial_h, current_step_size=initial_h,
@@ -104,14 +104,14 @@ class RK4IP:
step=1, step=1,
C_to_A_factor=C_to_A_factor, C_to_A_factor=C_to_A_factor,
converter=self.params.ifft, converter=self.params.ifft,
spectrum=self.params.spec_0.copy() / C_to_A_factor, solution=self.params.spec_0.copy() / C_to_A_factor,
) )
self.stored_spectra = self.params.recovery_last_stored * [None] + [ self.stored_spectra = self.params.recovery_last_stored * [None] + [
self.state.spectrum.copy() self.init_state.solution.spectrum.copy()
] ]
self.tracked_values = TrackedValues() self.tracked_values = TrackedValues()
def _save_current_spectrum(self, num: int): def _save_current_spectrum(self, spectrum: np.ndarray, num: int):
"""saves the spectrum and the corresponding cons_qty array """saves the spectrum and the corresponding cons_qty array
Parameters Parameters
@@ -119,19 +119,8 @@ class RK4IP:
num : int num : int
index of the z postition index of the z postition
""" """
self.write(self.get_current_spectrum(), f"spectrum_{num}") self.write(spectrum, f"spectrum_{num}")
self.write(self.tracked_values, "tracked_values") self.write(self.tracked_values, "tracked_values")
self.step_saved()
def get_current_spectrum(self) -> np.ndarray:
"""returns the current spectrum
Returns
-------
np.ndarray
spectrum
"""
return self.state.C_to_A_factor * self.state.spectrum
def write(self, data: np.ndarray, name: str): def write(self, data: np.ndarray, name: str):
"""calls the appropriate method to save data """calls the appropriate method to save data
@@ -147,14 +136,15 @@ class RK4IP:
def run(self) -> list[np.ndarray]: def run(self) -> list[np.ndarray]:
time_start = datetime.today() time_start = datetime.today()
state = self.init_state
for step, num, _ in self.irun(): for num, state in self.irun():
if self.save_data: if self.save_data:
self._save_current_spectrum(num) self._save_current_spectrum(state.actual_spectrum, num)
self.step_saved(state)
self.logger.info( self.logger.info(
"propagation finished in {} steps ({} seconds)".format( "propagation finished in {} steps ({} seconds)".format(
step, (datetime.today() - time_start).total_seconds() state.step, (datetime.today() - time_start).total_seconds()
) )
) )
@@ -163,49 +153,49 @@ class RK4IP:
return self.stored_spectra return self.stored_spectra
def irun(self) -> Generator[tuple[int, int, np.ndarray], None, None]: def irun(self) -> Iterator[tuple[int, CurrentState]]:
"""run the simulation as a generator obj """run the simulation as a generator obj
Yields Yields
------- -------
int
current simulation step
int int
current number of spectra returned current number of spectra returned
np.ndarray CurrentState
spectrum current simulation state
""" """
self.logger.debug( self.logger.debug(
"Computing {} new spectra, first one at {}m".format(self.store_num, self.z_targets[0]) "Computing {} new spectra, first one at {}m".format(self.store_num, self.z_targets[0])
) )
store = False store = False
state = self.init_state.copy()
yield self.state.step, len(self.stored_spectra) - 1, self.get_current_spectrum() yield len(self.stored_spectra) - 1, state
integrator = solver.ERK54(
while self.state.z < self.params.length: state,
self.state = self.params.integrator(self.state) self.params.linear_operator,
self.params.nonlinear_operator,
self.state.step += 1 self.params.tolerated_error,
new_tracked_values = ( self.params.dt,
dict(step=self.state.step, z=self.state.z) | self.params.integrator.all_values()
) )
self.logger.debug(f"tracked values at z={self.state.z} : {new_tracked_values}") for state in integrator:
new_tracked_values = integrator.all_values()
self.logger.debug(f"tracked values at z={state.z} : {new_tracked_values}")
self.tracked_values.append(new_tracked_values) self.tracked_values.append(new_tracked_values)
# Whether the current spectrum has to be stored depends on previous step # Whether the current spectrum has to be stored depends on previous step
if store: if store:
current_spec = self.get_current_spectrum() current_spec = state.actual_spectrum
self.stored_spectra.append(current_spec) self.stored_spectra.append(current_spec)
yield self.state.step, len(self.stored_spectra) - 1, current_spec yield len(self.stored_spectra) - 1, state.copy()
self.z_stored.append(self.state.z) self.z_stored.append(state.z)
del self.z_targets[0] del self.z_targets[0]
# reset the constant step size after a spectrum is stored # reset the constant step size after a spectrum is stored
if not self.params.adapt_step_size: if not self.params.adapt_step_size:
self.state.current_step_size = self.error_ok integrator.state.current_step_size = self.error_ok
if len(self.z_targets) == 0: if len(self.z_targets) == 0:
break break
@@ -213,14 +203,14 @@ class RK4IP:
# if the next step goes over a position at which we want to store # if the next step goes over a position at which we want to store
# a spectrum, we shorten the step to reach this position exactly # a spectrum, we shorten the step to reach this position exactly
if self.state.z + self.state.current_step_size >= self.z_targets[0]: if state.z + integrator.state.current_step_size >= self.z_targets[0]:
store = True store = True
self.state.current_step_size = self.z_targets[0] - self.state.z integrator.state.current_step_size = self.z_targets[0] - state.z
def step_saved(self): def step_saved(self, state: CurrentState):
pass pass
def __iter__(self) -> Generator[tuple[int, int, np.ndarray], None, None]: def __iter__(self) -> Iterator[tuple[int, CurrentState]]:
yield from self.irun() yield from self.irun()
def __len__(self) -> int: def __len__(self) -> int:
@@ -240,8 +230,8 @@ class SequentialRK4IP(RK4IP):
save_data=save_data, save_data=save_data,
) )
def step_saved(self): def step_saved(self, state: CurrentState):
self.pbars.update(1, self.state.z / self.params.length - self.pbars[1].n) self.pbars.update(1, state.z / self.params.length - self.pbars[1].n)
class MutliProcRK4IP(RK4IP): class MutliProcRK4IP(RK4IP):
@@ -259,8 +249,8 @@ class MutliProcRK4IP(RK4IP):
save_data=save_data, save_data=save_data,
) )
def step_saved(self): def step_saved(self, state: CurrentState):
self.p_queue.put((self.worker_id, self.state.z / self.params.length)) self.p_queue.put((self.worker_id, state.z / self.params.length))
class RayRK4IP(RK4IP): class RayRK4IP(RK4IP):
@@ -286,8 +276,8 @@ class RayRK4IP(RK4IP):
self.set(params, p_actor, worker_id, save_data) self.set(params, p_actor, worker_id, save_data)
self.run() self.run()
def step_saved(self): def step_saved(self, state: CurrentState):
self.p_actor.update.remote(self.worker_id, self.state.z / self.params.length) self.p_actor.update.remote(self.worker_id, state.z / self.params.length)
self.p_actor.update.remote(0) self.p_actor.update.remote(0)

View File

@@ -1,4 +1,5 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Iterator
import numpy as np import numpy as np
@@ -69,7 +70,7 @@ class RK4IPStepTaker(StepTaker):
l0, nl0 = self.cached_values(state) l0, nl0 = self.cached_values(state)
expD = np.exp(h * self.c2 * l0) expD = np.exp(h * self.c2 * l0)
A_I = expD * state.spectrum A_I = expD * state.solution
k1 = expD * (h * nl0) k1 = expD * (h * nl0)
k2 = h * self.nonlinear_operator(state.replace(A_I + k1 * self.c2)) k2 = h * self.nonlinear_operator(state.replace(A_I + k1 * self.c2))
k3 = h * self.nonlinear_operator(state.replace(A_I + k2 * self.c2)) k3 = h * self.nonlinear_operator(state.replace(A_I + k2 * self.c2))
@@ -114,14 +115,37 @@ class RK4IPStepTaker(StepTaker):
class Integrator(ValueTracker): class Integrator(ValueTracker):
last_step = 0.0 linear_operator: LinearOperator
nonlinear_operator: NonLinearOperator
_tracked_values: dict[str, float]
def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator):
self.linear_operator = linear_operator
self.nonlinear_operator = nonlinear_operator
self._tracked_values = {}
@abstractmethod @abstractmethod
def __call__(self, state: CurrentState) -> CurrentState: def __iter__(self) -> Iterator[CurrentState]:
"""propagate the state with a step size of state.current_step_size """propagate the state with a step size of state.current_step_size
and return a new state with updated z and previous_step_size attributes""" and yield a new state with updated z and previous_step_size attributes"""
... ...
def all_values(self) -> dict[str, float]:
"""override ValueTracker.all_values to account for the fact that operators are called
multiple times per step, sometimes with different state, so we use value recorded
earlier. Please call self.recorde_tracked_values() one time only just after calling
the linear and nonlinear operators in your StepTaker.
Returns
-------
dict[str, float]
tracked values
"""
return self.values() | self._tracked_values
def record_tracked_values(self):
self._tracked_values = super().all_values()
class ConstantStepIntegrator(Integrator): class ConstantStepIntegrator(Integrator):
def __call__(self, state: CurrentState) -> CurrentState: def __call__(self, state: CurrentState) -> CurrentState:
@@ -234,7 +258,7 @@ class LocalErrorIntegrator(Integrator):
h_next_step = h * self.size_fac h_next_step = h * self.size_fac
self.local_error = delta self.local_error = delta
fine_state.spectrum = fine_spec * self.fine_fac + coarse_spec * self.coarse_fac fine_state.solution = fine_spec * self.fine_fac + coarse_spec * self.coarse_fac
fine_state.current_step_size = h_next_step fine_state.current_step_size = h_next_step
fine_state.previous_step_size = h fine_state.previous_step_size = h
fine_state.z += h fine_state.z += h
@@ -249,32 +273,122 @@ class LocalErrorIntegrator(Integrator):
class ERK43(Integrator): class ERK43(Integrator):
state: CurrentState
linear_operator: LinearOperator linear_operator: LinearOperator
nonlinear_operator: NonLinearOperator nonlinear_operator: NonLinearOperator
tolerated_error: float
dt: float dt: float
current_error: float
next_h_factor = 1.0
def __init__( def __init__(
self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator, dt: float self,
init_state: CurrentState,
linear_operator: LinearOperator,
nonlinear_operator: NonLinearOperator,
tolerated_error: float,
dt: float,
): ):
self.state = init_state
self.linear_operator = linear_operator self.linear_operator = linear_operator
self.nonlinear_operator = nonlinear_operator self.nonlinear_operator = nonlinear_operator
self.dt = dt self.dt = dt
self.tolerated_error = tolerated_error
self.current_error = 0.0
def __call__(self, state: CurrentState) -> CurrentState: def __iter__(self) -> Iterator[CurrentState]:
keep = False h_next_step = self.state.current_step_size
h_next_step = state.current_step_size k5 = self.nonlinear_operator(self.state)
while not keep: while True:
lin = self.linear_operator(self.state)
self.record_tracked_values()
while True:
h = h_next_step h = h_next_step
expD = np.exp(h * 0.5 * self.linear_operator(state)) expD = np.exp(h * 0.5 * lin)
A_I = expD * state.spectrum A_I = expD * self.state.solution.spectrum
k1 = expD * state.prev_spectrum k1 = expD * k5
k2 = self.nonlinear_operator(state.replace(A_I + 0.5 * h * k1)) k2 = self.nl(A_I + 0.5 * h * k1)
k3 = self.nonlinear_operator(state.replace(A_I + 0.5 * h * k2)) k3 = self.nl(A_I + 0.5 * h * k2)
k4 = self.nonlinear_operator(state.replace(expD * A_I + h * k3)) k4 = self.nl(expD * A_I + h * k3)
r = expD * (A_I + h / 6 * (k1 + 2 * k2 + 2 * k3)) r = expD * (A_I + h / 6 * (k1 + 2 * k2 + 2 * k3))
new_fine = r + h / 6 * k4 new_fine = r + h / 6 * k4
k5 = self.nonlinear_operator(state.replace(new_fine)) tmp_k5 = self.nl(new_fine)
new_coarse = r + h / 30 * (2 * k4 + 3 * k5) new_coarse = r + h / 30 * (2 * k4 + 3 * tmp_k5)
self.current_error = np.sqrt(self.dt * math.abs2(new_fine - new_coarse).sum())
self.next_h_factor = max(
0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25)
)
h_next_step = self.next_h_factor * h
if self.current_error <= self.tolerated_error:
break
self.state.current_step_size = h_next_step
self.state.previous_step_size = h
self.state.z += h
self.state.step += 1
self.state.solution = new_fine
k5 = tmp_k5
yield self.state
def values(self) -> dict[str, float]:
return dict(
step=self.state.step,
z=self.state.z,
local_error=self.current_error,
next_h_factor=self.next_h_factor,
)
def nl(self, spectrum: np.ndarray) -> np.ndarray:
return self.nonlinear_operator(self.state.replace(spectrum))
class ERK54(ERK43):
def __iter__(self) -> Iterator[CurrentState]:
print("using ERK54")
h_next_step = self.state.current_step_size
k7 = self.nonlinear_operator(self.state)
while True:
lin = self.linear_operator(self.state)
self.record_tracked_values()
while True:
h = h_next_step
expD2 = np.exp(h * 0.5 * lin)
expD4p = np.exp(h * 0.25 * lin)
expD4m = 1 / expD4p
A_I = expD2 * self.state.solution.spectrum
k1 = expD2 * k7
k2 = self.nl(A_I + 0.5 * h * k1)
k3 = expD4p * self.nl(expD4m * (A_I + h / 16 * (3 * k1 + k2)))
k4 = self.nl(A_I + h / 4 * (k1 - k2 + 4 * k3))
k5 = expD4m * self.nl(expD4p * (A_I + 3 * h / 16 * (k1 + 3 * k4)))
k6 = self.nl(expD2 * (A_I + h / 7 * (-2 * k1 + k2 + 12 * k3 - 12 * k4 + 8 * k5)))
new_fine = (
expD2 * (A_I + h / 90 * (7 * k1 + 32 * k3 + 12 * k4 + 32 * k5))
+ 7 * h / 90 * k6
)
tmp_k7 = self.nl(new_fine)
new_coarse = (
expD2 * (A_I + h / 42 * (3 * k1 + 16 * k3 + 4 * k4 + 16 * k5)) + h / 14 * k7
)
self.current_error = np.sqrt(
self.dt * math.abs2(np.abs(new_fine) - np.abs(new_coarse)).sum()
)
self.next_h_factor = max(
0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25)
)
h_next_step = self.next_h_factor * h
if self.current_error <= self.tolerated_error:
break
self.state.current_step_size = h_next_step
self.state.previous_step_size = h
self.state.z += h
self.state.step += 1
self.state.solution = new_fine
k7 = tmp_k7
yield self.state