diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 7c683c4..160d732 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -170,12 +170,11 @@ class RK4IP: store = False state = self.init_state.copy() yield len(self.stored_spectra) - 1, state - integrator = solver.ERK54( + integrator = solver.RK4IPSD( state, self.params.linear_operator, self.params.nonlinear_operator, self.params.tolerated_error, - self.params.dt, ) for state in integrator: diff --git a/src/scgenerator/solver.py b/src/scgenerator/solver.py index a886e58..d973f3c 100644 --- a/src/scgenerator/solver.py +++ b/src/scgenerator/solver.py @@ -117,11 +117,21 @@ class RK4IPStepTaker(StepTaker): class Integrator(ValueTracker): linear_operator: LinearOperator nonlinear_operator: NonLinearOperator + state: CurrentState + tolerated_error: float _tracked_values: dict[str, float] - def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator): + def __init__( + self, + init_state: CurrentState, + linear_operator: LinearOperator, + nonlinear_operator: NonLinearOperator, + tolerated_error: float, + ): + self.state = init_state self.linear_operator = linear_operator self.nonlinear_operator = nonlinear_operator + self.tolerated_error = tolerated_error self._tracked_values = {} @abstractmethod @@ -146,6 +156,9 @@ class Integrator(ValueTracker): def record_tracked_values(self): self._tracked_values = super().all_values() + def nl(self, spectrum: np.ndarray) -> np.ndarray: + return self.nonlinear_operator(self.state.replace(spectrum)) + class ConstantStepIntegrator(Integrator): def __call__(self, state: CurrentState) -> CurrentState: @@ -218,58 +231,78 @@ class ConservedQuantityIntegrator(Integrator): ) -class LocalErrorIntegrator(Integrator): - step_taker: StepTaker +class RK4IPSD(Integrator): + """Runge-Kutta 4 in Interaction Picture with step doubling""" + + linear_operator: LinearOperator + nonlinear_operator: NonLinearOperator tolerated_error: float - local_error: float + current_error: float + next_h_factor = 1.0 + current_error = 0.0 - def __init__(self, step_taker: StepTaker, tolerated_error: float, w_num: int): - self.tolerated_error = tolerated_error - self.local_error = 0.0 - self.logger = get_logger(self.__class__.__name__) - self.size_fac, self.fine_fac, self.coarse_fac = 2.0 ** (1.0 / 5.0), 16 / 15, -1 / 15 - self.step_taker = step_taker + def __iter__(self) -> Iterator[CurrentState]: + h_next_step = self.state.current_step_size + size_fac = 2.0 ** (1.0 / 5.0) + while True: + lin = self.linear_operator(self.state) + nonlin = self.nonlinear_operator(self.state) + self.record_tracked_values() + while True: + h = h_next_step + new_fine_inter = self.take_step(h / 2, self.state.solution.spectrum, lin, nonlin) + new_fine_inter_state = self.state.replace(new_fine_inter) + new_fine = self.take_step( + h / 2, + new_fine_inter, + self.linear_operator(new_fine_inter_state), + self.nonlinear_operator(new_fine_inter_state), + ) + new_coarse = self.take_step(h, self.state.solution.spectrum, lin, nonlin) + self.current_error = self.compute_diff(new_coarse, new_fine) - def __call__(self, state: CurrentState) -> CurrentState: - keep = False - h_next_step = state.current_step_size - while not keep: - h = h_next_step - h_half = h / 2 - coarse_spec = self.step_taker(state, h) + if self.current_error > 2 * self.tolerated_error: + h_next_step = h * 0.5 + elif self.tolerated_error <= self.current_error <= 2 * self.tolerated_error: + h_next_step = h / size_fac + break + elif 0.5 * self.tolerated_error <= self.current_error < self.tolerated_error: + h_next_step = h + break + else: + h_next_step = h * size_fac + break - fine_spec1 = self.step_taker(state, h_half) - fine_state = state.replace(fine_spec1, z=state.z + h_half) - fine_spec = self.step_taker(fine_state, h_half) + self.state.current_step_size = h_next_step + self.state.previous_step_size = h + self.state.z += h + self.state.step += 1 + self.state.solution = new_fine + yield self.state - delta = self.compute_diff(coarse_spec, fine_spec) + def take_step( + self, h: float, spec: np.ndarray, lin: np.ndarray, nonlin: np.ndarray + ) -> np.ndarray: + expD = np.exp(h * 0.5 * lin) + A_I = expD * spec - if delta > 2 * self.tolerated_error: - keep = False - h_next_step = h_half - elif self.tolerated_error <= delta <= 2 * self.tolerated_error: - keep = True - h_next_step = h / self.size_fac - elif 0.5 * self.tolerated_error <= delta < self.tolerated_error: - keep = True - h_next_step = h - else: - keep = True - h_next_step = h * self.size_fac + k1 = expD * nonlin + k2 = self.nl(A_I + k1 * 0.5 * h) + k3 = self.nl(A_I + k2 * 0.5 * h) + k4 = self.nl(expD * (A_I + h * k3)) - self.local_error = delta - fine_state.solution = fine_spec * self.fine_fac + coarse_spec * self.coarse_fac - fine_state.current_step_size = h_next_step - fine_state.previous_step_size = h - fine_state.z += h - self.last_step = h - return fine_state + return expD * (A_I + h / 6 * (k1 + 2 * k2 + 2 * k3)) + h / 6 * k4 def compute_diff(self, coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float: return np.sqrt(math.abs2(coarse_spec - fine_spec).sum() / math.abs2(fine_spec).sum()) def values(self) -> dict[str, float]: - return dict(relative_error=self.local_error, h=self.last_step) + return dict( + step=self.state.step, + z=self.state.z, + local_error=self.current_error, + next_h_factor=self.next_h_factor, + ) class ERK43(Integrator): @@ -287,12 +320,12 @@ class ERK43(Integrator): linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator, tolerated_error: float, - dt: float, + dw: float, ): self.state = init_state self.linear_operator = linear_operator self.nonlinear_operator = nonlinear_operator - self.dt = dt + self.dw = dw self.tolerated_error = tolerated_error self.current_error = 0.0 @@ -318,7 +351,7 @@ class ERK43(Integrator): new_coarse = r + h / 30 * (2 * k4 + 3 * tmp_k5) - self.current_error = np.sqrt(self.dt * math.abs2(new_fine - new_coarse).sum()) + self.current_error = np.sqrt(self.dw * math.abs2(new_fine - new_coarse).sum()) self.next_h_factor = max( 0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25) ) @@ -341,9 +374,6 @@ class ERK43(Integrator): next_h_factor=self.next_h_factor, ) - def nl(self, spectrum: np.ndarray) -> np.ndarray: - return self.nonlinear_operator(self.state.replace(spectrum)) - class ERK54(ERK43): def __iter__(self) -> Iterator[CurrentState]: @@ -376,9 +406,7 @@ class ERK54(ERK43): expD2 * (A_I + h / 42 * (3 * k1 + 16 * k3 + 4 * k4 + 16 * k5)) + h / 14 * k7 ) - self.current_error = np.sqrt( - self.dt * math.abs2(np.abs(new_fine) - np.abs(new_coarse)).sum() - ) + self.current_error = np.sqrt(self.dw * math.abs2(new_fine - new_coarse).sum()) self.next_h_factor = max( 0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25) )