misc
This commit is contained in:
@@ -377,8 +377,7 @@ default_rules: list[Rule] = [
|
|||||||
Rule("loss_op", operators.NoLoss, priorities=-1),
|
Rule("loss_op", operators.NoLoss, priorities=-1),
|
||||||
Rule("plasma_op", operators.NoPlasma, priorities=-1),
|
Rule("plasma_op", operators.NoPlasma, priorities=-1),
|
||||||
Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1),
|
Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1),
|
||||||
Rule("step_taker", solver.RK4IPStepTaker),
|
Rule("integrator", solver.ERK54),
|
||||||
Rule("integrator", solver.ConstantStepIntegrator, priorities=-1),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
envelope_rules = default_rules + [
|
envelope_rules = default_rules + [
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, replace
|
||||||
import re
|
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -21,37 +20,6 @@ from .utils import load_material_dico
|
|||||||
|
|
||||||
|
|
||||||
class SpectrumDescriptor:
|
class SpectrumDescriptor:
|
||||||
name: str
|
|
||||||
value: np.ndarray = None
|
|
||||||
_counter = 0
|
|
||||||
_converter: Callable[[np.ndarray], np.ndarray]
|
|
||||||
|
|
||||||
def __init__(self, spec2_name: str, field_name: str, field2_name: str):
|
|
||||||
self.spec2_name = spec2_name
|
|
||||||
self.field_name = field_name
|
|
||||||
self.field2_name = field2_name
|
|
||||||
|
|
||||||
def __set__(self, instance: CurrentState, value: np.ndarray):
|
|
||||||
self._counter += 1
|
|
||||||
setattr(instance, self.spec2_name, math.abs2(value))
|
|
||||||
setattr(instance, self.field_name, instance.converter(value))
|
|
||||||
setattr(instance, self.field2_name, math.abs2(getattr(instance, self.field_name)))
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
def __delete__(self, instance):
|
|
||||||
raise AttributeError("Cannot delete Spectrum field")
|
|
||||||
|
|
||||||
def __set_name__(self, owner, name):
|
|
||||||
for field_name in ["converter", "field", "field2", "spec2"]:
|
|
||||||
if not hasattr(owner, field_name):
|
|
||||||
raise AttributeError(f"{owner!r} doesn't have a {field_name!r} attribute")
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
|
|
||||||
class SpectrumDescriptor2:
|
|
||||||
name: str
|
name: str
|
||||||
spectrum: np.ndarray = None
|
spectrum: np.ndarray = None
|
||||||
__spec2: np.ndarray = None
|
__spec2: np.ndarray = None
|
||||||
@@ -100,14 +68,7 @@ class CurrentState:
|
|||||||
step: int
|
step: int
|
||||||
C_to_A_factor: np.ndarray
|
C_to_A_factor: np.ndarray
|
||||||
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft
|
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft
|
||||||
spectrum: np.ndarray = SpectrumDescriptor("spec2", "field", "field2")
|
solution: SpectrumDescriptor = SpectrumDescriptor()
|
||||||
spec2: np.ndarray = dataclasses.field(init=False)
|
|
||||||
field: np.ndarray = dataclasses.field(init=False)
|
|
||||||
field2: np.ndarray = dataclasses.field(init=False)
|
|
||||||
prev_spectrum: np.ndarray = SpectrumDescriptor("prev_spec2", "prev_field", "prev_field2")
|
|
||||||
prev_spec2: np.ndarray = dataclasses.field(init=False)
|
|
||||||
prev_field: np.ndarray = dataclasses.field(init=False)
|
|
||||||
prev_field2: np.ndarray = dataclasses.field(init=False)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def z_ratio(self) -> float:
|
def z_ratio(self) -> float:
|
||||||
@@ -116,7 +77,7 @@ class CurrentState:
|
|||||||
def replace(self, new_spectrum: np.ndarray, **new_params) -> CurrentState:
|
def replace(self, new_spectrum: np.ndarray, **new_params) -> CurrentState:
|
||||||
"""returns a new state with new attributes"""
|
"""returns a new state with new attributes"""
|
||||||
params = dict(
|
params = dict(
|
||||||
spectrum=new_spectrum,
|
solution=new_spectrum,
|
||||||
length=self.length,
|
length=self.length,
|
||||||
z=self.z,
|
z=self.z,
|
||||||
current_step_size=self.current_step_size,
|
current_step_size=self.current_step_size,
|
||||||
@@ -127,6 +88,13 @@ class CurrentState:
|
|||||||
)
|
)
|
||||||
return CurrentState(**(params | new_params))
|
return CurrentState(**(params | new_params))
|
||||||
|
|
||||||
|
def copy(self) -> CurrentState:
|
||||||
|
return replace(self, solution=self.solution.spectrum)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def actual_spectrum(self) -> np.ndarray:
|
||||||
|
return self.C_to_A_factor * self.solution.spectrum
|
||||||
|
|
||||||
|
|
||||||
class ValueTracker(ABC):
|
class ValueTracker(ABC):
|
||||||
def values(self) -> dict[str, float]:
|
def values(self) -> dict[str, float]:
|
||||||
@@ -644,7 +612,7 @@ class EnvelopeRaman(AbstractRaman):
|
|||||||
self.f_r = 0.245 if raman_type == "agrawal" else 0.18
|
self.f_r = 0.245 if raman_type == "agrawal" else 0.18
|
||||||
|
|
||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||||
return self.f_r * np.fft.ifft(self.hr_w * np.fft.fft(state.field2))
|
return self.f_r * np.fft.ifft(self.hr_w * np.fft.fft(state.solution.field2))
|
||||||
|
|
||||||
|
|
||||||
class FullFieldRaman(AbstractRaman):
|
class FullFieldRaman(AbstractRaman):
|
||||||
@@ -654,7 +622,7 @@ class FullFieldRaman(AbstractRaman):
|
|||||||
self.multiplier = units.epsilon0 * chi3 * self.f_r
|
self.multiplier = units.epsilon0 * chi3 * self.f_r
|
||||||
|
|
||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||||
return self.multiplier * np.fft.ifft(np.fft.fft(state.field2) * self.hr_w)
|
return self.multiplier * np.fft.ifft(np.fft.fft(state.solution.field2) * self.hr_w)
|
||||||
|
|
||||||
|
|
||||||
##################################################
|
##################################################
|
||||||
@@ -693,7 +661,7 @@ class EnvelopeSPM(AbstractSPM):
|
|||||||
self.fraction = 1 - raman_op.f_r
|
self.fraction = 1 - raman_op.f_r
|
||||||
|
|
||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||||
return self.fraction * state.field2
|
return self.fraction * state.solution.field2
|
||||||
|
|
||||||
|
|
||||||
class FullFieldSPM(AbstractSPM):
|
class FullFieldSPM(AbstractSPM):
|
||||||
@@ -702,7 +670,7 @@ class FullFieldSPM(AbstractSPM):
|
|||||||
self.factor = self.fraction * chi3 * units.epsilon0
|
self.factor = self.fraction * chi3 * units.epsilon0
|
||||||
|
|
||||||
def __call__(self, state: CurrentState) -> np.ndarray:
|
def __call__(self, state: CurrentState) -> np.ndarray:
|
||||||
return self.factor * state.field2 * state.field
|
return self.factor * state.solution.field2 * state.solution.field
|
||||||
|
|
||||||
|
|
||||||
##################################################
|
##################################################
|
||||||
@@ -819,7 +787,7 @@ class PhotonNumberLoss(AbstractConservedQuantity):
|
|||||||
|
|
||||||
def __call__(self, state: CurrentState) -> float:
|
def __call__(self, state: CurrentState) -> float:
|
||||||
return pulse.photon_number_with_loss(
|
return pulse.photon_number_with_loss(
|
||||||
state.spec2,
|
state.solution.spec2,
|
||||||
self.w,
|
self.w,
|
||||||
self.dw,
|
self.dw,
|
||||||
self.gamma_op(state),
|
self.gamma_op(state),
|
||||||
@@ -835,7 +803,7 @@ class PhotonNumberNoLoss(AbstractConservedQuantity):
|
|||||||
self.gamma_op = gamma_op
|
self.gamma_op = gamma_op
|
||||||
|
|
||||||
def __call__(self, state: CurrentState) -> float:
|
def __call__(self, state: CurrentState) -> float:
|
||||||
return pulse.photon_number(state.spec2, self.w, self.dw, self.gamma_op(state))
|
return pulse.photon_number(state.solution.spec2, self.w, self.dw, self.gamma_op(state))
|
||||||
|
|
||||||
|
|
||||||
class EnergyLoss(AbstractConservedQuantity):
|
class EnergyLoss(AbstractConservedQuantity):
|
||||||
@@ -845,7 +813,7 @@ class EnergyLoss(AbstractConservedQuantity):
|
|||||||
|
|
||||||
def __call__(self, state: CurrentState) -> float:
|
def __call__(self, state: CurrentState) -> float:
|
||||||
return pulse.pulse_energy_with_loss(
|
return pulse.pulse_energy_with_loss(
|
||||||
math.abs2(state.C_to_A_factor * state.spectrum),
|
math.abs2(state.C_to_A_factor * state.solution.spectrum),
|
||||||
self.dw,
|
self.dw,
|
||||||
self.loss_op(state),
|
self.loss_op(state),
|
||||||
state.current_step_size,
|
state.current_step_size,
|
||||||
@@ -857,7 +825,7 @@ class EnergyNoLoss(AbstractConservedQuantity):
|
|||||||
self.dw = w[1] - w[0]
|
self.dw = w[1] - w[0]
|
||||||
|
|
||||||
def __call__(self, state: CurrentState) -> float:
|
def __call__(self, state: CurrentState) -> float:
|
||||||
return pulse.pulse_energy(math.abs2(state.C_to_A_factor * state.spectrum), self.dw)
|
return pulse.pulse_energy(math.abs2(state.C_to_A_factor * state.solution.spectrum), self.dw)
|
||||||
|
|
||||||
|
|
||||||
def conserved_quantity(
|
def conserved_quantity(
|
||||||
@@ -970,7 +938,7 @@ class EnvelopeNonLinearOperator(NonLinearOperator):
|
|||||||
-1j
|
-1j
|
||||||
* self.gamma_op(state)
|
* self.gamma_op(state)
|
||||||
* (1 + self.ss_op(state))
|
* (1 + self.ss_op(state))
|
||||||
* np.fft.fft(state.field * (self.spm_op(state) + self.raman_op(state)))
|
* np.fft.fft(state.solution.field * (self.spm_op(state) + self.raman_op(state)))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from collections import defaultdict
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import multiprocessing.connection
|
import multiprocessing.connection
|
||||||
import os
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generator, Optional, Type, Union
|
from typing import Any, Generator, Iterator, Optional, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .. import utils
|
from .. import solver, utils
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..operators import CurrentState
|
from ..operators import CurrentState
|
||||||
from ..parameter import Configuration, Parameters
|
from ..parameter import Configuration, Parameters
|
||||||
@@ -45,7 +45,7 @@ class RK4IP:
|
|||||||
size_fac: float
|
size_fac: float
|
||||||
cons_qty: list[float]
|
cons_qty: list[float]
|
||||||
|
|
||||||
state: CurrentState
|
init_state: CurrentState
|
||||||
stored_spectra: list[np.ndarray]
|
stored_spectra: list[np.ndarray]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -96,7 +96,7 @@ class RK4IP:
|
|||||||
initial_h = (self.z_targets[1] - self.z_targets[0]) / 2
|
initial_h = (self.z_targets[1] - self.z_targets[0]) / 2
|
||||||
else:
|
else:
|
||||||
initial_h = self.error_ok
|
initial_h = self.error_ok
|
||||||
self.state = CurrentState(
|
self.init_state = CurrentState(
|
||||||
length=self.params.length,
|
length=self.params.length,
|
||||||
z=self.z_targets.pop(0),
|
z=self.z_targets.pop(0),
|
||||||
current_step_size=initial_h,
|
current_step_size=initial_h,
|
||||||
@@ -104,14 +104,14 @@ class RK4IP:
|
|||||||
step=1,
|
step=1,
|
||||||
C_to_A_factor=C_to_A_factor,
|
C_to_A_factor=C_to_A_factor,
|
||||||
converter=self.params.ifft,
|
converter=self.params.ifft,
|
||||||
spectrum=self.params.spec_0.copy() / C_to_A_factor,
|
solution=self.params.spec_0.copy() / C_to_A_factor,
|
||||||
)
|
)
|
||||||
self.stored_spectra = self.params.recovery_last_stored * [None] + [
|
self.stored_spectra = self.params.recovery_last_stored * [None] + [
|
||||||
self.state.spectrum.copy()
|
self.init_state.solution.spectrum.copy()
|
||||||
]
|
]
|
||||||
self.tracked_values = TrackedValues()
|
self.tracked_values = TrackedValues()
|
||||||
|
|
||||||
def _save_current_spectrum(self, num: int):
|
def _save_current_spectrum(self, spectrum: np.ndarray, num: int):
|
||||||
"""saves the spectrum and the corresponding cons_qty array
|
"""saves the spectrum and the corresponding cons_qty array
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -119,19 +119,8 @@ class RK4IP:
|
|||||||
num : int
|
num : int
|
||||||
index of the z postition
|
index of the z postition
|
||||||
"""
|
"""
|
||||||
self.write(self.get_current_spectrum(), f"spectrum_{num}")
|
self.write(spectrum, f"spectrum_{num}")
|
||||||
self.write(self.tracked_values, "tracked_values")
|
self.write(self.tracked_values, "tracked_values")
|
||||||
self.step_saved()
|
|
||||||
|
|
||||||
def get_current_spectrum(self) -> np.ndarray:
|
|
||||||
"""returns the current spectrum
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
spectrum
|
|
||||||
"""
|
|
||||||
return self.state.C_to_A_factor * self.state.spectrum
|
|
||||||
|
|
||||||
def write(self, data: np.ndarray, name: str):
|
def write(self, data: np.ndarray, name: str):
|
||||||
"""calls the appropriate method to save data
|
"""calls the appropriate method to save data
|
||||||
@@ -147,14 +136,15 @@ class RK4IP:
|
|||||||
|
|
||||||
def run(self) -> list[np.ndarray]:
|
def run(self) -> list[np.ndarray]:
|
||||||
time_start = datetime.today()
|
time_start = datetime.today()
|
||||||
|
state = self.init_state
|
||||||
for step, num, _ in self.irun():
|
for num, state in self.irun():
|
||||||
if self.save_data:
|
if self.save_data:
|
||||||
self._save_current_spectrum(num)
|
self._save_current_spectrum(state.actual_spectrum, num)
|
||||||
|
self.step_saved(state)
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"propagation finished in {} steps ({} seconds)".format(
|
"propagation finished in {} steps ({} seconds)".format(
|
||||||
step, (datetime.today() - time_start).total_seconds()
|
state.step, (datetime.today() - time_start).total_seconds()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -163,49 +153,49 @@ class RK4IP:
|
|||||||
|
|
||||||
return self.stored_spectra
|
return self.stored_spectra
|
||||||
|
|
||||||
def irun(self) -> Generator[tuple[int, int, np.ndarray], None, None]:
|
def irun(self) -> Iterator[tuple[int, CurrentState]]:
|
||||||
"""run the simulation as a generator obj
|
"""run the simulation as a generator obj
|
||||||
|
|
||||||
Yields
|
Yields
|
||||||
-------
|
-------
|
||||||
int
|
|
||||||
current simulation step
|
|
||||||
int
|
int
|
||||||
current number of spectra returned
|
current number of spectra returned
|
||||||
np.ndarray
|
CurrentState
|
||||||
spectrum
|
current simulation state
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
"Computing {} new spectra, first one at {}m".format(self.store_num, self.z_targets[0])
|
"Computing {} new spectra, first one at {}m".format(self.store_num, self.z_targets[0])
|
||||||
)
|
)
|
||||||
store = False
|
store = False
|
||||||
|
state = self.init_state.copy()
|
||||||
|
yield len(self.stored_spectra) - 1, state
|
||||||
|
integrator = solver.ERK54(
|
||||||
|
state,
|
||||||
|
self.params.linear_operator,
|
||||||
|
self.params.nonlinear_operator,
|
||||||
|
self.params.tolerated_error,
|
||||||
|
self.params.dt,
|
||||||
|
)
|
||||||
|
for state in integrator:
|
||||||
|
|
||||||
yield self.state.step, len(self.stored_spectra) - 1, self.get_current_spectrum()
|
new_tracked_values = integrator.all_values()
|
||||||
|
self.logger.debug(f"tracked values at z={state.z} : {new_tracked_values}")
|
||||||
while self.state.z < self.params.length:
|
|
||||||
self.state = self.params.integrator(self.state)
|
|
||||||
|
|
||||||
self.state.step += 1
|
|
||||||
new_tracked_values = (
|
|
||||||
dict(step=self.state.step, z=self.state.z) | self.params.integrator.all_values()
|
|
||||||
)
|
|
||||||
self.logger.debug(f"tracked values at z={self.state.z} : {new_tracked_values}")
|
|
||||||
self.tracked_values.append(new_tracked_values)
|
self.tracked_values.append(new_tracked_values)
|
||||||
|
|
||||||
# Whether the current spectrum has to be stored depends on previous step
|
# Whether the current spectrum has to be stored depends on previous step
|
||||||
if store:
|
if store:
|
||||||
current_spec = self.get_current_spectrum()
|
current_spec = state.actual_spectrum
|
||||||
self.stored_spectra.append(current_spec)
|
self.stored_spectra.append(current_spec)
|
||||||
|
|
||||||
yield self.state.step, len(self.stored_spectra) - 1, current_spec
|
yield len(self.stored_spectra) - 1, state.copy()
|
||||||
|
|
||||||
self.z_stored.append(self.state.z)
|
self.z_stored.append(state.z)
|
||||||
del self.z_targets[0]
|
del self.z_targets[0]
|
||||||
|
|
||||||
# reset the constant step size after a spectrum is stored
|
# reset the constant step size after a spectrum is stored
|
||||||
if not self.params.adapt_step_size:
|
if not self.params.adapt_step_size:
|
||||||
self.state.current_step_size = self.error_ok
|
integrator.state.current_step_size = self.error_ok
|
||||||
|
|
||||||
if len(self.z_targets) == 0:
|
if len(self.z_targets) == 0:
|
||||||
break
|
break
|
||||||
@@ -213,14 +203,14 @@ class RK4IP:
|
|||||||
|
|
||||||
# if the next step goes over a position at which we want to store
|
# if the next step goes over a position at which we want to store
|
||||||
# a spectrum, we shorten the step to reach this position exactly
|
# a spectrum, we shorten the step to reach this position exactly
|
||||||
if self.state.z + self.state.current_step_size >= self.z_targets[0]:
|
if state.z + integrator.state.current_step_size >= self.z_targets[0]:
|
||||||
store = True
|
store = True
|
||||||
self.state.current_step_size = self.z_targets[0] - self.state.z
|
integrator.state.current_step_size = self.z_targets[0] - state.z
|
||||||
|
|
||||||
def step_saved(self):
|
def step_saved(self, state: CurrentState):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __iter__(self) -> Generator[tuple[int, int, np.ndarray], None, None]:
|
def __iter__(self) -> Iterator[tuple[int, CurrentState]]:
|
||||||
yield from self.irun()
|
yield from self.irun()
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@@ -240,8 +230,8 @@ class SequentialRK4IP(RK4IP):
|
|||||||
save_data=save_data,
|
save_data=save_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def step_saved(self):
|
def step_saved(self, state: CurrentState):
|
||||||
self.pbars.update(1, self.state.z / self.params.length - self.pbars[1].n)
|
self.pbars.update(1, state.z / self.params.length - self.pbars[1].n)
|
||||||
|
|
||||||
|
|
||||||
class MutliProcRK4IP(RK4IP):
|
class MutliProcRK4IP(RK4IP):
|
||||||
@@ -259,8 +249,8 @@ class MutliProcRK4IP(RK4IP):
|
|||||||
save_data=save_data,
|
save_data=save_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def step_saved(self):
|
def step_saved(self, state: CurrentState):
|
||||||
self.p_queue.put((self.worker_id, self.state.z / self.params.length))
|
self.p_queue.put((self.worker_id, state.z / self.params.length))
|
||||||
|
|
||||||
|
|
||||||
class RayRK4IP(RK4IP):
|
class RayRK4IP(RK4IP):
|
||||||
@@ -286,8 +276,8 @@ class RayRK4IP(RK4IP):
|
|||||||
self.set(params, p_actor, worker_id, save_data)
|
self.set(params, p_actor, worker_id, save_data)
|
||||||
self.run()
|
self.run()
|
||||||
|
|
||||||
def step_saved(self):
|
def step_saved(self, state: CurrentState):
|
||||||
self.p_actor.update.remote(self.worker_id, self.state.z / self.params.length)
|
self.p_actor.update.remote(self.worker_id, state.z / self.params.length)
|
||||||
self.p_actor.update.remote(0)
|
self.p_actor.update.remote(0)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -69,7 +70,7 @@ class RK4IPStepTaker(StepTaker):
|
|||||||
l0, nl0 = self.cached_values(state)
|
l0, nl0 = self.cached_values(state)
|
||||||
expD = np.exp(h * self.c2 * l0)
|
expD = np.exp(h * self.c2 * l0)
|
||||||
|
|
||||||
A_I = expD * state.spectrum
|
A_I = expD * state.solution
|
||||||
k1 = expD * (h * nl0)
|
k1 = expD * (h * nl0)
|
||||||
k2 = h * self.nonlinear_operator(state.replace(A_I + k1 * self.c2))
|
k2 = h * self.nonlinear_operator(state.replace(A_I + k1 * self.c2))
|
||||||
k3 = h * self.nonlinear_operator(state.replace(A_I + k2 * self.c2))
|
k3 = h * self.nonlinear_operator(state.replace(A_I + k2 * self.c2))
|
||||||
@@ -114,14 +115,37 @@ class RK4IPStepTaker(StepTaker):
|
|||||||
|
|
||||||
|
|
||||||
class Integrator(ValueTracker):
|
class Integrator(ValueTracker):
|
||||||
last_step = 0.0
|
linear_operator: LinearOperator
|
||||||
|
nonlinear_operator: NonLinearOperator
|
||||||
|
_tracked_values: dict[str, float]
|
||||||
|
|
||||||
|
def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator):
|
||||||
|
self.linear_operator = linear_operator
|
||||||
|
self.nonlinear_operator = nonlinear_operator
|
||||||
|
self._tracked_values = {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, state: CurrentState) -> CurrentState:
|
def __iter__(self) -> Iterator[CurrentState]:
|
||||||
"""propagate the state with a step size of state.current_step_size
|
"""propagate the state with a step size of state.current_step_size
|
||||||
and return a new state with updated z and previous_step_size attributes"""
|
and yield a new state with updated z and previous_step_size attributes"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def all_values(self) -> dict[str, float]:
|
||||||
|
"""override ValueTracker.all_values to account for the fact that operators are called
|
||||||
|
multiple times per step, sometimes with different state, so we use value recorded
|
||||||
|
earlier. Please call self.recorde_tracked_values() one time only just after calling
|
||||||
|
the linear and nonlinear operators in your StepTaker.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict[str, float]
|
||||||
|
tracked values
|
||||||
|
"""
|
||||||
|
return self.values() | self._tracked_values
|
||||||
|
|
||||||
|
def record_tracked_values(self):
|
||||||
|
self._tracked_values = super().all_values()
|
||||||
|
|
||||||
|
|
||||||
class ConstantStepIntegrator(Integrator):
|
class ConstantStepIntegrator(Integrator):
|
||||||
def __call__(self, state: CurrentState) -> CurrentState:
|
def __call__(self, state: CurrentState) -> CurrentState:
|
||||||
@@ -234,7 +258,7 @@ class LocalErrorIntegrator(Integrator):
|
|||||||
h_next_step = h * self.size_fac
|
h_next_step = h * self.size_fac
|
||||||
|
|
||||||
self.local_error = delta
|
self.local_error = delta
|
||||||
fine_state.spectrum = fine_spec * self.fine_fac + coarse_spec * self.coarse_fac
|
fine_state.solution = fine_spec * self.fine_fac + coarse_spec * self.coarse_fac
|
||||||
fine_state.current_step_size = h_next_step
|
fine_state.current_step_size = h_next_step
|
||||||
fine_state.previous_step_size = h
|
fine_state.previous_step_size = h
|
||||||
fine_state.z += h
|
fine_state.z += h
|
||||||
@@ -249,32 +273,122 @@ class LocalErrorIntegrator(Integrator):
|
|||||||
|
|
||||||
|
|
||||||
class ERK43(Integrator):
|
class ERK43(Integrator):
|
||||||
|
state: CurrentState
|
||||||
linear_operator: LinearOperator
|
linear_operator: LinearOperator
|
||||||
nonlinear_operator: NonLinearOperator
|
nonlinear_operator: NonLinearOperator
|
||||||
|
tolerated_error: float
|
||||||
dt: float
|
dt: float
|
||||||
|
current_error: float
|
||||||
|
next_h_factor = 1.0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator, dt: float
|
self,
|
||||||
|
init_state: CurrentState,
|
||||||
|
linear_operator: LinearOperator,
|
||||||
|
nonlinear_operator: NonLinearOperator,
|
||||||
|
tolerated_error: float,
|
||||||
|
dt: float,
|
||||||
):
|
):
|
||||||
|
self.state = init_state
|
||||||
self.linear_operator = linear_operator
|
self.linear_operator = linear_operator
|
||||||
self.nonlinear_operator = nonlinear_operator
|
self.nonlinear_operator = nonlinear_operator
|
||||||
self.dt = dt
|
self.dt = dt
|
||||||
|
self.tolerated_error = tolerated_error
|
||||||
|
self.current_error = 0.0
|
||||||
|
|
||||||
def __call__(self, state: CurrentState) -> CurrentState:
|
def __iter__(self) -> Iterator[CurrentState]:
|
||||||
keep = False
|
h_next_step = self.state.current_step_size
|
||||||
h_next_step = state.current_step_size
|
k5 = self.nonlinear_operator(self.state)
|
||||||
while not keep:
|
while True:
|
||||||
h = h_next_step
|
lin = self.linear_operator(self.state)
|
||||||
expD = np.exp(h * 0.5 * self.linear_operator(state))
|
self.record_tracked_values()
|
||||||
A_I = expD * state.spectrum
|
while True:
|
||||||
k1 = expD * state.prev_spectrum
|
h = h_next_step
|
||||||
k2 = self.nonlinear_operator(state.replace(A_I + 0.5 * h * k1))
|
expD = np.exp(h * 0.5 * lin)
|
||||||
k3 = self.nonlinear_operator(state.replace(A_I + 0.5 * h * k2))
|
A_I = expD * self.state.solution.spectrum
|
||||||
k4 = self.nonlinear_operator(state.replace(expD * A_I + h * k3))
|
k1 = expD * k5
|
||||||
r = expD * (A_I + h / 6 * (k1 + 2 * k2 + 2 * k3))
|
k2 = self.nl(A_I + 0.5 * h * k1)
|
||||||
|
k3 = self.nl(A_I + 0.5 * h * k2)
|
||||||
|
k4 = self.nl(expD * A_I + h * k3)
|
||||||
|
r = expD * (A_I + h / 6 * (k1 + 2 * k2 + 2 * k3))
|
||||||
|
|
||||||
new_fine = r + h / 6 * k4
|
new_fine = r + h / 6 * k4
|
||||||
|
|
||||||
k5 = self.nonlinear_operator(state.replace(new_fine))
|
tmp_k5 = self.nl(new_fine)
|
||||||
|
|
||||||
new_coarse = r + h / 30 * (2 * k4 + 3 * k5)
|
new_coarse = r + h / 30 * (2 * k4 + 3 * tmp_k5)
|
||||||
|
|
||||||
|
self.current_error = np.sqrt(self.dt * math.abs2(new_fine - new_coarse).sum())
|
||||||
|
self.next_h_factor = max(
|
||||||
|
0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25)
|
||||||
|
)
|
||||||
|
h_next_step = self.next_h_factor * h
|
||||||
|
if self.current_error <= self.tolerated_error:
|
||||||
|
break
|
||||||
|
self.state.current_step_size = h_next_step
|
||||||
|
self.state.previous_step_size = h
|
||||||
|
self.state.z += h
|
||||||
|
self.state.step += 1
|
||||||
|
self.state.solution = new_fine
|
||||||
|
k5 = tmp_k5
|
||||||
|
yield self.state
|
||||||
|
|
||||||
|
def values(self) -> dict[str, float]:
|
||||||
|
return dict(
|
||||||
|
step=self.state.step,
|
||||||
|
z=self.state.z,
|
||||||
|
local_error=self.current_error,
|
||||||
|
next_h_factor=self.next_h_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def nl(self, spectrum: np.ndarray) -> np.ndarray:
|
||||||
|
return self.nonlinear_operator(self.state.replace(spectrum))
|
||||||
|
|
||||||
|
|
||||||
|
class ERK54(ERK43):
|
||||||
|
def __iter__(self) -> Iterator[CurrentState]:
|
||||||
|
print("using ERK54")
|
||||||
|
h_next_step = self.state.current_step_size
|
||||||
|
k7 = self.nonlinear_operator(self.state)
|
||||||
|
while True:
|
||||||
|
lin = self.linear_operator(self.state)
|
||||||
|
self.record_tracked_values()
|
||||||
|
while True:
|
||||||
|
h = h_next_step
|
||||||
|
expD2 = np.exp(h * 0.5 * lin)
|
||||||
|
expD4p = np.exp(h * 0.25 * lin)
|
||||||
|
expD4m = 1 / expD4p
|
||||||
|
|
||||||
|
A_I = expD2 * self.state.solution.spectrum
|
||||||
|
k1 = expD2 * k7
|
||||||
|
k2 = self.nl(A_I + 0.5 * h * k1)
|
||||||
|
k3 = expD4p * self.nl(expD4m * (A_I + h / 16 * (3 * k1 + k2)))
|
||||||
|
k4 = self.nl(A_I + h / 4 * (k1 - k2 + 4 * k3))
|
||||||
|
k5 = expD4m * self.nl(expD4p * (A_I + 3 * h / 16 * (k1 + 3 * k4)))
|
||||||
|
k6 = self.nl(expD2 * (A_I + h / 7 * (-2 * k1 + k2 + 12 * k3 - 12 * k4 + 8 * k5)))
|
||||||
|
|
||||||
|
new_fine = (
|
||||||
|
expD2 * (A_I + h / 90 * (7 * k1 + 32 * k3 + 12 * k4 + 32 * k5))
|
||||||
|
+ 7 * h / 90 * k6
|
||||||
|
)
|
||||||
|
tmp_k7 = self.nl(new_fine)
|
||||||
|
new_coarse = (
|
||||||
|
expD2 * (A_I + h / 42 * (3 * k1 + 16 * k3 + 4 * k4 + 16 * k5)) + h / 14 * k7
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_error = np.sqrt(
|
||||||
|
self.dt * math.abs2(np.abs(new_fine) - np.abs(new_coarse)).sum()
|
||||||
|
)
|
||||||
|
self.next_h_factor = max(
|
||||||
|
0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25)
|
||||||
|
)
|
||||||
|
h_next_step = self.next_h_factor * h
|
||||||
|
if self.current_error <= self.tolerated_error:
|
||||||
|
break
|
||||||
|
self.state.current_step_size = h_next_step
|
||||||
|
self.state.previous_step_size = h
|
||||||
|
self.state.z += h
|
||||||
|
self.state.step += 1
|
||||||
|
self.state.solution = new_fine
|
||||||
|
k7 = tmp_k7
|
||||||
|
yield self.state
|
||||||
|
|||||||
Reference in New Issue
Block a user