diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index 3ed4a03..a34449e 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -5,6 +5,8 @@ from typing import Any, Callable, Optional, Union import numpy as np +from scgenerator import solver + from . import math, operators, utils from .const import MANDATORY_PARAMETERS from .errors import EvaluatorError, NoDefaultError @@ -377,7 +379,17 @@ default_rules: list[Rule] = [ Rule("raman_op", operators.NoRaman, priorities=-1), Rule("loss_op", operators.NoLoss, priorities=-1), Rule("plasma_op", operators.NoPlasma, priorities=-1), - Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1), + # solvers + Rule("integrator", solver.ConstantStepIntegrator, conditions=dict(adapt_step_size=False)), + Rule( + "integrator", + solver.ConservedQuantityIntegrator, + conditions=dict(adapt_step_size=True), + priorities=1, + ), + Rule("integrator", solver.RK4IPSD, conditions=dict(adapt_step_size=True)), + Rule("integrator", solver.ERK43, conditions=dict(adapt_step_size=True)), + Rule("integrator", solver.ERK54, conditions=dict(adapt_step_size=True), priorities=1), ] envelope_rules = default_rules + [ diff --git a/src/scgenerator/solver.py b/src/scgenerator/solver.py index 9ca8987..83408ab 100644 --- a/src/scgenerator/solver.py +++ b/src/scgenerator/solver.py @@ -72,6 +72,16 @@ class Integrator(ValueTracker): def nl(self, spectrum: np.ndarray) -> np.ndarray: return self.nonlinear_operator(self.state.replace(spectrum)) + def accept_step( + self, new_state: CurrentState, previous_step_size: float, next_step_size: float + ) -> CurrentState: + self.state = new_state + self.state.current_step_size = next_step_size + self.state.z += previous_step_size + self.state.step += 1 + self.logger.debug(f"accepted step {self.state.step} with h={previous_step_size}") + return self.state + class ConstantStepIntegrator(Integrator): def __init__( @@ -87,7 +97,7 @@ class ConstantStepIntegrator(Integrator): lin = self.linear_operator(self.state) nonlin = self.nonlinear_operator(self.state) self.record_tracked_values() - new_spec = rk4ip_step( + self.state.spectrum = rk4ip_step( self.nonlinear_operator, self.state, self.state.spectrum, @@ -95,11 +105,11 @@ class ConstantStepIntegrator(Integrator): lin, nonlin, ) - - self.state.z += self.state.current_step_size - self.state.step += 1 - self.state = new_spec - yield self.state + yield self.accept_step( + self.state, + self.state.current_step_size, + self.state.current_step_size, + ) class ConservedQuantityIntegrator(Integrator): @@ -156,11 +166,7 @@ class ConservedQuantityIntegrator(Integrator): f"step {new_state.step} rejected : {h=}, {self.current_error=}, {h_next_step=}" ) self.last_qty = new_qty - self.state = new_state - self.state.current_step_size = h_next_step - self.state.z += h - self.state.step += 1 - yield self.state + yield self.accept_step(new_state, h, h_next_step) def values(self) -> dict[str, float]: return dict(cons_qty=self.last_qty, relative_error=self.current_error) @@ -204,11 +210,8 @@ class RK4IPSD(Integrator): h_next_step = h * size_fac break - self.state.current_step_size = h_next_step - self.state.z += h - self.state.step += 1 - self.state = new_fine - yield self.state + self.state.spectrum = new_fine + yield self.accept_step(self.state, h, h_next_step) def take_step( self, h: float, spec: np.ndarray, lin: np.ndarray, nonlin: np.ndarray @@ -284,12 +287,9 @@ class ERK43(Integrator): f"step {self.state.step} rejected : {h=}, {self.current_error=}, {h_next_step=}" ) - self.state.current_step_size = h_next_step - self.state.z += h - self.state.step += 1 - self.state = new_fine k5 = tmp_k5 - yield self.state + self.state.spectrum = new_fine + yield self.accept_step(self.state, h, h_next_step) def values(self) -> dict[str, float]: return dict( @@ -341,12 +341,9 @@ class ERK54(ERK43): self.logger.info( f"step {self.state.step} rejected : {h=}, {self.current_error=}, {h_next_step=}" ) - self.state.current_step_size = h_next_step - self.state.z += h - self.state.step += 1 - self.state = new_fine k7 = tmp_k7 - yield self.state + self.state.spectrum = new_fine + yield self.accept_step(self.state, h, h_next_step) def rk4ip_step(