added scipy fft alternative

This commit is contained in:
2024-02-06 10:10:25 +01:00
parent 864b3ba187
commit 3741954d69
3 changed files with 52 additions and 24 deletions

View File

@@ -3,6 +3,7 @@ collection of purely mathematical function
"""
import math
import platform
import warnings
from dataclasses import dataclass
from functools import cache
@@ -10,6 +11,7 @@ from typing import Callable, Sequence
import numba
import numpy as np
from scipy import fft as sfft
from scipy.interpolate import interp1d, lagrange
from scipy.special import jn_zeros
@@ -20,10 +22,16 @@ c = 299792458.0
def fft_functions(
full_field: bool,
) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:
if full_field:
return np.fft.rfft, np.fft.irfft
if platform.processor() == "arm":
if full_field:
return sfft.rfft, sfft.irfft
else:
return sfft.fft, sfft.ifft
else:
return np.fft.fft, np.fft.ifft
if full_field:
return np.fft.rfft, np.fft.irfft
else:
return np.fft.fft, np.fft.ifft
def expm1_int(y: np.ndarray, dx: float) -> np.ndarray:

View File

@@ -2,6 +2,7 @@
This file includes Dispersion, NonLinear and Loss classes to be used in the solver
Nothing except the solver should depend on this file
"""
from __future__ import annotations
from typing import Callable
@@ -268,21 +269,32 @@ def constant_wave_vector(
##################################################
def envelope_raman(hr_w: np.ndarray, raman_fraction: float) -> FieldOperator:
def envelope_raman(
hr_w: np.ndarray,
raman_fraction: float,
fft: Callable[[np.ndarray], np.ndarray],
ifft: Callable[[np.ndarray], np.ndarray],
) -> FieldOperator:
def operate(field: np.ndarray, z: float) -> np.ndarray:
return raman_fraction * np.fft.ifft(hr_w * np.fft.fft(math.abs2(field)))
return raman_fraction * ifft(hr_w * fft(math.abs2(field)))
return operate
def full_field_raman(
raman_type: str, raman_fraction: float, t: np.ndarray, w: np.ndarray, chi3: float
raman_type: str,
raman_fraction: float,
t: np.ndarray,
w: np.ndarray,
chi3: float,
fft: Callable[[np.ndarray], np.ndarray],
ifft: Callable[[np.ndarray], np.ndarray],
) -> FieldOperator:
hr_w = fiber.delayed_raman_w(t, raman_type)
factor_in = units.epsilon0 * chi3 * raman_fraction
def operate(field: np.ndarray, z: float) -> np.ndarray:
return factor_in * field * np.fft.irfft(hr_w * np.fft.rfft(math.abs2(field)))
return factor_in * field * ifft(hr_w * fft(math.abs2(field)))
return operate
@@ -461,14 +473,16 @@ def envelope_nonlinear_operator(
ss_op: VariableQuantity,
spm_op: FieldOperator,
raman_op: FieldOperator,
fft: Callable[[np.ndarray], np.ndarray],
ifft: Callable[[np.ndarray], np.ndarray],
) -> SpecOperator:
def operate(spec: np.ndarray, z: float) -> np.ndarray:
field = np.fft.ifft(spec)
field = ifft(spec)
return (
-1j
* gamma_op(z)
* (1 + ss_op(z))
* np.fft.fft(field * (spm_op(field, z) + raman_op(field, z)))
* fft(field * (spm_op(field, z) + raman_op(field, z)))
)
return operate
@@ -480,10 +494,12 @@ def full_field_nonlinear_operator(
spm_op: FieldOperator,
plasma_op: FieldOperator,
fullfield_nl_prefactor: VariableQuantity,
fft: Callable[[np.ndarray], np.ndarray],
ifft: Callable[[np.ndarray], np.ndarray],
) -> SpecOperator:
def operate(spec: np.ndarray, z: float) -> np.ndarray:
field = np.fft.irfft(spec)
field = ifft(spec)
total_nonlinear = spm_op(field) + raman_op(field) + plasma_op(field)
return 1j * fullfield_nl_prefactor(z) * np.fft.rfft(total_nonlinear)
return 1j * fullfield_nl_prefactor(z) * fft(total_nonlinear)
return operate

View File

@@ -8,6 +8,7 @@ from pathlib import Path
from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload
import numpy as np
import scipy.fft as sfft
import scipy.signal as ss
from tqdm import tqdm
@@ -48,7 +49,12 @@ class Spectrum(np.ndarray):
if buf[:2] != b"S\xc9":
raise OSError("not a valid buffer")
buf = buf[2:]
ifft = np.fft.ifft if buf[:4] == b"comp" else np.fft.irfft
ifft = {
b"comp": np.fft.ifft,
b"real": np.fft.irfft,
b"scic": sfft.ifft,
b"scir": sfft.irfft,
}
shape_n = buf[4]
nt, *shape = (
int.from_bytes(buf[5 + i * 8 : 5 + (i + 1) * 8], "big") for i in range(shape_n)
@@ -122,6 +128,10 @@ class Spectrum(np.ndarray):
f_name = b"comp"
elif self.ifft is np.fft.irfft:
f_name = b"real"
elif self.ifft is sfft.ifft:
f_name = b"scic"
elif self.ifft is sfft.irfft:
f_name = b"scir"
else:
raise ValueError(f"cannot export ifft function {self.ifft!r}")
@@ -298,12 +308,10 @@ class Propagation(Generic[ParamsOrNone]):
return self._current_index
@overload
def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum:
...
def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum: ...
@overload
def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray:
...
def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray: ...
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
if isinstance(key, slice):
@@ -326,12 +334,10 @@ class Propagation(Generic[ParamsOrNone]):
return array
@overload
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum:
...
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum: ...
@overload
def _load_slice(self: Propagation[None], key: slice) -> np.ndarray:
...
def _load_slice(self: Propagation[None], key: slice) -> np.ndarray: ...
def _load_slice(self, key: slice) -> Spectrum:
self._warn_negative_index(key.start)
@@ -412,12 +418,10 @@ class PropagationCollection:
wl: np.ndarray
@overload
def __getitem__(self, key: int) -> Propagation:
...
def __getitem__(self, key: int) -> Propagation: ...
@overload
def __getitem__(self, key: slice) -> list[Propagation]:
...
def __getitem__(self, key: slice) -> list[Propagation]: ...
def __getitem__(self, key: int | slice) -> Propagation | list[Propagation]:
return self.propagations[key]