Trying to make Local Error method work, no success

This commit is contained in:
Benoît Sierro
2021-11-09 09:53:33 +01:00
parent 0d5d529ba3
commit f757466fe1
6 changed files with 427 additions and 135 deletions

View File

@@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Union
import numpy as np import numpy as np
from . import math, operators, utils from . import math, operators, utils, solver
from .const import MANDATORY_PARAMETERS from .const import MANDATORY_PARAMETERS
from .errors import EvaluatorError, NoDefaultError from .errors import EvaluatorError, NoDefaultError
from .physics import fiber, materials, pulse, units from .physics import fiber, materials, pulse, units
@@ -377,6 +377,8 @@ default_rules: list[Rule] = [
Rule("loss_op", operators.NoLoss, priorities=-1), Rule("loss_op", operators.NoLoss, priorities=-1),
Rule("plasma_op", operators.NoPlasma, priorities=-1), Rule("plasma_op", operators.NoPlasma, priorities=-1),
Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1), Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1),
Rule("step_taker", solver.RK4IPStepTaker),
Rule("integrator", solver.ConstantStepIntegrator, priorities=-1),
] ]
envelope_rules = default_rules + [ envelope_rules = default_rules + [
@@ -417,6 +419,7 @@ envelope_rules = default_rules + [
Rule("dispersion_op", operators.DirectDispersion), Rule("dispersion_op", operators.DirectDispersion),
Rule("linear_operator", operators.EnvelopeLinearOperator), Rule("linear_operator", operators.EnvelopeLinearOperator),
Rule("conserved_quantity", operators.conserved_quantity), Rule("conserved_quantity", operators.conserved_quantity),
Rule("integrator", solver.ConservedQuantityIntegrator),
] ]
full_field_rules = default_rules + [ full_field_rules = default_rules + [
@@ -440,4 +443,6 @@ full_field_rules = default_rules + [
operators.FullFieldLinearOperator, operators.FullFieldLinearOperator,
), ),
Rule("nonlinear_operator", operators.FullFieldNonLinearOperator), Rule("nonlinear_operator", operators.FullFieldNonLinearOperator),
# Integration
Rule("integrator", solver.LocalErrorIntegrator),
] ]

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import dataclasses import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
import re
from typing import Any, Callable from typing import Any, Callable
import numpy as np import numpy as np
@@ -23,14 +24,18 @@ class SpectrumDescriptor:
name: str name: str
value: np.ndarray = None value: np.ndarray = None
_counter = 0 _counter = 0
_full_field: bool = False
_converter: Callable[[np.ndarray], np.ndarray] _converter: Callable[[np.ndarray], np.ndarray]
def __init__(self, spec2_name: str, field_name: str, field2_name: str):
self.spec2_name = spec2_name
self.field_name = field_name
self.field2_name = field2_name
def __set__(self, instance: CurrentState, value: np.ndarray): def __set__(self, instance: CurrentState, value: np.ndarray):
self._counter += 1 self._counter += 1
instance.spec2 = math.abs2(value) setattr(instance, self.spec2_name, math.abs2(value))
instance.field = instance.converter(value) setattr(instance, self.field_name, instance.converter(value))
instance.field2 = math.abs2(instance.field) setattr(instance, self.field2_name, math.abs2(getattr(instance, self.field_name)))
self.value = value self.value = value
def __get__(self, instance, owner): def __get__(self, instance, owner):
@@ -46,41 +51,96 @@ class SpectrumDescriptor:
self.name = name self.name = name
class SpectrumDescriptor2:
name: str
spectrum: np.ndarray = None
__spec2: np.ndarray = None
__field: np.ndarray = None
__field2: np.ndarray = None
_converter: Callable[[np.ndarray], np.ndarray]
def __set__(self, instance: CurrentState, value: np.ndarray):
self._converter = instance.converter
self.spectrum = value
self.__spec2 = None
self.__field = None
self.__field2 = None
@property
def spec2(self) -> np.ndarray:
if self.__spec2 is None:
self.__spec2 = math.abs2(self.spectrum)
return self.__spec2
@property
def field(self) -> np.ndarray:
if self.__field is None:
self.__field = self._converter(self.spectrum)
return self.__field
@property
def field2(self) -> np.ndarray:
if self.__field2 is None:
self.__field2 = math.abs2(self.field)
return self.__field2
def __delete__(self, instance):
raise AttributeError("Cannot delete Spectrum field")
def __set_name__(self, owner, name):
self.name = name
@dataclasses.dataclass @dataclasses.dataclass
class CurrentState: class CurrentState:
length: float length: float
z: float z: float
h: float current_step_size: float
previous_step_size: float
step: int
C_to_A_factor: np.ndarray C_to_A_factor: np.ndarray
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft
spectrum: np.ndarray = SpectrumDescriptor() spectrum: np.ndarray = SpectrumDescriptor("spec2", "field", "field2")
spec2: np.ndarray = dataclasses.field(init=False) spec2: np.ndarray = dataclasses.field(init=False)
field: np.ndarray = dataclasses.field(init=False) field: np.ndarray = dataclasses.field(init=False)
field2: np.ndarray = dataclasses.field(init=False) field2: np.ndarray = dataclasses.field(init=False)
prev_spectrum: np.ndarray = SpectrumDescriptor("prev_spec2", "prev_field", "prev_field2")
prev_spec2: np.ndarray = dataclasses.field(init=False)
prev_field: np.ndarray = dataclasses.field(init=False)
prev_field2: np.ndarray = dataclasses.field(init=False)
@property @property
def z_ratio(self) -> float: def z_ratio(self) -> float:
return self.z / self.length return self.z / self.length
def replace(self, new_spectrum) -> CurrentState: def replace(self, new_spectrum: np.ndarray, **new_params) -> CurrentState:
return CurrentState( """returns a new state with new attributes"""
self.length, self.z, self.h, self.C_to_A_factor, self.converter, new_spectrum params = dict(
spectrum=new_spectrum,
length=self.length,
z=self.z,
current_step_size=self.current_step_size,
previous_step_size=self.previous_step_size,
step=self.step,
C_to_A_factor=self.C_to_A_factor,
converter=self.converter,
) )
return CurrentState(**(params | new_params))
class Operator(ABC): class ValueTracker(ABC):
def values(self) -> dict[str, float]: def values(self) -> dict[str, float]:
return {} return {}
def get_values(self) -> dict[str, float]: def all_values(self) -> dict[str, float]:
out = self.values() out = self.values()
for operator in self.__dict__.values(): for operator in vars(self).values():
if isinstance(operator, Operator): if isinstance(operator, ValueTracker):
out |= operator.get_values() out = operator.all_values() | out
return out return out
def __repr__(self) -> str: def __repr__(self) -> str:
value_pair_list = list(self.__dict__.items()) value_pair_list = list(vars(self).items())
if len(value_pair_list) == 0: if len(value_pair_list) == 0:
value_pair_str_list = "" value_pair_str_list = ""
elif len(value_pair_list) == 1: elif len(value_pair_list) == 1:
@@ -95,6 +155,8 @@ class Operator(ABC):
return repr(v[0]) return repr(v[0])
return repr(v) return repr(v)
class Operator(ValueTracker):
@abstractmethod @abstractmethod
def __call__(self, state: CurrentState) -> np.ndarray: def __call__(self, state: CurrentState) -> np.ndarray:
pass pass
@@ -757,7 +819,12 @@ class PhotonNumberLoss(AbstractConservedQuantity):
def __call__(self, state: CurrentState) -> float: def __call__(self, state: CurrentState) -> float:
return pulse.photon_number_with_loss( return pulse.photon_number_with_loss(
state.spec2, self.w, self.dw, self.gamma_op(state), self.loss_op(state), state.h state.spec2,
self.w,
self.dw,
self.gamma_op(state),
self.loss_op(state),
state.current_step_size,
) )
@@ -778,7 +845,10 @@ class EnergyLoss(AbstractConservedQuantity):
def __call__(self, state: CurrentState) -> float: def __call__(self, state: CurrentState) -> float:
return pulse.pulse_energy_with_loss( return pulse.pulse_energy_with_loss(
math.abs2(state.C_to_A_factor * state.spectrum), self.dw, self.loss_op(state), state.h math.abs2(state.C_to_A_factor * state.spectrum),
self.dw,
self.loss_op(state),
state.current_step_size,
) )

