added scipy fft alternative

This commit is contained in:
2024-02-06 10:10:25 +01:00
parent 864b3ba187
commit 3741954d69
3 changed files with 52 additions and 24 deletions

View File

@@ -3,6 +3,7 @@ collection of purely mathematical function
""" """
import math import math
import platform
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import cache from functools import cache
@@ -10,6 +11,7 @@ from typing import Callable, Sequence
import numba import numba
import numpy as np import numpy as np
from scipy import fft as sfft
from scipy.interpolate import interp1d, lagrange from scipy.interpolate import interp1d, lagrange
from scipy.special import jn_zeros from scipy.special import jn_zeros
@@ -20,10 +22,16 @@ c = 299792458.0
def fft_functions( def fft_functions(
full_field: bool, full_field: bool,
) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]: ) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:
if full_field: if platform.processor() == "arm":
return np.fft.rfft, np.fft.irfft if full_field:
return sfft.rfft, sfft.irfft
else:
return sfft.fft, sfft.ifft
else: else:
return np.fft.fft, np.fft.ifft if full_field:
return np.fft.rfft, np.fft.irfft
else:
return np.fft.fft, np.fft.ifft
def expm1_int(y: np.ndarray, dx: float) -> np.ndarray: def expm1_int(y: np.ndarray, dx: float) -> np.ndarray:

View File

@@ -2,6 +2,7 @@
This file includes Dispersion, NonLinear and Loss classes to be used in the solver This file includes Dispersion, NonLinear and Loss classes to be used in the solver
Nothing except the solver should depend on this file Nothing except the solver should depend on this file
""" """
from __future__ import annotations from __future__ import annotations
from typing import Callable from typing import Callable
@@ -268,21 +269,32 @@ def constant_wave_vector(
################################################## ##################################################
def envelope_raman(hr_w: np.ndarray, raman_fraction: float) -> FieldOperator: def envelope_raman(
hr_w: np.ndarray,
raman_fraction: float,
fft: Callable[[np.ndarray], np.ndarray],
ifft: Callable[[np.ndarray], np.ndarray],
) -> FieldOperator:
def operate(field: np.ndarray, z: float) -> np.ndarray: def operate(field: np.ndarray, z: float) -> np.ndarray:
return raman_fraction * np.fft.ifft(hr_w * np.fft.fft(math.abs2(field))) return raman_fraction * ifft(hr_w * fft(math.abs2(field)))
return operate return operate
def full_field_raman( def full_field_raman(
raman_type: str, raman_fraction: float, t: np.ndarray, w: np.ndarray, chi3: float raman_type: str,
raman_fraction: float,
t: np.ndarray,
w: np.ndarray,
chi3: float,
fft: Callable[[np.ndarray], np.ndarray],
ifft: Callable[[np.ndarray], np.ndarray],
) -> FieldOperator: ) -> FieldOperator:
hr_w = fiber.delayed_raman_w(t, raman_type) hr_w = fiber.delayed_raman_w(t, raman_type)
factor_in = units.epsilon0 * chi3 * raman_fraction factor_in = units.epsilon0 * chi3 * raman_fraction
def operate(field: np.ndarray, z: float) -> np.ndarray: def operate(field: np.ndarray, z: float) -> np.ndarray:
return factor_in * field * np.fft.irfft(hr_w * np.fft.rfft(math.abs2(field))) return factor_in * field * ifft(hr_w * fft(math.abs2(field)))
return operate return operate
@@ -461,14 +473,16 @@ def envelope_nonlinear_operator(
ss_op: VariableQuantity, ss_op: VariableQuantity,
spm_op: FieldOperator, spm_op: FieldOperator,
raman_op: FieldOperator, raman_op: FieldOperator,
fft: Callable[[np.ndarray], np.ndarray],
ifft: Callable[[np.ndarray], np.ndarray],
) -> SpecOperator: ) -> SpecOperator:
def operate(spec: np.ndarray, z: float) -> np.ndarray: def operate(spec: np.ndarray, z: float) -> np.ndarray:
field = np.fft.ifft(spec) field = ifft(spec)
return ( return (
-1j -1j
* gamma_op(z) * gamma_op(z)
* (1 + ss_op(z)) * (1 + ss_op(z))
* np.fft.fft(field * (spm_op(field, z) + raman_op(field, z))) * fft(field * (spm_op(field, z) + raman_op(field, z)))
) )
return operate return operate
@@ -480,10 +494,12 @@ def full_field_nonlinear_operator(
spm_op: FieldOperator, spm_op: FieldOperator,
plasma_op: FieldOperator, plasma_op: FieldOperator,
fullfield_nl_prefactor: VariableQuantity, fullfield_nl_prefactor: VariableQuantity,
fft: Callable[[np.ndarray], np.ndarray],
ifft: Callable[[np.ndarray], np.ndarray],
) -> SpecOperator: ) -> SpecOperator:
def operate(spec: np.ndarray, z: float) -> np.ndarray: def operate(spec: np.ndarray, z: float) -> np.ndarray:
field = np.fft.irfft(spec) field = ifft(spec)
total_nonlinear = spm_op(field) + raman_op(field) + plasma_op(field) total_nonlinear = spm_op(field) + raman_op(field) + plasma_op(field)
return 1j * fullfield_nl_prefactor(z) * np.fft.rfft(total_nonlinear) return 1j * fullfield_nl_prefactor(z) * fft(total_nonlinear)
return operate return operate

