correction attempt on Step Doubling method

This commit is contained in:
Benoît Sierro
2021-11-09 18:18:42 +01:00
parent 65b42bf2ee
commit c83b4879fd
2 changed files with 79 additions and 52 deletions

View File

@@ -170,12 +170,11 @@ class RK4IP:
store = False
state = self.init_state.copy()
yield len(self.stored_spectra) - 1, state
integrator = solver.ERK54(
integrator = solver.RK4IPSD(
state,
self.params.linear_operator,
self.params.nonlinear_operator,
self.params.tolerated_error,
self.params.dt,
)
for state in integrator:

View File

@@ -117,11 +117,21 @@ class RK4IPStepTaker(StepTaker):
class Integrator(ValueTracker):
linear_operator: LinearOperator
nonlinear_operator: NonLinearOperator
state: CurrentState
tolerated_error: float
_tracked_values: dict[str, float]
def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator):
def __init__(
self,
init_state: CurrentState,
linear_operator: LinearOperator,
nonlinear_operator: NonLinearOperator,
tolerated_error: float,
):
self.state = init_state
self.linear_operator = linear_operator
self.nonlinear_operator = nonlinear_operator
self.tolerated_error = tolerated_error
self._tracked_values = {}
@abstractmethod
@@ -146,6 +156,9 @@ class Integrator(ValueTracker):
def record_tracked_values(self):
self._tracked_values = super().all_values()
def nl(self, spectrum: np.ndarray) -> np.ndarray:
return self.nonlinear_operator(self.state.replace(spectrum))
class ConstantStepIntegrator(Integrator):
def __call__(self, state: CurrentState) -> CurrentState:
@@ -218,58 +231,78 @@ class ConservedQuantityIntegrator(Integrator):
)
class LocalErrorIntegrator(Integrator):
step_taker: StepTaker
class RK4IPSD(Integrator):
"""Runge-Kutta 4 in Interaction Picture with step doubling"""
linear_operator: LinearOperator
nonlinear_operator: NonLinearOperator
tolerated_error: float
local_error: float
current_error: float
next_h_factor = 1.0
current_error = 0.0
def __init__(self, step_taker: StepTaker, tolerated_error: float, w_num: int):
self.tolerated_error = tolerated_error
self.local_error = 0.0
self.logger = get_logger(self.__class__.__name__)
self.size_fac, self.fine_fac, self.coarse_fac = 2.0 ** (1.0 / 5.0), 16 / 15, -1 / 15
self.step_taker = step_taker
def __call__(self, state: CurrentState) -> CurrentState:
keep = False
h_next_step = state.current_step_size
while not keep:
def __iter__(self) -> Iterator[CurrentState]:
h_next_step = self.state.current_step_size
size_fac = 2.0 ** (1.0 / 5.0)
while True:
lin = self.linear_operator(self.state)
nonlin = self.nonlinear_operator(self.state)
self.record_tracked_values()
while True:
h = h_next_step
h_half = h / 2
coarse_spec = self.step_taker(state, h)
new_fine_inter = self.take_step(h / 2, self.state.solution.spectrum, lin, nonlin)
new_fine_inter_state = self.state.replace(new_fine_inter)
new_fine = self.take_step(
h / 2,
new_fine_inter,
self.linear_operator(new_fine_inter_state),
self.nonlinear_operator(new_fine_inter_state),
)
new_coarse = self.take_step(h, self.state.solution.spectrum, lin, nonlin)
self.current_error = self.compute_diff(new_coarse, new_fine)
fine_spec1 = self.step_taker(state, h_half)
fine_state = state.replace(fine_spec1, z=state.z + h_half)
fine_spec = self.step_taker(fine_state, h_half)
delta = self.compute_diff(coarse_spec, fine_spec)
if delta > 2 * self.tolerated_error:
keep = False
h_next_step = h_half
elif self.tolerated_error <= delta <= 2 * self.tolerated_error:
keep = True
h_next_step = h / self.size_fac
elif 0.5 * self.tolerated_error <= delta < self.tolerated_error:
keep = True
if self.current_error > 2 * self.tolerated_error:
h_next_step = h * 0.5
elif self.tolerated_error <= self.current_error <= 2 * self.tolerated_error:
h_next_step = h / size_fac
break
elif 0.5 * self.tolerated_error <= self.current_error < self.tolerated_error:
h_next_step = h
break
else:
keep = True
h_next_step = h * self.size_fac
h_next_step = h * size_fac
break
self.local_error = delta
fine_state.solution = fine_spec * self.fine_fac + coarse_spec * self.coarse_fac
fine_state.current_step_size = h_next_step
fine_state.previous_step_size = h
fine_state.z += h
self.last_step = h
return fine_state
self.state.current_step_size = h_next_step
self.state.previous_step_size = h
self.state.z += h
self.state.step += 1
self.state.solution = new_fine
yield self.state
def take_step(
self, h: float, spec: np.ndarray, lin: np.ndarray, nonlin: np.ndarray
) -> np.ndarray:
expD = np.exp(h * 0.5 * lin)
A_I = expD * spec
k1 = expD * nonlin
k2 = self.nl(A_I + k1 * 0.5 * h)
k3 = self.nl(A_I + k2 * 0.5 * h)
k4 = self.nl(expD * (A_I + h * k3))
return expD * (A_I + h / 6 * (k1 + 2 * k2 + 2 * k3)) + h / 6 * k4
def compute_diff(self, coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float:
return np.sqrt(math.abs2(coarse_spec - fine_spec).sum() / math.abs2(fine_spec).sum())
def values(self) -> dict[str, float]:
return dict(relative_error=self.local_error, h=self.last_step)
return dict(
step=self.state.step,
z=self.state.z,
local_error=self.current_error,
next_h_factor=self.next_h_factor,
)
class ERK43(Integrator):
@@ -287,12 +320,12 @@ class ERK43(Integrator):
linear_operator: LinearOperator,
nonlinear_operator: NonLinearOperator,
tolerated_error: float,
dt: float,
dw: float,
):
self.state = init_state
self.linear_operator = linear_operator
self.nonlinear_operator = nonlinear_operator
self.dt = dt
self.dw = dw
self.tolerated_error = tolerated_error
self.current_error = 0.0
@@ -318,7 +351,7 @@ class ERK43(Integrator):
new_coarse = r + h / 30 * (2 * k4 + 3 * tmp_k5)
self.current_error = np.sqrt(self.dt * math.abs2(new_fine - new_coarse).sum())
self.current_error = np.sqrt(self.dw * math.abs2(new_fine - new_coarse).sum())
self.next_h_factor = max(
0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25)
)
@@ -341,9 +374,6 @@ class ERK43(Integrator):
next_h_factor=self.next_h_factor,
)
def nl(self, spectrum: np.ndarray) -> np.ndarray:
return self.nonlinear_operator(self.state.replace(spectrum))
class ERK54(ERK43):
def __iter__(self) -> Iterator[CurrentState]:
@@ -376,9 +406,7 @@ class ERK54(ERK43):
expD2 * (A_I + h / 42 * (3 * k1 + 16 * k3 + 4 * k4 + 16 * k5)) + h / 14 * k7
)
self.current_error = np.sqrt(
self.dt * math.abs2(np.abs(new_fine) - np.abs(new_coarse)).sum()
)
self.current_error = np.sqrt(self.dw * math.abs2(new_fine - new_coarse).sum())
self.next_h_factor = max(
0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25)
)