got rid of SpectrumDescriptor

This commit is contained in:
Benoît Sierro
2021-11-12 09:47:37 +01:00
parent 4b9425b519
commit 1f8685048b
3 changed files with 99 additions and 83 deletions

View File

@@ -4,9 +4,8 @@ Nothing except the solver should depend on this file
""" """
from __future__ import annotations from __future__ import annotations
import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, replace from dataclasses import dataclass
from typing import Any, Callable from typing import Any, Callable
import numpy as np import numpy as np
@@ -18,17 +17,51 @@ from .physics import fiber, materials, pulse, units
from .utils import load_material_dico from .utils import load_material_dico
class SpectrumDescriptor: class CurrentState:
name: str length: float
spectrum: np.ndarray = None z: float
current_step_size: float
step: int
C_to_A_factor: np.ndarray
converter: Callable[[np.ndarray], np.ndarray]
__spectrum: np.ndarray = None
__spec2: np.ndarray = None __spec2: np.ndarray = None
__field: np.ndarray = None __field: np.ndarray = None
__field2: np.ndarray = None __field2: np.ndarray = None
_converter: Callable[[np.ndarray], np.ndarray]
def __set__(self, instance: CurrentState, value: np.ndarray): def __init__(
self._converter = instance.converter self,
self.spectrum = value length: float,
z: float,
current_step_size: float,
step: int,
spectrum: np.ndarray,
C_to_A_factor: np.ndarray,
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft,
):
self.length = length
self.z = z
self.current_step_size = current_step_size
self.step = step
self.C_to_A_factor = C_to_A_factor
self.converter = converter
self.__spectrum = spectrum
@property
def z_ratio(self) -> float:
return self.z / self.length
@property
def actual_spectrum(self) -> np.ndarray:
return self.C_to_A_factor * self.spectrum
@property
def spectrum(self) -> np.ndarray:
return self.__spectrum
@spectrum.setter
def spectrum(self, new_value: np.ndarray):
self.__spectrum = new_value
self.__spec2 = None self.__spec2 = None
self.__field = None self.__field = None
self.__field2 = None self.__field2 = None
@@ -42,7 +75,7 @@ class SpectrumDescriptor:
@property @property
def field(self) -> np.ndarray: def field(self) -> np.ndarray:
if self.__field is None: if self.__field is None:
self.__field = self._converter(self.spectrum) self.__field = self.converter(self.spectrum)
return self.__field return self.__field
@property @property
@@ -69,37 +102,16 @@ class SpectrumDescriptor:
self.__field = field self.__field = field
self.__field2 = field2 self.__field2 = field2
def __delete__(self, instance):
raise AttributeError("Cannot delete Spectrum field")
def __set_name__(self, owner, name):
self.name = name
@dataclasses.dataclass
class CurrentState:
length: float
z: float
current_step_size: float
step: int
C_to_A_factor: np.ndarray
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft
solution: SpectrumDescriptor = SpectrumDescriptor()
@property
def z_ratio(self) -> float:
return self.z / self.length
def replace(self, new_spectrum: np.ndarray) -> CurrentState: def replace(self, new_spectrum: np.ndarray) -> CurrentState:
"""returns a new state with new attributes""" """returns a new state with new attributes"""
return CurrentState( return CurrentState(
self.length, length=self.length,
self.z, z=self.z,
self.current_step_size, current_step_size=self.current_step_size,
self.step, step=self.step,
self.C_to_A_factor, C_to_A_factor=self.C_to_A_factor,
self.converter, converter=self.converter,
new_spectrum, spectrum=new_spectrum,
) )
def with_params(self, **params) -> CurrentState: def with_params(self, **params) -> CurrentState:
@@ -112,18 +124,22 @@ class CurrentState:
C_to_A_factor=self.C_to_A_factor, C_to_A_factor=self.C_to_A_factor,
converter=self.converter, converter=self.converter,
) )
new_state = CurrentState(solution=self.solution.spectrum, **(my_params | params)) new_state = CurrentState(spectrum=self.__spectrum, **(my_params | params))
new_state.solution.force_values( new_state.force_values(self.spec2, self.field, self.field2)
self.solution.spec2, self.solution.field, self.solution.field2
)
return new_state return new_state
def copy(self) -> CurrentState: def copy(self) -> CurrentState:
return replace(self, solution=self.solution.spectrum) new = CurrentState(
length=self.length,
@property z=self.z,
def actual_spectrum(self) -> np.ndarray: current_step_size=self.current_step_size,
return self.C_to_A_factor * self.solution.spectrum step=self.step,
C_to_A_factor=self.C_to_A_factor,
converter=self.converter,
spectrum=self.__spectrum,
)
new.force_values(self.__spec2, self.__field, self.__field2)
return new
class ValueTracker(ABC): class ValueTracker(ABC):
@@ -634,7 +650,7 @@ class EnvelopeRaman(AbstractRaman):
self.f_r = 0.245 if raman_type == "agrawal" else 0.18 self.f_r = 0.245 if raman_type == "agrawal" else 0.18
def __call__(self, state: CurrentState) -> np.ndarray: def __call__(self, state: CurrentState) -> np.ndarray:
return self.f_r * np.fft.ifft(self.hr_w * np.fft.fft(state.solution.field2)) return self.f_r * np.fft.ifft(self.hr_w * np.fft.fft(state.field2))
class FullFieldRaman(AbstractRaman): class FullFieldRaman(AbstractRaman):
@@ -644,7 +660,7 @@ class FullFieldRaman(AbstractRaman):
self.multiplier = units.epsilon0 * chi3 * self.f_r self.multiplier = units.epsilon0 * chi3 * self.f_r
def __call__(self, state: CurrentState) -> np.ndarray: def __call__(self, state: CurrentState) -> np.ndarray:
return self.multiplier * np.fft.ifft(np.fft.fft(state.solution.field2) * self.hr_w) return self.multiplier * np.fft.ifft(np.fft.fft(state.field2) * self.hr_w)
################################################## ##################################################
@@ -683,7 +699,7 @@ class EnvelopeSPM(AbstractSPM):
self.fraction = 1 - raman_op.f_r self.fraction = 1 - raman_op.f_r
def __call__(self, state: CurrentState) -> np.ndarray: def __call__(self, state: CurrentState) -> np.ndarray:
return self.fraction * state.solution.field2 return self.fraction * state.field2
class FullFieldSPM(AbstractSPM): class FullFieldSPM(AbstractSPM):
@@ -692,7 +708,7 @@ class FullFieldSPM(AbstractSPM):
self.factor = self.fraction * chi3 * units.epsilon0 self.factor = self.fraction * chi3 * units.epsilon0
def __call__(self, state: CurrentState) -> np.ndarray: def __call__(self, state: CurrentState) -> np.ndarray:
return self.factor * state.solution.field2 * state.solution.field return self.factor * state.field2 * state.field
################################################## ##################################################
@@ -809,7 +825,7 @@ class PhotonNumberLoss(AbstractConservedQuantity):
def __call__(self, state: CurrentState) -> float: def __call__(self, state: CurrentState) -> float:
return pulse.photon_number_with_loss( return pulse.photon_number_with_loss(
state.solution.spec2, state.spec2,
self.w, self.w,
self.dw, self.dw,
self.gamma_op(state), self.gamma_op(state),
@@ -825,7 +841,7 @@ class PhotonNumberNoLoss(AbstractConservedQuantity):
self.gamma_op = gamma_op self.gamma_op = gamma_op
def __call__(self, state: CurrentState) -> float: def __call__(self, state: CurrentState) -> float:
return pulse.photon_number(state.solution.spec2, self.w, self.dw, self.gamma_op(state)) return pulse.photon_number(state.spec2, self.w, self.dw, self.gamma_op(state))
class EnergyLoss(AbstractConservedQuantity): class EnergyLoss(AbstractConservedQuantity):
@@ -835,7 +851,7 @@ class EnergyLoss(AbstractConservedQuantity):
def __call__(self, state: CurrentState) -> float: def __call__(self, state: CurrentState) -> float:
return pulse.pulse_energy_with_loss( return pulse.pulse_energy_with_loss(
math.abs2(state.C_to_A_factor * state.solution.spectrum), math.abs2(state.C_to_A_factor * state.spectrum),
self.dw, self.dw,
self.loss_op(state), self.loss_op(state),
state.current_step_size, state.current_step_size,
@@ -847,7 +863,7 @@ class EnergyNoLoss(AbstractConservedQuantity):
self.dw = w[1] - w[0] self.dw = w[1] - w[0]
def __call__(self, state: CurrentState) -> float: def __call__(self, state: CurrentState) -> float:
return pulse.pulse_energy(math.abs2(state.C_to_A_factor * state.solution.spectrum), self.dw) return pulse.pulse_energy(math.abs2(state.C_to_A_factor * state.spectrum), self.dw)
def conserved_quantity( def conserved_quantity(
@@ -960,7 +976,7 @@ class EnvelopeNonLinearOperator(NonLinearOperator):
-1j -1j
* self.gamma_op(state) * self.gamma_op(state)
* (1 + self.ss_op(state)) * (1 + self.ss_op(state))
* np.fft.fft(state.solution.field * (self.spm_op(state) + self.raman_op(state))) * np.fft.fft(state.field * (self.spm_op(state) + self.raman_op(state)))
) )

View File

@@ -97,10 +97,10 @@ class RK4IP:
step=0, step=0,
C_to_A_factor=self.params.c_to_a_factor, C_to_A_factor=self.params.c_to_a_factor,
converter=self.params.ifft, converter=self.params.ifft,
solution=self.params.spec_0.copy() / self.params.c_to_a_factor, spectrum=self.params.spec_0.copy() / self.params.c_to_a_factor,
) )
self.stored_spectra = self.params.recovery_last_stored * [None] + [ self.stored_spectra = self.params.recovery_last_stored * [None] + [
self.init_state.solution.spectrum.copy() self.init_state.spectrum.copy()
] ]
self.tracked_values = TrackedValues() self.tracked_values = TrackedValues()
@@ -165,7 +165,7 @@ class RK4IP:
yield len(self.stored_spectra) - 1, state yield len(self.stored_spectra) - 1, state
if self.params.adapt_step_size: if self.params.adapt_step_size:
integrator = solver.ConservedQuantityIntegrator( integrator = solver.ConservedQuantityIntegrator(
self.init_state, state,
self.params.linear_operator, self.params.linear_operator,
self.params.nonlinear_operator, self.params.nonlinear_operator,
self.params.tolerated_error, self.params.tolerated_error,
@@ -173,7 +173,7 @@ class RK4IP:
) )
else: else:
integrator = solver.ConstantStepIntegrator( integrator = solver.ConstantStepIntegrator(
self.init_state, self.params.linear_operator, self.params.nonlinear_operator state, self.params.linear_operator, self.params.nonlinear_operator
) )
for state in integrator: for state in integrator:

View File

@@ -90,7 +90,7 @@ class ConstantStepIntegrator(Integrator):
new_spec = rk4ip_step( new_spec = rk4ip_step(
self.nonlinear_operator, self.nonlinear_operator,
self.state, self.state,
self.state.solution.spectrum, self.state.spectrum,
self.state.current_step_size, self.state.current_step_size,
lin, lin,
nonlin, nonlin,
@@ -98,7 +98,7 @@ class ConstantStepIntegrator(Integrator):
self.state.z += self.state.current_step_size self.state.z += self.state.current_step_size
self.state.step += 1 self.state.step += 1
self.state.solution = new_spec self.state = new_spec
yield self.state yield self.state
@@ -128,27 +128,25 @@ class ConservedQuantityIntegrator(Integrator):
self.record_tracked_values() self.record_tracked_values()
while True: while True:
h = h_next_step h = h_next_step
new_state = self.state.replace( new_spec = rk4ip_step(
rk4ip_step( self.nonlinear_operator,
self.nonlinear_operator, self.state,
self.state, self.state.spectrum,
self.state.solution.spectrum, h,
h, lin,
lin, nonlin,
nonlin,
)
) )
new_state = self.state.replace(new_spec)
new_qty = self.conserved_quantity(new_state) new_qty = self.conserved_quantity(new_state)
self.current_error = np.abs(new_qty - self.last_qty) self.current_error = np.abs(new_qty - self.last_qty) / self.last_qty
error_ok = self.last_qty * self.tolerated_error
if self.current_error > 2 * error_ok: if self.current_error > 2 * self.tolerated_error:
h_next_step = h * 0.5 h_next_step = h * 0.5
elif error_ok < self.current_error <= 2.0 * error_ok: elif self.tolerated_error < self.current_error <= 2.0 * self.tolerated_error:
h_next_step = h / size_fac h_next_step = h / size_fac
break break
elif self.current_error < 0.1 * error_ok: elif self.current_error < 0.1 * self.tolerated_error:
h_next_step = h * size_fac h_next_step = h * size_fac
break break
else: else:
@@ -183,7 +181,7 @@ class RK4IPSD(Integrator):
self.record_tracked_values() self.record_tracked_values()
while True: while True:
h = h_next_step h = h_next_step
new_fine_inter = self.take_step(h / 2, self.state.solution.spectrum, lin, nonlin) new_fine_inter = self.take_step(h / 2, self.state.spectrum, lin, nonlin)
new_fine_inter_state = self.state.replace(new_fine_inter) new_fine_inter_state = self.state.replace(new_fine_inter)
new_fine = self.take_step( new_fine = self.take_step(
h / 2, h / 2,
@@ -191,7 +189,7 @@ class RK4IPSD(Integrator):
self.linear_operator(new_fine_inter_state), self.linear_operator(new_fine_inter_state),
self.nonlinear_operator(new_fine_inter_state), self.nonlinear_operator(new_fine_inter_state),
) )
new_coarse = self.take_step(h, self.state.solution.spectrum, lin, nonlin) new_coarse = self.take_step(h, self.state.spectrum, lin, nonlin)
self.current_error = self.compute_diff(new_coarse, new_fine) self.current_error = self.compute_diff(new_coarse, new_fine)
if self.current_error > 2 * self.tolerated_error: if self.current_error > 2 * self.tolerated_error:
@@ -209,7 +207,7 @@ class RK4IPSD(Integrator):
self.state.current_step_size = h_next_step self.state.current_step_size = h_next_step
self.state.z += h self.state.z += h
self.state.step += 1 self.state.step += 1
self.state.solution = new_fine self.state = new_fine
yield self.state yield self.state
def take_step( def take_step(
@@ -262,7 +260,7 @@ class ERK43(Integrator):
while True: while True:
h = h_next_step h = h_next_step
expD = np.exp(h * 0.5 * lin) expD = np.exp(h * 0.5 * lin)
A_I = expD * self.state.solution.spectrum A_I = expD * self.state.spectrum
k1 = expD * k5 k1 = expD * k5
k2 = self.nl(A_I + 0.5 * h * k1) k2 = self.nl(A_I + 0.5 * h * k1)
k3 = self.nl(A_I + 0.5 * h * k2) k3 = self.nl(A_I + 0.5 * h * k2)
@@ -289,7 +287,7 @@ class ERK43(Integrator):
self.state.current_step_size = h_next_step self.state.current_step_size = h_next_step
self.state.z += h self.state.z += h
self.state.step += 1 self.state.step += 1
self.state.solution = new_fine self.state = new_fine
k5 = tmp_k5 k5 = tmp_k5
yield self.state yield self.state
@@ -316,7 +314,7 @@ class ERK54(ERK43):
expD4p = np.exp(h * 0.25 * lin) expD4p = np.exp(h * 0.25 * lin)
expD4m = 1 / expD4p expD4m = 1 / expD4p
A_I = expD2 * self.state.solution.spectrum A_I = expD2 * self.state.spectrum
k1 = expD2 * k7 k1 = expD2 * k7
k2 = self.nl(A_I + 0.5 * h * k1) k2 = self.nl(A_I + 0.5 * h * k1)
k3 = expD4p * self.nl(expD4m * (A_I + h / 16 * (3 * k1 + k2))) k3 = expD4p * self.nl(expD4m * (A_I + h / 16 * (3 * k1 + k2)))
@@ -346,7 +344,7 @@ class ERK54(ERK43):
self.state.current_step_size = h_next_step self.state.current_step_size = h_next_step
self.state.z += h self.state.z += h
self.state.step += 1 self.state.step += 1
self.state.solution = new_fine self.state = new_fine
k7 = tmp_k7 k7 = tmp_k7
yield self.state yield self.state
@@ -369,6 +367,8 @@ def rk4ip_step(
state at the start of the step state at the start of the step
h : float h : float
step size step size
spectrum : np.ndarray
spectrum to propagate
init_linear : np.ndarray init_linear : np.ndarray
linear operator already applied on the initial state linear operator already applied on the initial state
init_nonlinear : np.ndarray init_nonlinear : np.ndarray