diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index 85a14b5..076dc74 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Union import numpy as np -from . import math, operators, utils +from . import math, operators, utils, solver from .const import MANDATORY_PARAMETERS from .errors import EvaluatorError, NoDefaultError from .physics import fiber, materials, pulse, units @@ -377,6 +377,8 @@ default_rules: list[Rule] = [ Rule("loss_op", operators.NoLoss, priorities=-1), Rule("plasma_op", operators.NoPlasma, priorities=-1), Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1), + Rule("step_taker", solver.RK4IPStepTaker), + Rule("integrator", solver.ConstantStepIntegrator, priorities=-1), ] envelope_rules = default_rules + [ @@ -417,6 +419,7 @@ envelope_rules = default_rules + [ Rule("dispersion_op", operators.DirectDispersion), Rule("linear_operator", operators.EnvelopeLinearOperator), Rule("conserved_quantity", operators.conserved_quantity), + Rule("integrator", solver.ConservedQuantityIntegrator), ] full_field_rules = default_rules + [ @@ -440,4 +443,6 @@ full_field_rules = default_rules + [ operators.FullFieldLinearOperator, ), Rule("nonlinear_operator", operators.FullFieldNonLinearOperator), + # Integration + Rule("integrator", solver.LocalErrorIntegrator), ] diff --git a/src/scgenerator/operators.py b/src/scgenerator/operators.py index eb80db3..c3174ba 100644 --- a/src/scgenerator/operators.py +++ b/src/scgenerator/operators.py @@ -7,6 +7,7 @@ from __future__ import annotations import dataclasses from abc import ABC, abstractmethod from dataclasses import dataclass +import re from typing import Any, Callable import numpy as np @@ -23,14 +24,18 @@ class SpectrumDescriptor: name: str value: np.ndarray = None _counter = 0 - _full_field: bool = False _converter: Callable[[np.ndarray], np.ndarray] + def __init__(self, spec2_name: str, field_name: str, field2_name: str): + self.spec2_name = spec2_name + self.field_name = field_name + self.field2_name = field2_name + def __set__(self, instance: CurrentState, value: np.ndarray): self._counter += 1 - instance.spec2 = math.abs2(value) - instance.field = instance.converter(value) - instance.field2 = math.abs2(instance.field) + setattr(instance, self.spec2_name, math.abs2(value)) + setattr(instance, self.field_name, instance.converter(value)) + setattr(instance, self.field2_name, math.abs2(getattr(instance, self.field_name))) self.value = value def __get__(self, instance, owner): @@ -46,41 +51,96 @@ class SpectrumDescriptor: self.name = name +class SpectrumDescriptor2: + name: str + spectrum: np.ndarray = None + __spec2: np.ndarray = None + __field: np.ndarray = None + __field2: np.ndarray = None + _converter: Callable[[np.ndarray], np.ndarray] + + def __set__(self, instance: CurrentState, value: np.ndarray): + self._converter = instance.converter + self.spectrum = value + self.__spec2 = None + self.__field = None + self.__field2 = None + + @property + def spec2(self) -> np.ndarray: + if self.__spec2 is None: + self.__spec2 = math.abs2(self.spectrum) + return self.__spec2 + + @property + def field(self) -> np.ndarray: + if self.__field is None: + self.__field = self._converter(self.spectrum) + return self.__field + + @property + def field2(self) -> np.ndarray: + if self.__field2 is None: + self.__field2 = math.abs2(self.field) + return self.__field2 + + def __delete__(self, instance): + raise AttributeError("Cannot delete Spectrum field") + + def __set_name__(self, owner, name): + self.name = name + + @dataclasses.dataclass class CurrentState: length: float z: float - h: float + current_step_size: float + previous_step_size: float + step: int C_to_A_factor: np.ndarray converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft - spectrum: np.ndarray = SpectrumDescriptor() + spectrum: np.ndarray = SpectrumDescriptor("spec2", "field", "field2") spec2: np.ndarray = dataclasses.field(init=False) field: np.ndarray = dataclasses.field(init=False) field2: np.ndarray = dataclasses.field(init=False) + prev_spectrum: np.ndarray = SpectrumDescriptor("prev_spec2", "prev_field", "prev_field2") + prev_spec2: np.ndarray = dataclasses.field(init=False) + prev_field: np.ndarray = dataclasses.field(init=False) + prev_field2: np.ndarray = dataclasses.field(init=False) @property def z_ratio(self) -> float: return self.z / self.length - def replace(self, new_spectrum) -> CurrentState: - return CurrentState( - self.length, self.z, self.h, self.C_to_A_factor, self.converter, new_spectrum + def replace(self, new_spectrum: np.ndarray, **new_params) -> CurrentState: + """returns a new state with new attributes""" + params = dict( + spectrum=new_spectrum, + length=self.length, + z=self.z, + current_step_size=self.current_step_size, + previous_step_size=self.previous_step_size, + step=self.step, + C_to_A_factor=self.C_to_A_factor, + converter=self.converter, ) + return CurrentState(**(params | new_params)) -class Operator(ABC): +class ValueTracker(ABC): def values(self) -> dict[str, float]: return {} - def get_values(self) -> dict[str, float]: + def all_values(self) -> dict[str, float]: out = self.values() - for operator in self.__dict__.values(): - if isinstance(operator, Operator): - out |= operator.get_values() + for operator in vars(self).values(): + if isinstance(operator, ValueTracker): + out = operator.all_values() | out return out def __repr__(self) -> str: - value_pair_list = list(self.__dict__.items()) + value_pair_list = list(vars(self).items()) if len(value_pair_list) == 0: value_pair_str_list = "" elif len(value_pair_list) == 1: @@ -95,6 +155,8 @@ class Operator(ABC): return repr(v[0]) return repr(v) + +class Operator(ValueTracker): @abstractmethod def __call__(self, state: CurrentState) -> np.ndarray: pass @@ -757,7 +819,12 @@ class PhotonNumberLoss(AbstractConservedQuantity): def __call__(self, state: CurrentState) -> float: return pulse.photon_number_with_loss( - state.spec2, self.w, self.dw, self.gamma_op(state), self.loss_op(state), state.h + state.spec2, + self.w, + self.dw, + self.gamma_op(state), + self.loss_op(state), + state.current_step_size, ) @@ -778,7 +845,10 @@ class EnergyLoss(AbstractConservedQuantity): def __call__(self, state: CurrentState) -> float: return pulse.pulse_energy_with_loss( - math.abs2(state.C_to_A_factor * state.spectrum), self.dw, self.loss_op(state), state.h + math.abs2(state.C_to_A_factor * state.spectrum), + self.dw, + self.loss_op(state), + state.current_step_size, ) diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index e468462..cb13b14 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -13,18 +13,13 @@ from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeV import numpy as np -from scgenerator.physics import units - from . import env, legacy, utils from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__ from .errors import EvaluatorError from .evaluator import Evaluator from .logger import get_logger -from .operators import ( - AbstractConservedQuantity, - LinearOperator, - NonLinearOperator, -) +from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator +from .solver import Integrator, StepTaker from .utils import fiber_folder, update_path_name from .variationer import VariationDescriptor, Variationer @@ -382,6 +377,8 @@ class Parameters: # computed linear_operator: LinearOperator = Parameter(type_checker(LinearOperator)) nonlinear_operator: NonLinearOperator = Parameter(type_checker(NonLinearOperator)) + step_taker: StepTaker = Parameter(type_checker(StepTaker)) + integrator: Integrator = Parameter(type_checker(Integrator)) conserved_quantity: AbstractConservedQuantity = Parameter( type_checker(AbstractConservedQuantity) ) diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index b8fc1fb..1450d3f 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -1,3 +1,4 @@ +from collections import defaultdict import multiprocessing import multiprocessing.connection import os @@ -13,7 +14,6 @@ from ..logger import get_logger from ..operators import CurrentState from ..parameter import Configuration, Parameters from ..pbar import PBars, ProgressBarActor, progress_worker -from ..const import ONE_2, ONE_3, ONE_6 try: import ray @@ -21,6 +21,15 @@ except ModuleNotFoundError: ray = None +class TrackedValues(defaultdict): + def __init__(self): + super().__init__(list) + + def append(self, d: dict[str, Any]): + for k, v in d.items(): + self[k].append(v) + + class RK4IP: params: Parameters save_data: bool @@ -53,19 +62,7 @@ class RK4IP: save_data : bool, optional save calculated spectra to disk, by default False """ - self.set(params, save_data) - def __iter__(self) -> Generator[tuple[int, int, np.ndarray], None, None]: - yield from self.irun() - - def __len__(self) -> int: - return self.params.z_num - - def set( - self, - params: Parameters, - save_data=False, - ): self.params = params self.save_data = save_data @@ -77,16 +74,12 @@ class RK4IP: self.logger = get_logger(self.params.output_path.name) - self.dw = self.params.w[1] - self.params.w[0] - self.z_targets = self.params.z_targets self.error_ok = ( params.tolerated_error if self.params.adapt_step_size else self.params.step_size ) - self._setup_sim_parameters() - - def _setup_sim_parameters(self): - # making sure to keep only the z that we want + # setup save targets + self.z_targets = self.params.z_targets self.z_stored = list(self.z_targets.copy()[0 : self.params.recovery_last_stored + 1]) self.z_targets = list(self.z_targets.copy()[self.params.recovery_last_stored :]) self.z_targets.sort() @@ -97,16 +90,18 @@ class RK4IP: C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4) else: C_to_A_factor = 1.0 - z = self.z_targets.pop(0) + # Initial step size if self.params.adapt_step_size: - initial_h = (self.z_targets[0] - z) / 2 + initial_h = (self.z_targets[1] - self.z_targets[0]) / 2 else: initial_h = self.error_ok self.state = CurrentState( length=self.params.length, - z=z, - h=initial_h, + z=self.z_targets.pop(0), + current_step_size=initial_h, + previous_step_size=0.0, + step=1, C_to_A_factor=C_to_A_factor, converter=self.params.ifft, spectrum=self.params.spec_0.copy() / C_to_A_factor, @@ -114,11 +109,7 @@ class RK4IP: self.stored_spectra = self.params.recovery_last_stored * [None] + [ self.state.spectrum.copy() ] - self.cons_qty = [ - self.params.conserved_quantity(self.state), - 0, - ] - self.size_fac = 2 ** (1 / 5) + self.tracked_values = TrackedValues() def _save_current_spectrum(self, num: int): """saves the spectrum and the corresponding cons_qty array @@ -128,8 +119,8 @@ class RK4IP: num : int index of the z postition """ - self._save_data(self.get_current_spectrum(), f"spectrum_{num}") - self._save_data(self.cons_qty, "cons_qty") + self.write(self.get_current_spectrum(), f"spectrum_{num}") + self.write(self.tracked_values, "tracked_values") self.step_saved() def get_current_spectrum(self) -> np.ndarray: @@ -142,7 +133,7 @@ class RK4IP: """ return self.state.C_to_A_factor * self.state.spectrum - def _save_data(self, data: np.ndarray, name: str): + def write(self, data: np.ndarray, name: str): """calls the appropriate method to save data Parameters @@ -168,7 +159,7 @@ class RK4IP: ) if self.save_data: - self._save_data(self.z_stored, "z.npy") + self.write(self.z_stored, "z.npy") return self.stored_spectra @@ -185,40 +176,36 @@ class RK4IP: spectrum """ - # Print introduction self.logger.debug( "Computing {} new spectra, first one at {}m".format(self.store_num, self.z_targets[0]) ) + store = False - # Start of the integration - step = 1 - store = False # store a spectrum - - yield step, len(self.stored_spectra) - 1, self.get_current_spectrum() + yield self.state.step, len(self.stored_spectra) - 1, self.get_current_spectrum() while self.state.z < self.params.length: - h_taken = self.take_step(step) + self.state = self.params.integrator(self.state) - step += 1 - self.cons_qty.append(0) + self.state.step += 1 + new_tracked_values = ( + dict(step=self.state.step, z=self.state.z) | self.params.integrator.all_values() + ) + self.logger.debug(f"tracked values at z={self.state.z} : {new_tracked_values}") + self.tracked_values.append(new_tracked_values) # Whether the current spectrum has to be stored depends on previous step if store: - self.logger.debug( - "{} steps, z = {:.4f}, h = {:.5g}".format(step, self.state.z, h_taken) - ) - current_spec = self.get_current_spectrum() self.stored_spectra.append(current_spec) - yield step, len(self.stored_spectra) - 1, current_spec + yield self.state.step, len(self.stored_spectra) - 1, current_spec self.z_stored.append(self.state.z) del self.z_targets[0] # reset the constant step size after a spectrum is stored if not self.params.adapt_step_size: - self.state.h = self.error_ok + self.state.current_step_size = self.error_ok if len(self.z_targets) == 0: break @@ -226,69 +213,19 @@ class RK4IP: # if the next step goes over a position at which we want to store # a spectrum, we shorten the step to reach this position exactly - if self.state.z + self.state.h >= self.z_targets[0]: + if self.state.z + self.state.current_step_size >= self.z_targets[0]: store = True - self.state.h = self.z_targets[0] - self.state.z - - def take_step(self, step: int) -> float: - """computes a new spectrum, whilst adjusting step size if required, until the error estimation - validates the new spectrum. Saves the result in the internal state attribute - - Parameters - ---------- - step : int - index of the current - - Returns - ------- - h : float - step sized used - """ - keep = False - h_next_step = self.state.h - while not keep: - h = h_next_step - - expD = np.exp(h * ONE_2 * self.params.linear_operator(self.state)) - - A_I = expD * self.state.spectrum - k1 = expD * (h * self.params.nonlinear_operator(self.state)) - k2 = h * self.params.nonlinear_operator(self.state.replace(A_I + k1 * ONE_2)) - k3 = h * self.params.nonlinear_operator(self.state.replace(A_I + k2 * ONE_2)) - k4 = h * self.params.nonlinear_operator(self.state.replace(expD * (A_I + k3))) - new_state = self.state.replace( - expD * (A_I + k1 * ONE_6 + k2 * ONE_3 + k3 * ONE_3) + k4 * ONE_6 - ) - - self.cons_qty[step] = self.params.conserved_quantity(new_state) - if self.params.adapt_step_size: - curr_p_change = np.abs(self.cons_qty[step - 1] - self.cons_qty[step]) - cons_qty_change_ok = self.error_ok * self.cons_qty[step - 1] - - if curr_p_change > 2 * cons_qty_change_ok: - progress_str = f"step {step} rejected with h = {h:.4e}, doing over" - self.logger.debug(progress_str) - keep = False - h_next_step = h * ONE_2 - elif cons_qty_change_ok < curr_p_change <= 2.0 * cons_qty_change_ok: - keep = True - h_next_step = h / self.size_fac - elif curr_p_change < 0.1 * cons_qty_change_ok: - keep = True - h_next_step = h * self.size_fac - else: - keep = True - h_next_step = h - else: - keep = True - self.state = new_state - self.state.h = h_next_step - self.state.z += h - return h + self.state.current_step_size = self.z_targets[0] - self.state.z def step_saved(self): pass + def __iter__(self) -> Generator[tuple[int, int, np.ndarray], None, None]: + yield from self.irun() + + def __len__(self) -> int: + return self.params.z_num + class SequentialRK4IP(RK4IP): def __init__( @@ -339,7 +276,7 @@ class RayRK4IP(RK4IP): ): self.worker_id = worker_id self.p_actor = p_actor - super().set( + super().__init__( params, save_data=save_data, ) diff --git a/src/scgenerator/solver.py b/src/scgenerator/solver.py new file mode 100644 index 0000000..be892b7 --- /dev/null +++ b/src/scgenerator/solver.py @@ -0,0 +1,280 @@ +from abc import abstractmethod + +import numpy as np + +from . import math +from .logger import get_logger +from .operators import ( + AbstractConservedQuantity, + CurrentState, + LinearOperator, + NonLinearOperator, + ValueTracker, +) + +################################################## +################### STEP-TAKER ################### +################################################## + + +class StepTaker(ValueTracker): + linear_operator: LinearOperator + nonlinear_operator: NonLinearOperator + _tracked_values: dict[str, float] + + def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator): + self.linear_operator = linear_operator + self.nonlinear_operator = nonlinear_operator + self._tracked_values = {} + + @abstractmethod + def __call__(self, state: CurrentState, step_size: float) -> np.ndarray: + ... + + def all_values(self) -> dict[str, float]: + """override ValueTracker.all_values to account for the fact that operators are called + multiple times per step, sometimes with different state, so we use value recorded + earlier. Please call self.recorde_tracked_values() one time only just after calling + the linear and nonlinear operators in your StepTaker. + + Returns + ------- + dict[str, float] + tracked values + """ + return self.values() | self._tracked_values + + def record_tracked_values(self): + self._tracked_values = super().all_values() + + +class RK4IPStepTaker(StepTaker): + c2 = 1 / 2 + c3 = 1 / 3 + c6 = 1 / 6 + _cached_values: tuple[np.ndarray, np.ndarray] + _cached_key: float + _cache_hits: int + _cache_misses: int + + def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator): + super().__init__(linear_operator, nonlinear_operator) + self._cached_key = None + self._cached_values = None + self._cache_hits = 0 + self._cache_misses = 0 + + def __call__(self, state: CurrentState, step_size: float) -> np.ndarray: + h = step_size + l0, nl0 = self.cached_values(state) + expD = np.exp(h * self.c2 * l0) + + A_I = expD * state.spectrum + k1 = expD * (h * nl0) + k2 = h * self.nonlinear_operator(state.replace(A_I + k1 * self.c2)) + k3 = h * self.nonlinear_operator(state.replace(A_I + k2 * self.c2)) + k4 = h * self.nonlinear_operator(state.replace(expD * (A_I + k3))) + + return expD * (A_I + k1 * self.c6 + k2 * self.c3 + k3 * self.c3) + k4 * self.c6 + + def cached_values(self, state: CurrentState) -> tuple[np.ndarray, np.ndarray]: + """the evaluation of the linear and nonlinear operators at the start of the step don't + depend on the step size, so we cache them in case we need them more than once (which + can happen depending on the adaptive step size controller) + + + Parameters + ---------- + state : CurrentState + current state of the simulation. state.z is used as the key for the cache + + Returns + ------- + np.ndarray + result of the linear operator + np.ndarray + result of the nonlinear operator + """ + if self._cached_key != state.z: + self._cache_misses += 1 + self._cached_key = state.z + self._cached_values = self.linear_operator(state), self.nonlinear_operator(state) + self.record_tracked_values() + else: + self._cache_hits += 1 + return self._cached_values + + def values(self) -> dict[str, float]: + return dict(RK4IP_cache_hits=self._cache_hits, RK4IP_cache_misses=self._cache_misses) + + +################################################## +################### INTEGRATOR ################### +################################################## + + +class Integrator(ValueTracker): + last_step = 0.0 + + @abstractmethod + def __call__(self, state: CurrentState) -> CurrentState: + """propagate the state with a step size of state.current_step_size + and return a new state with updated z and previous_step_size attributes""" + ... + + +class ConstantStepIntegrator(Integrator): + def __call__(self, state: CurrentState) -> CurrentState: + new_state = state.replace(self.step_taker(state, state.current_step_size)) + new_state.z += new_state.current_step_size + new_state.previous_step_size = new_state.current_step_size + return new_state + + def values(self) -> dict[str, float]: + return dict(h=self.last_step) + + +class ConservedQuantityIntegrator(Integrator): + step_taker: StepTaker + conserved_quantity: AbstractConservedQuantity + last_quantity_value: float + tolerated_error: float + local_error: float = 0.0 + + def __init__( + self, + step_taker: StepTaker, + conserved_quantity: AbstractConservedQuantity, + tolerated_error: float, + ): + self.conserved_quantity = conserved_quantity + self.last_quantity_value = 0 + self.tolerated_error = tolerated_error + self.logger = get_logger(self.__class__.__name__) + self.size_fac = 2.0 ** (1.0 / 5.0) + self.step_taker = step_taker + + def __call__(self, state: CurrentState) -> CurrentState: + keep = False + h_next_step = state.current_step_size + while not keep: + h = h_next_step + + new_state = state.replace(self.step_taker(state, h)) + + new_qty = self.conserved_quantity(new_state) + delta = np.abs(new_qty - self.last_quantity_value) / self.last_quantity_value + + if delta > 2 * self.tolerated_error: + progress_str = f"step {state.step} rejected with h = {h:.4e}, doing over" + self.logger.info(progress_str) + keep = False + h_next_step = h * 0.5 + elif self.tolerated_error < delta <= 2.0 * self.tolerated_error: + keep = True + h_next_step = h / self.size_fac + elif delta < 0.1 * self.tolerated_error: + keep = True + h_next_step = h * self.size_fac + else: + keep = True + h_next_step = h + + self.local_error = delta + self.last_quantity_value = new_qty + new_state.current_step_size = h_next_step + new_state.previous_step_size = h + new_state.z += h + self.last_step = h + return new_state + + def values(self) -> dict[str, float]: + return dict( + cons_qty=self.last_quantity_value, h=self.last_step, relative_error=self.local_error + ) + + +class LocalErrorIntegrator(Integrator): + step_taker: StepTaker + tolerated_error: float + local_error: float + + def __init__(self, step_taker: StepTaker, tolerated_error: float, w_num: int): + self.tolerated_error = tolerated_error + self.local_error = 0.0 + self.logger = get_logger(self.__class__.__name__) + self.size_fac, self.fine_fac, self.coarse_fac = 2.0 ** (1.0 / 5.0), 16 / 15, -1 / 15 + self.step_taker = step_taker + + def __call__(self, state: CurrentState) -> CurrentState: + keep = False + h_next_step = state.current_step_size + while not keep: + h = h_next_step + h_half = h / 2 + coarse_spec = self.step_taker(state, h) + + fine_spec1 = self.step_taker(state, h_half) + fine_state = state.replace(fine_spec1, z=state.z + h_half) + fine_spec = self.step_taker(fine_state, h_half) + + delta = self.compute_diff(coarse_spec, fine_spec) + + if delta > 2 * self.tolerated_error: + keep = False + h_next_step = h_half + elif self.tolerated_error <= delta <= 2 * self.tolerated_error: + keep = True + h_next_step = h / self.size_fac + elif 0.5 * self.tolerated_error <= delta < self.tolerated_error: + keep = True + h_next_step = h + else: + keep = True + h_next_step = h * self.size_fac + + self.local_error = delta + fine_state.spectrum = fine_spec * self.fine_fac + coarse_spec * self.coarse_fac + fine_state.current_step_size = h_next_step + fine_state.previous_step_size = h + fine_state.z += h + self.last_step = h + return fine_state + + def compute_diff(self, coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float: + return np.sqrt(math.abs2(coarse_spec - fine_spec).sum() / math.abs2(fine_spec).sum()) + + def values(self) -> dict[str, float]: + return dict(relative_error=self.local_error, h=self.last_step) + + +class ERK43(Integrator): + linear_operator: LinearOperator + nonlinear_operator: NonLinearOperator + dt: float + + def __init__( + self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator, dt: float + ): + self.linear_operator = linear_operator + self.nonlinear_operator = nonlinear_operator + self.dt = dt + + def __call__(self, state: CurrentState) -> CurrentState: + keep = False + h_next_step = state.current_step_size + while not keep: + h = h_next_step + expD = np.exp(h * 0.5 * self.linear_operator(state)) + A_I = expD * state.spectrum + k1 = expD * state.prev_spectrum + k2 = self.nonlinear_operator(state.replace(A_I + 0.5 * h * k1)) + k3 = self.nonlinear_operator(state.replace(A_I + 0.5 * h * k2)) + k4 = self.nonlinear_operator(state.replace(expD * A_I + h * k3)) + r = expD * (A_I + h / 6 * (k1 + 2 * k2 + 2 * k3)) + + new_fine = r + h / 6 * k4 + + k5 = self.nonlinear_operator(state.replace(new_fine)) + + new_coarse = r + h / 30 * (2 * k4 + 3 * k5) diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index a2f8e34..eb998c7 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from functools import cache from pathlib import Path from string import printable as str_printable -from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Set +from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Set, Union import numpy as np import pkg_resources as pkg @@ -251,12 +251,12 @@ def load_material_dico(name: str) -> dict[str, Any]: return tomli.loads(Paths.gets("materials"))[name] -def save_data(data: np.ndarray, data_dir: Path, file_name: str): +def save_data(data: Union[np.ndarray, MutableMapping], data_dir: Path, file_name: str): """saves numpy array to disk Parameters ---------- - data : np.ndarray + data : Union[np.ndarray, MutableMapping] data to save file_name : str file name @@ -266,7 +266,10 @@ def save_data(data: np.ndarray, data_dir: Path, file_name: str): identifier in the main data folder of the task, by default "" """ path = data_dir / file_name - np.save(path, data) + if isinstance(data, np.ndarray): + np.save(path, data) + elif isinstance(data, MutableMapping): + np.savez(path, **data) get_logger(__name__).debug(f"saved data in {path}") return