View File

@@ -8,6 +8,7 @@ from pathlib import Path
from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload
import numpy as np import numpy as np
import scipy.fft as sfft
import scipy.signal as ss import scipy.signal as ss
from tqdm import tqdm from tqdm import tqdm
@@ -48,7 +49,12 @@ class Spectrum(np.ndarray):
if buf[:2] != b"S\xc9": if buf[:2] != b"S\xc9":
raise OSError("not a valid buffer") raise OSError("not a valid buffer")
buf = buf[2:] buf = buf[2:]
ifft = np.fft.ifft if buf[:4] == b"comp" else np.fft.irfft ifft = {
b"comp": np.fft.ifft,
b"real": np.fft.irfft,
b"scic": sfft.ifft,
b"scir": sfft.irfft,
}
shape_n = buf[4] shape_n = buf[4]
nt, *shape = ( nt, *shape = (
int.from_bytes(buf[5 + i * 8 : 5 + (i + 1) * 8], "big") for i in range(shape_n) int.from_bytes(buf[5 + i * 8 : 5 + (i + 1) * 8], "big") for i in range(shape_n)
@@ -122,6 +128,10 @@ class Spectrum(np.ndarray):
f_name = b"comp" f_name = b"comp"
elif self.ifft is np.fft.irfft: elif self.ifft is np.fft.irfft:
f_name = b"real" f_name = b"real"
elif self.ifft is sfft.ifft:
f_name = b"scic"
elif self.ifft is sfft.irfft:
f_name = b"scir"
else: else:
raise ValueError(f"cannot export ifft function {self.ifft!r}") raise ValueError(f"cannot export ifft function {self.ifft!r}")
@@ -298,12 +308,10 @@ class Propagation(Generic[ParamsOrNone]):
return self._current_index return self._current_index
@overload @overload
def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum: def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum: ...
...
@overload @overload
def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray: def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray: ...
...
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray: def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
if isinstance(key, slice): if isinstance(key, slice):
@@ -326,12 +334,10 @@ class Propagation(Generic[ParamsOrNone]):
return array return array
@overload @overload
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum: def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum: ...
...
@overload @overload
def _load_slice(self: Propagation[None], key: slice) -> np.ndarray: def _load_slice(self: Propagation[None], key: slice) -> np.ndarray: ...
...
def _load_slice(self, key: slice) -> Spectrum: def _load_slice(self, key: slice) -> Spectrum:
self._warn_negative_index(key.start) self._warn_negative_index(key.start)
@@ -412,12 +418,10 @@ class PropagationCollection:
wl: np.ndarray wl: np.ndarray
@overload @overload
def __getitem__(self, key: int) -> Propagation: def __getitem__(self, key: int) -> Propagation: ...
...
@overload @overload
def __getitem__(self, key: slice) -> list[Propagation]: def __getitem__(self, key: slice) -> list[Propagation]: ...
...
def __getitem__(self, key: int | slice) -> Propagation | list[Propagation]: def __getitem__(self, key: int | slice) -> Propagation | list[Propagation]:
return self.propagations[key] return self.propagations[key]