View File

@@ -13,18 +13,13 @@ from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeV
import numpy as np import numpy as np
from scgenerator.physics import units
from . import env, legacy, utils from . import env, legacy, utils
from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__ from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__
from .errors import EvaluatorError from .errors import EvaluatorError
from .evaluator import Evaluator from .evaluator import Evaluator
from .logger import get_logger from .logger import get_logger
from .operators import ( from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator
AbstractConservedQuantity, from .solver import Integrator, StepTaker
LinearOperator,
NonLinearOperator,
)
from .utils import fiber_folder, update_path_name from .utils import fiber_folder, update_path_name
from .variationer import VariationDescriptor, Variationer from .variationer import VariationDescriptor, Variationer
@@ -382,6 +377,8 @@ class Parameters:
# computed # computed
linear_operator: LinearOperator = Parameter(type_checker(LinearOperator)) linear_operator: LinearOperator = Parameter(type_checker(LinearOperator))
nonlinear_operator: NonLinearOperator = Parameter(type_checker(NonLinearOperator)) nonlinear_operator: NonLinearOperator = Parameter(type_checker(NonLinearOperator))
step_taker: StepTaker = Parameter(type_checker(StepTaker))
integrator: Integrator = Parameter(type_checker(Integrator))
conserved_quantity: AbstractConservedQuantity = Parameter( conserved_quantity: AbstractConservedQuantity = Parameter(
type_checker(AbstractConservedQuantity) type_checker(AbstractConservedQuantity)
) )

