diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index c4a471c..3ed4a03 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Union import numpy as np -from . import math, operators, utils, solver +from . import math, operators, utils from .const import MANDATORY_PARAMETERS from .errors import EvaluatorError, NoDefaultError from .physics import fiber, materials, pulse, units @@ -324,6 +324,7 @@ default_rules: list[Rule] = [ Rule("L_D", pulse.L_D), Rule("L_NL", pulse.L_NL), Rule("L_sol", pulse.L_sol), + Rule("c_to_a_factor", lambda: 1.0, priorities=-1), # Fiber Dispersion Rule("w_for_disp", units.m, ["wl_for_disp"]), Rule("hr_w", fiber.delayed_raman_w), @@ -377,7 +378,6 @@ default_rules: list[Rule] = [ Rule("loss_op", operators.NoLoss, priorities=-1), Rule("plasma_op", operators.NoPlasma, priorities=-1), Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1), - Rule("integrator", solver.ERK54), ] envelope_rules = default_rules + [ @@ -387,6 +387,7 @@ envelope_rules = default_rules + [ Rule("pre_field_0", pulse.initial_field_envelope, priorities=1), Rule("spec_0", np.fft.fft, ["field_0"]), Rule("field_0", np.fft.ifft, ["spec_0"]), + Rule("c_to_a_factor", pulse.c_to_a_factor), # Dispersion Rule(["wl_for_disp", "dispersion_ind"], fiber.lambda_for_envelope_dispersion), Rule("beta2_coefficients", fiber.dispersion_coefficients), @@ -418,7 +419,6 @@ envelope_rules = default_rules + [ Rule("dispersion_op", operators.DirectDispersion), Rule("linear_operator", operators.EnvelopeLinearOperator), Rule("conserved_quantity", operators.conserved_quantity), - Rule("integrator", solver.ConservedQuantityIntegrator), ] full_field_rules = default_rules + [ @@ -442,6 +442,4 @@ full_field_rules = default_rules + [ operators.FullFieldLinearOperator, ), Rule("nonlinear_operator", operators.FullFieldNonLinearOperator), - # Integration - Rule("integrator", solver.LocalErrorIntegrator), ] diff --git a/src/scgenerator/operators.py b/src/scgenerator/operators.py index 79cf2b7..9f40022 100644 --- a/src/scgenerator/operators.py +++ b/src/scgenerator/operators.py @@ -52,6 +52,24 @@ class SpectrumDescriptor: self.__field2 = math.abs2(self.field) return self.__field2 + def force_values(self, spec2: np.ndarray, field: np.ndarray, field2: np.ndarray): + """force these values instead of recomputing them + + Parameters + ---------- + spectrum : np.ndarray + spectrum + spec2 : np.ndarray + |spectrum|^2 + field : np.ndarray + field = converter(spectrum) + field2 : np.ndarray + |field|^2 + """ + self.__spec2 = spec2 + self.__field = field + self.__field2 = field2 + def __delete__(self, instance): raise AttributeError("Cannot delete Spectrum field") @@ -64,7 +82,6 @@ class CurrentState: length: float z: float current_step_size: float - previous_step_size: float step: int C_to_A_factor: np.ndarray converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft @@ -74,19 +91,33 @@ class CurrentState: def z_ratio(self) -> float: return self.z / self.length - def replace(self, new_spectrum: np.ndarray, **new_params) -> CurrentState: + def replace(self, new_spectrum: np.ndarray) -> CurrentState: """returns a new state with new attributes""" - params = dict( - solution=new_spectrum, + return CurrentState( + self.length, + self.z, + self.current_step_size, + self.step, + self.C_to_A_factor, + self.converter, + new_spectrum, + ) + + def with_params(self, **params) -> CurrentState: + """returns a new CurrentState with modified params, except for the solution""" + my_params = dict( length=self.length, z=self.z, current_step_size=self.current_step_size, - previous_step_size=self.previous_step_size, step=self.step, C_to_A_factor=self.C_to_A_factor, converter=self.converter, ) - return CurrentState(**(params | new_params)) + new_state = CurrentState(solution=self.solution.spectrum, **(my_params | params)) + new_state.solution.force_values( + self.solution.spec2, self.solution.field, self.solution.field2 + ) + return new_state def copy(self) -> CurrentState: return replace(self, solution=self.solution.spectrum) diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index cb13b14..9b3b8f1 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -6,7 +6,7 @@ import os import time from copy import copy from dataclasses import dataclass, field, fields -from functools import lru_cache +from functools import lru_cache, wraps from math import isnan from pathlib import Path from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeVar, Union @@ -19,7 +19,7 @@ from .errors import EvaluatorError from .evaluator import Evaluator from .logger import get_logger from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator -from .solver import Integrator, StepTaker +from .solver import Integrator from .utils import fiber_folder, update_path_name from .variationer import VariationDescriptor, Variationer @@ -37,6 +37,7 @@ def type_checker(*types): def validator(*args): pass + @wraps(validator) def _type_checker_wrapped(name, n): if not isinstance(n, types): raise TypeError( @@ -377,7 +378,6 @@ class Parameters: # computed linear_operator: LinearOperator = Parameter(type_checker(LinearOperator)) nonlinear_operator: NonLinearOperator = Parameter(type_checker(NonLinearOperator)) - step_taker: StepTaker = Parameter(type_checker(StepTaker)) integrator: Integrator = Parameter(type_checker(Integrator)) conserved_quantity: AbstractConservedQuantity = Parameter( type_checker(AbstractConservedQuantity) @@ -391,6 +391,7 @@ class Parameters: alpha: float = Parameter(non_negative(float, int)) gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray)) A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray)) + c_to_a_factor: np.ndarray = Parameter(type_checker(float, np.ndarray)) w: np.ndarray = Parameter(type_checker(np.ndarray)) l: np.ndarray = Parameter(type_checker(np.ndarray)) w_c: np.ndarray = Parameter(type_checker(np.ndarray)) diff --git a/src/scgenerator/physics/__init__.py b/src/scgenerator/physics/__init__.py index f59df64..f7008bf 100644 --- a/src/scgenerator/physics/__init__.py +++ b/src/scgenerator/physics/__init__.py @@ -104,8 +104,11 @@ def find_optimal_depth( ind = w > (w0 / 10) disp[ind] = material_dispersion(units.m.inv(w[ind]), material) - propagate = lambda z: spectrum * np.exp(-0.5j * disp * w_c ** 2 * z) - integrate = lambda z: math.abs2(np.fft.ifft(propagate(z))) + def propagate(z): + return spectrum * np.exp(-0.5j * disp * w_c ** 2 * z) + + def integrate(z): + return math.abs2(np.fft.ifft(propagate(z))) def score(z): return -np.nansum(integrate(z) ** 6) diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index f097887..7bea7e3 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -1127,3 +1127,7 @@ def capillary_loss(wl: np.ndarray, he_mode: tuple[int, int], core_radius: float) def extinction_distance(loss: T, ratio=1 / e) -> T: return np.log(ratio) / -loss + + +def L_eff(loss: T, length: float) -> T: + return -np.expm1(-loss * length) / loss diff --git a/src/scgenerator/physics/materials.py b/src/scgenerator/physics/materials.py index cda8c74..8d8e073 100644 --- a/src/scgenerator/physics/materials.py +++ b/src/scgenerator/physics/materials.py @@ -1,3 +1,4 @@ +import functools from typing import Any import numpy as np @@ -110,15 +111,7 @@ def number_density_van_der_waals( return np.min(roots) -def sellmeier_scalar( - wavelength: float, - material_dico: dict[str, Any], - pressure: float = None, - temperature: float = None, -) -> float: - return float(sellmeier(np.array([wavelength]), material_dico, pressure, temperature)[0]) - - +@functools.singledispatch def sellmeier( wl_for_disp: np.ndarray, material_dico: dict[str, Any], @@ -187,6 +180,17 @@ def sellmeier( return chi +@sellmeier.register +def sellmeier_scalar( + wavelength: float, + material_dico: dict[str, Any], + pressure: float = None, + temperature: float = None, +) -> float: + """n^2 - 1""" + return float(sellmeier(np.array([wavelength]), material_dico, pressure, temperature)[0]) + + def delta_gas(w, material_dico): """returns the value delta_t (eq. 24 in Markos(2017)) Parameters diff --git a/src/scgenerator/physics/plasma.py b/src/scgenerator/physics/plasma.py index cad4f09..155b0d8 100644 --- a/src/scgenerator/physics/plasma.py +++ b/src/scgenerator/physics/plasma.py @@ -64,8 +64,7 @@ class Plasma: field_abs = np.abs(field) delta = 1e-14 * field_abs.max() rate = self.rate(field_abs) - exp_int = expm1_int(rate, self.dt) - electron_density = N0 * exp_int + electron_density = free_electron_density(rate, self.dt, N0) dn_dt = (N0 - electron_density) * rate out = self.dt * cumulative_simpson( dn_dt * self.Ip / (field + delta) @@ -79,7 +78,5 @@ def adiabadicity(w: np.ndarray, I: float, field: np.ndarray) -> np.ndarray: return w * np.sqrt(2 * me * I) / (e * np.abs(field)) -def free_electron_density( - field: np.ndarray, dt: float, N0: float, rate: IonizationRate -) -> np.ndarray: - return N0 * (1 - np.exp(-dt * cumulative_simpson(rate(field)))) +def free_electron_density(rate: np.ndarray, dt: float, N0: float) -> np.ndarray: + return N0 * expm1_int(rate, dt) diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index c3f0099..d3eff47 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -215,6 +215,10 @@ def convert_field_units(envelope: np.ndarray, n: np.ndarray, A_eff: float) -> np return 2 * envelope.real / np.sqrt(2 * units.epsilon0 * units.c * n * A_eff) +def c_to_a_factor(A_eff_arr: np.ndarray) -> np.ndarray: + return (A_eff_arr / A_eff_arr[0]) ** (1 / 4) + + def conform_pulse_params( shape: Literal["gaussian", "sech"], width: float = None, diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 160d732..57211af 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -85,12 +85,6 @@ class RK4IP: self.z_targets.sort() self.store_num = len(self.z_targets) - # Setup initial values for every physical quantity that we want to track - if self.params.A_eff_arr is not None: - C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4) - else: - C_to_A_factor = 1.0 - # Initial step size if self.params.adapt_step_size: initial_h = (self.z_targets[1] - self.z_targets[0]) / 2 @@ -100,11 +94,10 @@ class RK4IP: length=self.params.length, z=self.z_targets.pop(0), current_step_size=initial_h, - previous_step_size=0.0, - step=1, - C_to_A_factor=C_to_A_factor, + step=0, + C_to_A_factor=self.params.c_to_a_factor, converter=self.params.ifft, - solution=self.params.spec_0.copy() / C_to_A_factor, + solution=self.params.spec_0.copy() / self.params.c_to_a_factor, ) self.stored_spectra = self.params.recovery_last_stored * [None] + [ self.init_state.solution.spectrum.copy() @@ -170,12 +163,18 @@ class RK4IP: store = False state = self.init_state.copy() yield len(self.stored_spectra) - 1, state - integrator = solver.RK4IPSD( - state, - self.params.linear_operator, - self.params.nonlinear_operator, - self.params.tolerated_error, - ) + if self.params.adapt_step_size: + integrator = solver.ConservedQuantityIntegrator( + self.init_state, + self.params.linear_operator, + self.params.nonlinear_operator, + self.params.tolerated_error, + self.params.conserved_quantity, + ) + else: + integrator = solver.ConstantStepIntegrator( + self.init_state, self.params.linear_operator, self.params.nonlinear_operator + ) for state in integrator: new_tracked_values = integrator.all_values() diff --git a/src/scgenerator/solver.py b/src/scgenerator/solver.py index d973f3c..076a3db 100644 --- a/src/scgenerator/solver.py +++ b/src/scgenerator/solver.py @@ -1,5 +1,8 @@ +from __future__ import annotations + +import logging from abc import abstractmethod -from typing import Iterator +from typing import Iterator, Type import numpy as np @@ -13,106 +16,6 @@ from .operators import ( ValueTracker, ) -################################################## -################### STEP-TAKER ################### -################################################## - - -class StepTaker(ValueTracker): - linear_operator: LinearOperator - nonlinear_operator: NonLinearOperator - _tracked_values: dict[str, float] - - def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator): - self.linear_operator = linear_operator - self.nonlinear_operator = nonlinear_operator - self._tracked_values = {} - - @abstractmethod - def __call__(self, state: CurrentState, step_size: float) -> np.ndarray: - ... - - def all_values(self) -> dict[str, float]: - """override ValueTracker.all_values to account for the fact that operators are called - multiple times per step, sometimes with different state, so we use value recorded - earlier. Please call self.recorde_tracked_values() one time only just after calling - the linear and nonlinear operators in your StepTaker. - - Returns - ------- - dict[str, float] - tracked values - """ - return self.values() | self._tracked_values - - def record_tracked_values(self): - self._tracked_values = super().all_values() - - -class RK4IPStepTaker(StepTaker): - c2 = 1 / 2 - c3 = 1 / 3 - c6 = 1 / 6 - _cached_values: tuple[np.ndarray, np.ndarray] - _cached_key: float - _cache_hits: int - _cache_misses: int - - def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator): - super().__init__(linear_operator, nonlinear_operator) - self._cached_key = None - self._cached_values = None - self._cache_hits = 0 - self._cache_misses = 0 - - def __call__(self, state: CurrentState, step_size: float) -> np.ndarray: - h = step_size - l0, nl0 = self.cached_values(state) - expD = np.exp(h * self.c2 * l0) - - A_I = expD * state.solution - k1 = expD * (h * nl0) - k2 = h * self.nonlinear_operator(state.replace(A_I + k1 * self.c2)) - k3 = h * self.nonlinear_operator(state.replace(A_I + k2 * self.c2)) - k4 = h * self.nonlinear_operator(state.replace(expD * (A_I + k3))) - - return expD * (A_I + k1 * self.c6 + k2 * self.c3 + k3 * self.c3) + k4 * self.c6 - - def cached_values(self, state: CurrentState) -> tuple[np.ndarray, np.ndarray]: - """the evaluation of the linear and nonlinear operators at the start of the step don't - depend on the step size, so we cache them in case we need them more than once (which - can happen depending on the adaptive step size controller) - - - Parameters - ---------- - state : CurrentState - current state of the simulation. state.z is used as the key for the cache - - Returns - ------- - np.ndarray - result of the linear operator - np.ndarray - result of the nonlinear operator - """ - if self._cached_key != state.z: - self._cache_misses += 1 - self._cached_key = state.z - self._cached_values = self.linear_operator(state), self.nonlinear_operator(state) - self.record_tracked_values() - else: - self._cache_hits += 1 - return self._cached_values - - def values(self) -> dict[str, float]: - return dict(RK4IP_cache_hits=self._cache_hits, RK4IP_cache_misses=self._cache_misses) - - -################################################## -################### INTEGRATOR ################### -################################################## - class Integrator(ValueTracker): linear_operator: LinearOperator @@ -120,6 +23,8 @@ class Integrator(ValueTracker): state: CurrentState tolerated_error: float _tracked_values: dict[str, float] + logger: logging.Logger + __registry: dict[str, Type[Integrator]] = {} def __init__( self, @@ -133,11 +38,19 @@ class Integrator(ValueTracker): self.nonlinear_operator = nonlinear_operator self.tolerated_error = tolerated_error self._tracked_values = {} + self.logger = get_logger(self.__class__.__name__) + + def __init_subclass__(cls): + cls.__registry[cls.__name__] = cls + + @classmethod + def get(cls, integr: str) -> Type[Integrator]: + return cls.__registry[integr] @abstractmethod def __iter__(self) -> Iterator[CurrentState]: """propagate the state with a step size of state.current_step_size - and yield a new state with updated z and previous_step_size attributes""" + and yield a new state with updated z and step attributes""" ... def all_values(self) -> dict[str, float]: @@ -151,7 +64,7 @@ class Integrator(ValueTracker): dict[str, float] tracked values """ - return self.values() | self._tracked_values + return self.values() | self._tracked_values | dict(z=self.state.z, step=self.state.step) def record_tracked_values(self): self._tracked_values = super().all_values() @@ -161,85 +74,104 @@ class Integrator(ValueTracker): class ConstantStepIntegrator(Integrator): - def __call__(self, state: CurrentState) -> CurrentState: - new_state = state.replace(self.step_taker(state, state.current_step_size)) - new_state.z += new_state.current_step_size - new_state.previous_step_size = new_state.current_step_size - return new_state + def __init__( + self, + init_state: CurrentState, + linear_operator: LinearOperator, + nonlinear_operator: NonLinearOperator, + ): + super().__init__(init_state, linear_operator, nonlinear_operator, 0.0) - def values(self) -> dict[str, float]: - return dict(h=self.last_step) + def __iter__(self) -> Iterator[CurrentState]: + while True: + lin = self.linear_operator(self.state) + nonlin = self.nonlinear_operator(self.state) + self.record_tracked_values() + new_spec = RK4IP_step( + self.nonlinear_operator, + self.state, + self.state.solution.spectrum, + self.state.current_step_size, + lin, + nonlin, + ) + + self.state.z += self.state.current_step_size + self.state.step += 1 + self.state.solution = new_spec + yield self.state class ConservedQuantityIntegrator(Integrator): - step_taker: StepTaker + last_qty: float conserved_quantity: AbstractConservedQuantity - last_quantity_value: float - tolerated_error: float - local_error: float = 0.0 + current_error: float = 0.0 def __init__( self, - step_taker: StepTaker, - conserved_quantity: AbstractConservedQuantity, + init_state: CurrentState, + linear_operator: LinearOperator, + nonlinear_operator: NonLinearOperator, tolerated_error: float, + conserved_quantity: AbstractConservedQuantity, ): + super().__init__(init_state, linear_operator, nonlinear_operator, tolerated_error) self.conserved_quantity = conserved_quantity - self.last_quantity_value = 0 - self.tolerated_error = tolerated_error - self.logger = get_logger(self.__class__.__name__) - self.size_fac = 2.0 ** (1.0 / 5.0) - self.step_taker = step_taker + self.last_qty = self.conserved_quantity(self.state) - def __call__(self, state: CurrentState) -> CurrentState: - keep = False - h_next_step = state.current_step_size - while not keep: - h = h_next_step + def __iter__(self) -> Iterator[CurrentState]: + h_next_step = self.state.current_step_size + size_fac = 2.0 ** (1.0 / 5.0) + while True: + lin = self.linear_operator(self.state) + nonlin = self.nonlinear_operator(self.state) + self.record_tracked_values() + while True: + h = h_next_step + new_state = self.state.replace( + RK4IP_step( + self.nonlinear_operator, + self.state, + self.state.solution.spectrum, + h, + lin, + nonlin, + ) + ) - new_state = state.replace(self.step_taker(state, h)) + new_qty = self.conserved_quantity(new_state) + self.current_error = np.abs(new_qty - self.last_qty) / self.last_qty - new_qty = self.conserved_quantity(new_state) - delta = np.abs(new_qty - self.last_quantity_value) / self.last_quantity_value - - if delta > 2 * self.tolerated_error: - progress_str = f"step {state.step} rejected with h = {h:.4e}, doing over" - self.logger.info(progress_str) - keep = False - h_next_step = h * 0.5 - elif self.tolerated_error < delta <= 2.0 * self.tolerated_error: - keep = True - h_next_step = h / self.size_fac - elif delta < 0.1 * self.tolerated_error: - keep = True - h_next_step = h * self.size_fac - else: - keep = True - h_next_step = h - - self.local_error = delta - self.last_quantity_value = new_qty - new_state.current_step_size = h_next_step - new_state.previous_step_size = h - new_state.z += h - self.last_step = h - return new_state + if self.current_error > 2 * self.tolerated_error: + h_next_step = h * 0.5 + elif self.tolerated_error < self.current_error <= 2.0 * self.tolerated_error: + h_next_step = h / size_fac + break + elif self.current_error < 0.1 * self.tolerated_error: + h_next_step = h * size_fac + break + else: + h_next_step = h + break + self.logger.info( + f"step {new_state.step} rejected : {h=}, {self.current_error=}, {h_next_step=}" + ) + self.last_qty = new_qty + self.state = new_state + self.state.current_step_size = h_next_step + self.state.z += h + self.state.step += 1 + yield self.state def values(self) -> dict[str, float]: - return dict( - cons_qty=self.last_quantity_value, h=self.last_step, relative_error=self.local_error - ) + return dict(cons_qty=self.last_qty, relative_error=self.current_error) class RK4IPSD(Integrator): """Runge-Kutta 4 in Interaction Picture with step doubling""" - linear_operator: LinearOperator - nonlinear_operator: NonLinearOperator - tolerated_error: float - current_error: float - next_h_factor = 1.0 - current_error = 0.0 + next_h_factor: float = 1.0 + current_error: float = 0.0 def __iter__(self) -> Iterator[CurrentState]: h_next_step = self.state.current_step_size @@ -274,7 +206,6 @@ class RK4IPSD(Integrator): break self.state.current_step_size = h_next_step - self.state.previous_step_size = h self.state.z += h self.state.step += 1 self.state.solution = new_fine @@ -283,15 +214,7 @@ class RK4IPSD(Integrator): def take_step( self, h: float, spec: np.ndarray, lin: np.ndarray, nonlin: np.ndarray ) -> np.ndarray: - expD = np.exp(h * 0.5 * lin) - A_I = expD * spec - - k1 = expD * nonlin - k2 = self.nl(A_I + k1 * 0.5 * h) - k3 = self.nl(A_I + k2 * 0.5 * h) - k4 = self.nl(expD * (A_I + h * k3)) - - return expD * (A_I + h / 6 * (k1 + 2 * k2 + 2 * k3)) + h / 6 * k4 + return RK4IP_step(self.nonlinear_operator, self.state, spec, h, lin, nonlin) def compute_diff(self, coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float: return np.sqrt(math.abs2(coarse_spec - fine_spec).sum() / math.abs2(fine_spec).sum()) @@ -358,8 +281,11 @@ class ERK43(Integrator): h_next_step = self.next_h_factor * h if self.current_error <= self.tolerated_error: break + self.logger.info( + f"step {self.state.step} rejected : {h=}, {self.current_error=}, {h_next_step=}" + ) + self.state.current_step_size = h_next_step - self.state.previous_step_size = h self.state.z += h self.state.step += 1 self.state.solution = new_fine @@ -377,7 +303,7 @@ class ERK43(Integrator): class ERK54(ERK43): def __iter__(self) -> Iterator[CurrentState]: - print("using ERK54") + self.logger.info("using ERK54") h_next_step = self.state.current_step_size k7 = self.nonlinear_operator(self.state) while True: @@ -413,10 +339,51 @@ class ERK54(ERK43): h_next_step = self.next_h_factor * h if self.current_error <= self.tolerated_error: break + self.logger.info( + f"step {self.state.step} rejected : {h=}, {self.current_error=}, {h_next_step=}" + ) self.state.current_step_size = h_next_step - self.state.previous_step_size = h self.state.z += h self.state.step += 1 self.state.solution = new_fine k7 = tmp_k7 yield self.state + + +def RK4IP_step( + nonlinear_operator: NonLinearOperator, + init_state: CurrentState, + spectrum: np.ndarray, + h: float, + init_linear: np.ndarray, + init_nonlinear: np.ndarray, +) -> np.ndarray: + """Take a normal RK4IP step + + Parameters + ---------- + nonlinear_operator : NonLinearOperator + non linear operator + init_state : CurrentState + state at the start of the step + h : float + step size + init_linear : np.ndarray + linear operator already applied on the initial state + init_nonlinear : np.ndarray + nonlinear operator already applied on the initial state + + Returns + ------- + np.ndarray + resutling spectrum + """ + expD = np.exp(h * 0.5 * init_linear) + A_I = expD * spectrum + + k1 = expD * init_nonlinear + k2 = nonlinear_operator(init_state.replace(A_I + k1 * 0.5 * h)) + k3 = nonlinear_operator(init_state.replace(A_I + k2 * 0.5 * h)) + k4 = nonlinear_operator(init_state.replace(expD * (A_I + h * k3))) + + return expD * (A_I + h / 6 * (k1 + 2 * k2 + 2 * k3)) + h / 6 * k4 diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index eb998c7..1c7d38e 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -315,7 +315,7 @@ def ensure_folder(path: Path, prevent_overwrite: bool = True, mkdir=True) -> Pat def branch_id(branch: Path) -> tuple[int, int]: - sim_match = branch.parent.name.split()[0] + sim_match = branch.resolve().parent.name.split()[0] if sim_match.isdigit(): s_int = int(sim_match) else: