got rid of SpectrumDescriptor

This commit is contained in:
Benoît Sierro
2021-11-12 09:47:37 +01:00
parent 4b9425b519
commit 1f8685048b
3 changed files with 99 additions and 83 deletions

View File

@@ -4,9 +4,8 @@ Nothing except the solver should depend on this file
"""
from __future__ import annotations
import dataclasses
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from dataclasses import dataclass
from typing import Any, Callable
import numpy as np
@@ -18,17 +17,51 @@ from .physics import fiber, materials, pulse, units
from .utils import load_material_dico
class SpectrumDescriptor:
name: str
spectrum: np.ndarray = None
class CurrentState:
length: float
z: float
current_step_size: float
step: int
C_to_A_factor: np.ndarray
converter: Callable[[np.ndarray], np.ndarray]
__spectrum: np.ndarray = None
__spec2: np.ndarray = None
__field: np.ndarray = None
__field2: np.ndarray = None
_converter: Callable[[np.ndarray], np.ndarray]
def __set__(self, instance: CurrentState, value: np.ndarray):
self._converter = instance.converter
self.spectrum = value
def __init__(
self,
length: float,
z: float,
current_step_size: float,
step: int,
spectrum: np.ndarray,
C_to_A_factor: np.ndarray,
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft,
):
self.length = length
self.z = z
self.current_step_size = current_step_size
self.step = step
self.C_to_A_factor = C_to_A_factor
self.converter = converter
self.__spectrum = spectrum
@property
def z_ratio(self) -> float:
return self.z / self.length
@property
def actual_spectrum(self) -> np.ndarray:
return self.C_to_A_factor * self.spectrum
@property
def spectrum(self) -> np.ndarray:
return self.__spectrum
@spectrum.setter
def spectrum(self, new_value: np.ndarray):
self.__spectrum = new_value
self.__spec2 = None
self.__field = None
self.__field2 = None
@@ -42,7 +75,7 @@ class SpectrumDescriptor:
@property
def field(self) -> np.ndarray:
if self.__field is None:
self.__field = self._converter(self.spectrum)
self.__field = self.converter(self.spectrum)
return self.__field
@property
@@ -69,37 +102,16 @@ class SpectrumDescriptor:
self.__field = field
self.__field2 = field2
def __delete__(self, instance):
raise AttributeError("Cannot delete Spectrum field")
def __set_name__(self, owner, name):
self.name = name
@dataclasses.dataclass
class CurrentState:
length: float
z: float
current_step_size: float
step: int
C_to_A_factor: np.ndarray
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft
solution: SpectrumDescriptor = SpectrumDescriptor()
@property
def z_ratio(self) -> float:
return self.z / self.length
def replace(self, new_spectrum: np.ndarray) -> CurrentState:
"""returns a new state with new attributes"""
return CurrentState(
self.length,
self.z,
self.current_step_size,
self.step,
self.C_to_A_factor,
self.converter,
new_spectrum,
length=self.length,
z=self.z,
current_step_size=self.current_step_size,
step=self.step,
C_to_A_factor=self.C_to_A_factor,
converter=self.converter,
spectrum=new_spectrum,
)
def with_params(self, **params) -> CurrentState:
@@ -112,18 +124,22 @@ class CurrentState:
C_to_A_factor=self.C_to_A_factor,
converter=self.converter,
)
new_state = CurrentState(solution=self.solution.spectrum, **(my_params | params))
new_state.solution.force_values(
self.solution.spec2, self.solution.field, self.solution.field2
)
new_state = CurrentState(spectrum=self.__spectrum, **(my_params | params))
new_state.force_values(self.spec2, self.field, self.field2)
return new_state
def copy(self) -> CurrentState:
return replace(self, solution=self.solution.spectrum)
@property
def actual_spectrum(self) -> np.ndarray:
return self.C_to_A_factor * self.solution.spectrum
new = CurrentState(
length=self.length,
z=self.z,
current_step_size=self.current_step_size,
step=self.step,
C_to_A_factor=self.C_to_A_factor,
converter=self.converter,
spectrum=self.__spectrum,
)
new.force_values(self.__spec2, self.__field, self.__field2)
return new
class ValueTracker(ABC):
@@ -634,7 +650,7 @@ class EnvelopeRaman(AbstractRaman):
self.f_r = 0.245 if raman_type == "agrawal" else 0.18
def __call__(self, state: CurrentState) -> np.ndarray:
return self.f_r * np.fft.ifft(self.hr_w * np.fft.fft(state.solution.field2))
return self.f_r * np.fft.ifft(self.hr_w * np.fft.fft(state.field2))
class FullFieldRaman(AbstractRaman):
@@ -644,7 +660,7 @@ class FullFieldRaman(AbstractRaman):
self.multiplier = units.epsilon0 * chi3 * self.f_r
def __call__(self, state: CurrentState) -> np.ndarray:
return self.multiplier * np.fft.ifft(np.fft.fft(state.solution.field2) * self.hr_w)
return self.multiplier * np.fft.ifft(np.fft.fft(state.field2) * self.hr_w)
##################################################
@@ -683,7 +699,7 @@ class EnvelopeSPM(AbstractSPM):
self.fraction = 1 - raman_op.f_r
def __call__(self, state: CurrentState) -> np.ndarray:
return self.fraction * state.solution.field2
return self.fraction * state.field2
class FullFieldSPM(AbstractSPM):
@@ -692,7 +708,7 @@ class FullFieldSPM(AbstractSPM):
self.factor = self.fraction * chi3 * units.epsilon0
def __call__(self, state: CurrentState) -> np.ndarray:
return self.factor * state.solution.field2 * state.solution.field
return self.factor * state.field2 * state.field
##################################################
@@ -809,7 +825,7 @@ class PhotonNumberLoss(AbstractConservedQuantity):
def __call__(self, state: CurrentState) -> float:
return pulse.photon_number_with_loss(
state.solution.spec2,
state.spec2,
self.w,
self.dw,
self.gamma_op(state),
@@ -825,7 +841,7 @@ class PhotonNumberNoLoss(AbstractConservedQuantity):
self.gamma_op = gamma_op
def __call__(self, state: CurrentState) -> float:
return pulse.photon_number(state.solution.spec2, self.w, self.dw, self.gamma_op(state))
return pulse.photon_number(state.spec2, self.w, self.dw, self.gamma_op(state))
class EnergyLoss(AbstractConservedQuantity):
@@ -835,7 +851,7 @@ class EnergyLoss(AbstractConservedQuantity):
def __call__(self, state: CurrentState) -> float:
return pulse.pulse_energy_with_loss(
math.abs2(state.C_to_A_factor * state.solution.spectrum),
math.abs2(state.C_to_A_factor * state.spectrum),
self.dw,
self.loss_op(state),
state.current_step_size,
@@ -847,7 +863,7 @@ class EnergyNoLoss(AbstractConservedQuantity):
self.dw = w[1] - w[0]
def __call__(self, state: CurrentState) -> float:
return pulse.pulse_energy(math.abs2(state.C_to_A_factor * state.solution.spectrum), self.dw)
return pulse.pulse_energy(math.abs2(state.C_to_A_factor * state.spectrum), self.dw)
def conserved_quantity(
@@ -960,7 +976,7 @@ class EnvelopeNonLinearOperator(NonLinearOperator):
-1j
* self.gamma_op(state)
* (1 + self.ss_op(state))
* np.fft.fft(state.solution.field * (self.spm_op(state) + self.raman_op(state)))
* np.fft.fft(state.field * (self.spm_op(state) + self.raman_op(state)))
)

View File

@@ -97,10 +97,10 @@ class RK4IP:
step=0,
C_to_A_factor=self.params.c_to_a_factor,
converter=self.params.ifft,
solution=self.params.spec_0.copy() / self.params.c_to_a_factor,
spectrum=self.params.spec_0.copy() / self.params.c_to_a_factor,
)
self.stored_spectra = self.params.recovery_last_stored * [None] + [
self.init_state.solution.spectrum.copy()
self.init_state.spectrum.copy()
]
self.tracked_values = TrackedValues()
@@ -165,7 +165,7 @@ class RK4IP:
yield len(self.stored_spectra) - 1, state
if self.params.adapt_step_size:
integrator = solver.ConservedQuantityIntegrator(
self.init_state,
state,
self.params.linear_operator,
self.params.nonlinear_operator,
self.params.tolerated_error,
@@ -173,7 +173,7 @@ class RK4IP:
)
else:
integrator = solver.ConstantStepIntegrator(
self.init_state, self.params.linear_operator, self.params.nonlinear_operator
state, self.params.linear_operator, self.params.nonlinear_operator
)
for state in integrator:

View File

@@ -90,7 +90,7 @@ class ConstantStepIntegrator(Integrator):
new_spec = rk4ip_step(
self.nonlinear_operator,
self.state,
self.state.solution.spectrum,
self.state.spectrum,
self.state.current_step_size,
lin,
nonlin,
@@ -98,7 +98,7 @@ class ConstantStepIntegrator(Integrator):
self.state.z += self.state.current_step_size
self.state.step += 1
self.state.solution = new_spec
self.state = new_spec
yield self.state
@@ -128,27 +128,25 @@ class ConservedQuantityIntegrator(Integrator):
self.record_tracked_values()
while True:
h = h_next_step
new_state = self.state.replace(
rk4ip_step(
new_spec = rk4ip_step(
self.nonlinear_operator,
self.state,
self.state.solution.spectrum,
self.state.spectrum,
h,
lin,
nonlin,
)
)
new_state = self.state.replace(new_spec)
new_qty = self.conserved_quantity(new_state)
self.current_error = np.abs(new_qty - self.last_qty)
error_ok = self.last_qty * self.tolerated_error
self.current_error = np.abs(new_qty - self.last_qty) / self.last_qty
if self.current_error > 2 * error_ok:
if self.current_error > 2 * self.tolerated_error:
h_next_step = h * 0.5
elif error_ok < self.current_error <= 2.0 * error_ok:
elif self.tolerated_error < self.current_error <= 2.0 * self.tolerated_error:
h_next_step = h / size_fac
break
elif self.current_error < 0.1 * error_ok:
elif self.current_error < 0.1 * self.tolerated_error:
h_next_step = h * size_fac
break
else:
@@ -183,7 +181,7 @@ class RK4IPSD(Integrator):
self.record_tracked_values()
while True:
h = h_next_step
new_fine_inter = self.take_step(h / 2, self.state.solution.spectrum, lin, nonlin)
new_fine_inter = self.take_step(h / 2, self.state.spectrum, lin, nonlin)
new_fine_inter_state = self.state.replace(new_fine_inter)
new_fine = self.take_step(
h / 2,
@@ -191,7 +189,7 @@ class RK4IPSD(Integrator):
self.linear_operator(new_fine_inter_state),
self.nonlinear_operator(new_fine_inter_state),
)
new_coarse = self.take_step(h, self.state.solution.spectrum, lin, nonlin)
new_coarse = self.take_step(h, self.state.spectrum, lin, nonlin)
self.current_error = self.compute_diff(new_coarse, new_fine)
if self.current_error > 2 * self.tolerated_error:
@@ -209,7 +207,7 @@ class RK4IPSD(Integrator):
self.state.current_step_size = h_next_step
self.state.z += h
self.state.step += 1
self.state.solution = new_fine
self.state = new_fine
yield self.state
def take_step(
@@ -262,7 +260,7 @@ class ERK43(Integrator):
while True:
h = h_next_step
expD = np.exp(h * 0.5 * lin)
A_I = expD * self.state.solution.spectrum
A_I = expD * self.state.spectrum
k1 = expD * k5
k2 = self.nl(A_I + 0.5 * h * k1)
k3 = self.nl(A_I + 0.5 * h * k2)
@@ -289,7 +287,7 @@ class ERK43(Integrator):
self.state.current_step_size = h_next_step
self.state.z += h
self.state.step += 1
self.state.solution = new_fine
self.state = new_fine
k5 = tmp_k5
yield self.state
@@ -316,7 +314,7 @@ class ERK54(ERK43):
expD4p = np.exp(h * 0.25 * lin)
expD4m = 1 / expD4p
A_I = expD2 * self.state.solution.spectrum
A_I = expD2 * self.state.spectrum
k1 = expD2 * k7
k2 = self.nl(A_I + 0.5 * h * k1)
k3 = expD4p * self.nl(expD4m * (A_I + h / 16 * (3 * k1 + k2)))
@@ -346,7 +344,7 @@ class ERK54(ERK43):
self.state.current_step_size = h_next_step
self.state.z += h
self.state.step += 1
self.state.solution = new_fine
self.state = new_fine
k7 = tmp_k7
yield self.state
@@ -369,6 +367,8 @@ def rk4ip_step(
state at the start of the step
h : float
step size
spectrum : np.ndarray
spectrum to propagate
init_linear : np.ndarray
linear operator already applied on the initial state
init_nonlinear : np.ndarray