View File

@@ -1,3 +1,4 @@
from collections import defaultdict
import multiprocessing import multiprocessing
import multiprocessing.connection import multiprocessing.connection
import os import os
@@ -13,7 +14,6 @@ from ..logger import get_logger
from ..operators import CurrentState from ..operators import CurrentState
from ..parameter import Configuration, Parameters from ..parameter import Configuration, Parameters
from ..pbar import PBars, ProgressBarActor, progress_worker from ..pbar import PBars, ProgressBarActor, progress_worker
from ..const import ONE_2, ONE_3, ONE_6
try: try:
import ray import ray
@@ -21,6 +21,15 @@ except ModuleNotFoundError:
ray = None ray = None
class TrackedValues(defaultdict):
def __init__(self):
super().__init__(list)
def append(self, d: dict[str, Any]):
for k, v in d.items():
self[k].append(v)
class RK4IP: class RK4IP:
params: Parameters params: Parameters
save_data: bool save_data: bool
@@ -53,19 +62,7 @@ class RK4IP:
save_data : bool, optional save_data : bool, optional
save calculated spectra to disk, by default False save calculated spectra to disk, by default False
""" """
self.set(params, save_data)
def __iter__(self) -> Generator[tuple[int, int, np.ndarray], None, None]:
yield from self.irun()
def __len__(self) -> int:
return self.params.z_num
def set(
self,
params: Parameters,
save_data=False,
):
self.params = params self.params = params
self.save_data = save_data self.save_data = save_data
@@ -77,16 +74,12 @@ class RK4IP:
self.logger = get_logger(self.params.output_path.name) self.logger = get_logger(self.params.output_path.name)
self.dw = self.params.w[1] - self.params.w[0]
self.z_targets = self.params.z_targets
self.error_ok = ( self.error_ok = (
params.tolerated_error if self.params.adapt_step_size else self.params.step_size params.tolerated_error if self.params.adapt_step_size else self.params.step_size
) )
self._setup_sim_parameters() # setup save targets
self.z_targets = self.params.z_targets
def _setup_sim_parameters(self):
# making sure to keep only the z that we want
self.z_stored = list(self.z_targets.copy()[0 : self.params.recovery_last_stored + 1]) self.z_stored = list(self.z_targets.copy()[0 : self.params.recovery_last_stored + 1])
self.z_targets = list(self.z_targets.copy()[self.params.recovery_last_stored :]) self.z_targets = list(self.z_targets.copy()[self.params.recovery_last_stored :])
self.z_targets.sort() self.z_targets.sort()
@@ -97,16 +90,18 @@ class RK4IP:
C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4) C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
else: else:
C_to_A_factor = 1.0 C_to_A_factor = 1.0
z = self.z_targets.pop(0)
# Initial step size # Initial step size
if self.params.adapt_step_size: if self.params.adapt_step_size:
initial_h = (self.z_targets[0] - z) / 2 initial_h = (self.z_targets[1] - self.z_targets[0]) / 2
else: else:
initial_h = self.error_ok initial_h = self.error_ok
self.state = CurrentState( self.state = CurrentState(
length=self.params.length, length=self.params.length,
z=z, z=self.z_targets.pop(0),
h=initial_h, current_step_size=initial_h,
previous_step_size=0.0,
step=1,
C_to_A_factor=C_to_A_factor, C_to_A_factor=C_to_A_factor,
converter=self.params.ifft, converter=self.params.ifft,
spectrum=self.params.spec_0.copy() / C_to_A_factor, spectrum=self.params.spec_0.copy() / C_to_A_factor,
@@ -114,11 +109,7 @@ class RK4IP:
self.stored_spectra = self.params.recovery_last_stored * [None] + [ self.stored_spectra = self.params.recovery_last_stored * [None] + [
self.state.spectrum.copy() self.state.spectrum.copy()
] ]
self.cons_qty = [ self.tracked_values = TrackedValues()
self.params.conserved_quantity(self.state),
0,
]
self.size_fac = 2 ** (1 / 5)
def _save_current_spectrum(self, num: int): def _save_current_spectrum(self, num: int):
"""saves the spectrum and the corresponding cons_qty array """saves the spectrum and the corresponding cons_qty array
@@ -128,8 +119,8 @@ class RK4IP:
num : int num : int
index of the z postition index of the z postition
""" """
self._save_data(self.get_current_spectrum(), f"spectrum_{num}") self.write(self.get_current_spectrum(), f"spectrum_{num}")
self._save_data(self.cons_qty, "cons_qty") self.write(self.tracked_values, "tracked_values")
self.step_saved() self.step_saved()
def get_current_spectrum(self) -> np.ndarray: def get_current_spectrum(self) -> np.ndarray:
@@ -142,7 +133,7 @@ class RK4IP:
""" """
return self.state.C_to_A_factor * self.state.spectrum return self.state.C_to_A_factor * self.state.spectrum
def _save_data(self, data: np.ndarray, name: str): def write(self, data: np.ndarray, name: str):
"""calls the appropriate method to save data """calls the appropriate method to save data
Parameters Parameters
@@ -168,7 +159,7 @@ class RK4IP:
) )
if self.save_data: if self.save_data:
self._save_data(self.z_stored, "z.npy") self.write(self.z_stored, "z.npy")
return self.stored_spectra return self.stored_spectra
@@ -185,40 +176,36 @@ class RK4IP:
spectrum spectrum
""" """
# Print introduction
self.logger.debug( self.logger.debug(
"Computing {} new spectra, first one at {}m".format(self.store_num, self.z_targets[0]) "Computing {} new spectra, first one at {}m".format(self.store_num, self.z_targets[0])
) )
store = False
# Start of the integration yield self.state.step, len(self.stored_spectra) - 1, self.get_current_spectrum()
step = 1
store = False # store a spectrum
yield step, len(self.stored_spectra) - 1, self.get_current_spectrum()
while self.state.z < self.params.length: while self.state.z < self.params.length:
h_taken = self.take_step(step) self.state = self.params.integrator(self.state)
step += 1 self.state.step += 1
self.cons_qty.append(0) new_tracked_values = (
dict(step=self.state.step, z=self.state.z) | self.params.integrator.all_values()
)
self.logger.debug(f"tracked values at z={self.state.z} : {new_tracked_values}")
self.tracked_values.append(new_tracked_values)
# Whether the current spectrum has to be stored depends on previous step # Whether the current spectrum has to be stored depends on previous step
if store: if store:
self.logger.debug(
"{} steps, z = {:.4f}, h = {:.5g}".format(step, self.state.z, h_taken)
)
current_spec = self.get_current_spectrum() current_spec = self.get_current_spectrum()
self.stored_spectra.append(current_spec) self.stored_spectra.append(current_spec)
yield step, len(self.stored_spectra) - 1, current_spec yield self.state.step, len(self.stored_spectra) - 1, current_spec
self.z_stored.append(self.state.z) self.z_stored.append(self.state.z)
del self.z_targets[0] del self.z_targets[0]
# reset the constant step size after a spectrum is stored # reset the constant step size after a spectrum is stored
if not self.params.adapt_step_size: if not self.params.adapt_step_size:
self.state.h = self.error_ok self.state.current_step_size = self.error_ok
if len(self.z_targets) == 0: if len(self.z_targets) == 0:
break break
@@ -226,69 +213,19 @@ class RK4IP:
# if the next step goes over a position at which we want to store # if the next step goes over a position at which we want to store
# a spectrum, we shorten the step to reach this position exactly # a spectrum, we shorten the step to reach this position exactly
if self.state.z + self.state.h >= self.z_targets[0]: if self.state.z + self.state.current_step_size >= self.z_targets[0]:
store = True store = True
self.state.h = self.z_targets[0] - self.state.z self.state.current_step_size = self.z_targets[0] - self.state.z
def take_step(self, step: int) -> float:
"""computes a new spectrum, whilst adjusting step size if required, until the error estimation
validates the new spectrum. Saves the result in the internal state attribute
Parameters
----------
step : int
index of the current
Returns
-------
h : float
step sized used
"""
keep = False
h_next_step = self.state.h
while not keep:
h = h_next_step
expD = np.exp(h * ONE_2 * self.params.linear_operator(self.state))
A_I = expD * self.state.spectrum
k1 = expD * (h * self.params.nonlinear_operator(self.state))
k2 = h * self.params.nonlinear_operator(self.state.replace(A_I + k1 * ONE_2))
k3 = h * self.params.nonlinear_operator(self.state.replace(A_I + k2 * ONE_2))
k4 = h * self.params.nonlinear_operator(self.state.replace(expD * (A_I + k3)))
new_state = self.state.replace(
expD * (A_I + k1 * ONE_6 + k2 * ONE_3 + k3 * ONE_3) + k4 * ONE_6
)
self.cons_qty[step] = self.params.conserved_quantity(new_state)
if self.params.adapt_step_size:
curr_p_change = np.abs(self.cons_qty[step - 1] - self.cons_qty[step])
cons_qty_change_ok = self.error_ok * self.cons_qty[step - 1]
if curr_p_change > 2 * cons_qty_change_ok:
progress_str = f"step {step} rejected with h = {h:.4e}, doing over"
self.logger.debug(progress_str)
keep = False
h_next_step = h * ONE_2
elif cons_qty_change_ok < curr_p_change <= 2.0 * cons_qty_change_ok:
keep = True
h_next_step = h / self.size_fac
elif curr_p_change < 0.1 * cons_qty_change_ok:
keep = True
h_next_step = h * self.size_fac
else:
keep = True
h_next_step = h
else:
keep = True
self.state = new_state
self.state.h = h_next_step
self.state.z += h
return h
def step_saved(self): def step_saved(self):
pass pass
def __iter__(self) -> Generator[tuple[int, int, np.ndarray], None, None]:
yield from self.irun()
def __len__(self) -> int:
return self.params.z_num
class SequentialRK4IP(RK4IP): class SequentialRK4IP(RK4IP):
def __init__( def __init__(
@@ -339,7 +276,7 @@ class RayRK4IP(RK4IP):
): ):
self.worker_id = worker_id self.worker_id = worker_id
self.p_actor = p_actor self.p_actor = p_actor
super().set( super().__init__(
params, params,
save_data=save_data, save_data=save_data,
) )

280
src/scgenerator/solver.py Normal file
View File

@@ -0,0 +1,280 @@
from abc import abstractmethod
import numpy as np
from . import math
from .logger import get_logger
from .operators import (
AbstractConservedQuantity,
CurrentState,
LinearOperator,
NonLinearOperator,
ValueTracker,
)
##################################################
################### STEP-TAKER ###################
##################################################
class StepTaker(ValueTracker):
linear_operator: LinearOperator
nonlinear_operator: NonLinearOperator
_tracked_values: dict[str, float]
def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator):
self.linear_operator = linear_operator
self.nonlinear_operator = nonlinear_operator
self._tracked_values = {}
@abstractmethod
def __call__(self, state: CurrentState, step_size: float) -> np.ndarray:
...
def all_values(self) -> dict[str, float]:
"""override ValueTracker.all_values to account for the fact that operators are called
multiple times per step, sometimes with different state, so we use value recorded
earlier. Please call self.recorde_tracked_values() one time only just after calling
the linear and nonlinear operators in your StepTaker.
Returns
-------
dict[str, float]
tracked values
"""
return self.values() | self._tracked_values
def record_tracked_values(self):
self._tracked_values = super().all_values()
class RK4IPStepTaker(StepTaker):
c2 = 1 / 2
c3 = 1 / 3
c6 = 1 / 6
_cached_values: tuple[np.ndarray, np.ndarray]
_cached_key: float
_cache_hits: int
_cache_misses: int
def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator):
super().__init__(linear_operator, nonlinear_operator)
self._cached_key = None
self._cached_values = None
self._cache_hits = 0
self._cache_misses = 0
def __call__(self, state: CurrentState, step_size: float) -> np.ndarray:
h = step_size
l0, nl0 = self.cached_values(state)
expD = np.exp(h * self.c2 * l0)
A_I = expD * state.spectrum
k1 = expD * (h * nl0)
k2 = h * self.nonlinear_operator(state.replace(A_I + k1 * self.c2))
k3 = h * self.nonlinear_operator(state.replace(A_I + k2 * self.c2))
k4 = h * self.nonlinear_operator(state.replace(expD * (A_I + k3)))
return expD * (A_I + k1 * self.c6 + k2 * self.c3 + k3 * self.c3) + k4 * self.c6
def cached_values(self, state: CurrentState) -> tuple[np.ndarray, np.ndarray]:
"""the evaluation of the linear and nonlinear operators at the start of the step don't
depend on the step size, so we cache them in case we need them more than once (which
can happen depending on the adaptive step size controller)
Parameters
----------
state : CurrentState
current state of the simulation. state.z is used as the key for the cache
Returns
-------
np.ndarray
result of the linear operator
np.ndarray
result of the nonlinear operator
"""
if self._cached_key != state.z:
self._cache_misses += 1
self._cached_key = state.z
self._cached_values = self.linear_operator(state), self.nonlinear_operator(state)
self.record_tracked_values()
else:
self._cache_hits += 1
return self._cached_values
def values(self) -> dict[str, float]:
return dict(RK4IP_cache_hits=self._cache_hits, RK4IP_cache_misses=self._cache_misses)
##################################################
################### INTEGRATOR ###################
##################################################
class Integrator(ValueTracker):
last_step = 0.0
@abstractmethod
def __call__(self, state: CurrentState) -> CurrentState:
"""propagate the state with a step size of state.current_step_size
and return a new state with updated z and previous_step_size attributes"""
...
class ConstantStepIntegrator(Integrator):
def __call__(self, state: CurrentState) -> CurrentState:
new_state = state.replace(self.step_taker(state, state.current_step_size))
new_state.z += new_state.current_step_size
new_state.previous_step_size = new_state.current_step_size
return new_state
def values(self) -> dict[str, float]:
return dict(h=self.last_step)
class ConservedQuantityIntegrator(Integrator):
step_taker: StepTaker
conserved_quantity: AbstractConservedQuantity
last_quantity_value: float
tolerated_error: float
local_error: float = 0.0
def __init__(
self,
step_taker: StepTaker,
conserved_quantity: AbstractConservedQuantity,
tolerated_error: float,
):
self.conserved_quantity = conserved_quantity
self.last_quantity_value = 0
self.tolerated_error = tolerated_error
self.logger = get_logger(self.__class__.__name__)
self.size_fac = 2.0 ** (1.0 / 5.0)
self.step_taker = step_taker
def __call__(self, state: CurrentState) -> CurrentState:
keep = False
h_next_step = state.current_step_size
while not keep:
h = h_next_step
new_state = state.replace(self.step_taker(state, h))
new_qty = self.conserved_quantity(new_state)
delta = np.abs(new_qty - self.last_quantity_value) / self.last_quantity_value
if delta > 2 * self.tolerated_error:
progress_str = f"step {state.step} rejected with h = {h:.4e}, doing over"
self.logger.info(progress_str)
keep = False
h_next_step = h * 0.5
elif self.tolerated_error < delta <= 2.0 * self.tolerated_error:
keep = True
h_next_step = h / self.size_fac
elif delta < 0.1 * self.tolerated_error:
keep = True
h_next_step = h * self.size_fac
else:
keep = True
h_next_step = h
self.local_error = delta
self.last_quantity_value = new_qty
new_state.current_step_size = h_next_step
new_state.previous_step_size = h
new_state.z += h
self.last_step = h
return new_state
def values(self) -> dict[str, float]:
return dict(
cons_qty=self.last_quantity_value, h=self.last_step, relative_error=self.local_error
)
class LocalErrorIntegrator(Integrator):
step_taker: StepTaker
tolerated_error: float
local_error: float
def __init__(self, step_taker: StepTaker, tolerated_error: float, w_num: int):
self.tolerated_error = tolerated_error
self.local_error = 0.0
self.logger = get_logger(self.__class__.__name__)
self.size_fac, self.fine_fac, self.coarse_fac = 2.0 ** (1.0 / 5.0), 16 / 15, -1 / 15
self.step_taker = step_taker
def __call__(self, state: CurrentState) -> CurrentState:
keep = False
h_next_step = state.current_step_size
while not keep:
h = h_next_step
h_half = h / 2
coarse_spec = self.step_taker(state, h)
fine_spec1 = self.step_taker(state, h_half)
fine_state = state.replace(fine_spec1, z=state.z + h_half)
fine_spec = self.step_taker(fine_state, h_half)
delta = self.compute_diff(coarse_spec, fine_spec)
if delta > 2 * self.tolerated_error:
keep = False
h_next_step = h_half
elif self.tolerated_error <= delta <= 2 * self.tolerated_error:
keep = True
h_next_step = h / self.size_fac
elif 0.5 * self.tolerated_error <= delta < self.tolerated_error:
keep = True
h_next_step = h
else:
keep = True
h_next_step = h * self.size_fac
self.local_error = delta
fine_state.spectrum = fine_spec * self.fine_fac + coarse_spec * self.coarse_fac
fine_state.current_step_size = h_next_step
fine_state.previous_step_size = h
fine_state.z += h
self.last_step = h
return fine_state
def compute_diff(self, coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float:
return np.sqrt(math.abs2(coarse_spec - fine_spec).sum() / math.abs2(fine_spec).sum())
def values(self) -> dict[str, float]:
return dict(relative_error=self.local_error, h=self.last_step)
class ERK43(Integrator):
linear_operator: LinearOperator
nonlinear_operator: NonLinearOperator
dt: float
def __init__(
self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator, dt: float
):
self.linear_operator = linear_operator
self.nonlinear_operator = nonlinear_operator
self.dt = dt
def __call__(self, state: CurrentState) -> CurrentState:
keep = False
h_next_step = state.current_step_size
while not keep:
h = h_next_step
expD = np.exp(h * 0.5 * self.linear_operator(state))
A_I = expD * state.spectrum
k1 = expD * state.prev_spectrum
k2 = self.nonlinear_operator(state.replace(A_I + 0.5 * h * k1))
k3 = self.nonlinear_operator(state.replace(A_I + 0.5 * h * k2))
k4 = self.nonlinear_operator(state.replace(expD * A_I + h * k3))
r = expD * (A_I + h / 6 * (k1 + 2 * k2 + 2 * k3))
new_fine = r + h / 6 * k4
k5 = self.nonlinear_operator(state.replace(new_fine))
new_coarse = r + h / 30 * (2 * k4 + 3 * k5)

View File

@@ -14,7 +14,7 @@ from dataclasses import dataclass
from functools import cache from functools import cache
from pathlib import Path from pathlib import Path
from string import printable as str_printable from string import printable as str_printable
from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Set from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Set, Union
import numpy as np import numpy as np
import pkg_resources as pkg import pkg_resources as pkg
@@ -251,12 +251,12 @@ def load_material_dico(name: str) -> dict[str, Any]:
return tomli.loads(Paths.gets("materials"))[name] return tomli.loads(Paths.gets("materials"))[name]
def save_data(data: np.ndarray, data_dir: Path, file_name: str): def save_data(data: Union[np.ndarray, MutableMapping], data_dir: Path, file_name: str):
"""saves numpy array to disk """saves numpy array to disk
Parameters Parameters
---------- ----------
data : np.ndarray data : Union[np.ndarray, MutableMapping]
data to save data to save
file_name : str file_name : str
file name file name
@@ -266,7 +266,10 @@ def save_data(data: np.ndarray, data_dir: Path, file_name: str):
identifier in the main data folder of the task, by default "" identifier in the main data folder of the task, by default ""
""" """
path = data_dir / file_name path = data_dir / file_name
if isinstance(data, np.ndarray):
np.save(path, data) np.save(path, data)
elif isinstance(data, MutableMapping):
np.savez(path, **data)
get_logger(__name__).debug(f"saved data in {path}") get_logger(__name__).debug(f"saved data in {path}")
return return