added scipy fft alternative
This commit is contained in:
@@ -3,6 +3,7 @@ collection of purely mathematical function
|
||||
"""
|
||||
|
||||
import math
|
||||
import platform
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
@@ -10,6 +11,7 @@ from typing import Callable, Sequence
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from scipy import fft as sfft
|
||||
from scipy.interpolate import interp1d, lagrange
|
||||
from scipy.special import jn_zeros
|
||||
|
||||
@@ -20,6 +22,12 @@ c = 299792458.0
|
||||
def fft_functions(
|
||||
full_field: bool,
|
||||
) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:
|
||||
if platform.processor() == "arm":
|
||||
if full_field:
|
||||
return sfft.rfft, sfft.irfft
|
||||
else:
|
||||
return sfft.fft, sfft.ifft
|
||||
else:
|
||||
if full_field:
|
||||
return np.fft.rfft, np.fft.irfft
|
||||
else:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
This file includes Dispersion, NonLinear and Loss classes to be used in the solver
|
||||
Nothing except the solver should depend on this file
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
@@ -268,21 +269,32 @@ def constant_wave_vector(
|
||||
##################################################
|
||||
|
||||
|
||||
def envelope_raman(hr_w: np.ndarray, raman_fraction: float) -> FieldOperator:
|
||||
def envelope_raman(
|
||||
hr_w: np.ndarray,
|
||||
raman_fraction: float,
|
||||
fft: Callable[[np.ndarray], np.ndarray],
|
||||
ifft: Callable[[np.ndarray], np.ndarray],
|
||||
) -> FieldOperator:
|
||||
def operate(field: np.ndarray, z: float) -> np.ndarray:
|
||||
return raman_fraction * np.fft.ifft(hr_w * np.fft.fft(math.abs2(field)))
|
||||
return raman_fraction * ifft(hr_w * fft(math.abs2(field)))
|
||||
|
||||
return operate
|
||||
|
||||
|
||||
def full_field_raman(
|
||||
raman_type: str, raman_fraction: float, t: np.ndarray, w: np.ndarray, chi3: float
|
||||
raman_type: str,
|
||||
raman_fraction: float,
|
||||
t: np.ndarray,
|
||||
w: np.ndarray,
|
||||
chi3: float,
|
||||
fft: Callable[[np.ndarray], np.ndarray],
|
||||
ifft: Callable[[np.ndarray], np.ndarray],
|
||||
) -> FieldOperator:
|
||||
hr_w = fiber.delayed_raman_w(t, raman_type)
|
||||
factor_in = units.epsilon0 * chi3 * raman_fraction
|
||||
|
||||
def operate(field: np.ndarray, z: float) -> np.ndarray:
|
||||
return factor_in * field * np.fft.irfft(hr_w * np.fft.rfft(math.abs2(field)))
|
||||
return factor_in * field * ifft(hr_w * fft(math.abs2(field)))
|
||||
|
||||
return operate
|
||||
|
||||
@@ -461,14 +473,16 @@ def envelope_nonlinear_operator(
|
||||
ss_op: VariableQuantity,
|
||||
spm_op: FieldOperator,
|
||||
raman_op: FieldOperator,
|
||||
fft: Callable[[np.ndarray], np.ndarray],
|
||||
ifft: Callable[[np.ndarray], np.ndarray],
|
||||
) -> SpecOperator:
|
||||
def operate(spec: np.ndarray, z: float) -> np.ndarray:
|
||||
field = np.fft.ifft(spec)
|
||||
field = ifft(spec)
|
||||
return (
|
||||
-1j
|
||||
* gamma_op(z)
|
||||
* (1 + ss_op(z))
|
||||
* np.fft.fft(field * (spm_op(field, z) + raman_op(field, z)))
|
||||
* fft(field * (spm_op(field, z) + raman_op(field, z)))
|
||||
)
|
||||
|
||||
return operate
|
||||
@@ -480,10 +494,12 @@ def full_field_nonlinear_operator(
|
||||
spm_op: FieldOperator,
|
||||
plasma_op: FieldOperator,
|
||||
fullfield_nl_prefactor: VariableQuantity,
|
||||
fft: Callable[[np.ndarray], np.ndarray],
|
||||
ifft: Callable[[np.ndarray], np.ndarray],
|
||||
) -> SpecOperator:
|
||||
def operate(spec: np.ndarray, z: float) -> np.ndarray:
|
||||
field = np.fft.irfft(spec)
|
||||
field = ifft(spec)
|
||||
total_nonlinear = spm_op(field) + raman_op(field) + plasma_op(field)
|
||||
return 1j * fullfield_nl_prefactor(z) * np.fft.rfft(total_nonlinear)
|
||||
return 1j * fullfield_nl_prefactor(z) * fft(total_nonlinear)
|
||||
|
||||
return operate
|
||||
|
||||
@@ -8,6 +8,7 @@ from pathlib import Path
|
||||
from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload
|
||||
|
||||
import numpy as np
|
||||
import scipy.fft as sfft
|
||||
import scipy.signal as ss
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -48,7 +49,12 @@ class Spectrum(np.ndarray):
|
||||
if buf[:2] != b"S\xc9":
|
||||
raise OSError("not a valid buffer")
|
||||
buf = buf[2:]
|
||||
ifft = np.fft.ifft if buf[:4] == b"comp" else np.fft.irfft
|
||||
ifft = {
|
||||
b"comp": np.fft.ifft,
|
||||
b"real": np.fft.irfft,
|
||||
b"scic": sfft.ifft,
|
||||
b"scir": sfft.irfft,
|
||||
}
|
||||
shape_n = buf[4]
|
||||
nt, *shape = (
|
||||
int.from_bytes(buf[5 + i * 8 : 5 + (i + 1) * 8], "big") for i in range(shape_n)
|
||||
@@ -122,6 +128,10 @@ class Spectrum(np.ndarray):
|
||||
f_name = b"comp"
|
||||
elif self.ifft is np.fft.irfft:
|
||||
f_name = b"real"
|
||||
elif self.ifft is sfft.ifft:
|
||||
f_name = b"scic"
|
||||
elif self.ifft is sfft.irfft:
|
||||
f_name = b"scir"
|
||||
else:
|
||||
raise ValueError(f"cannot export ifft function {self.ifft!r}")
|
||||
|
||||
@@ -298,12 +308,10 @@ class Propagation(Generic[ParamsOrNone]):
|
||||
return self._current_index
|
||||
|
||||
@overload
|
||||
def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum:
|
||||
...
|
||||
def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray:
|
||||
...
|
||||
def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray: ...
|
||||
|
||||
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
|
||||
if isinstance(key, slice):
|
||||
@@ -326,12 +334,10 @@ class Propagation(Generic[ParamsOrNone]):
|
||||
return array
|
||||
|
||||
@overload
|
||||
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum:
|
||||
...
|
||||
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum: ...
|
||||
|
||||
@overload
|
||||
def _load_slice(self: Propagation[None], key: slice) -> np.ndarray:
|
||||
...
|
||||
def _load_slice(self: Propagation[None], key: slice) -> np.ndarray: ...
|
||||
|
||||
def _load_slice(self, key: slice) -> Spectrum:
|
||||
self._warn_negative_index(key.start)
|
||||
@@ -412,12 +418,10 @@ class PropagationCollection:
|
||||
wl: np.ndarray
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> Propagation:
|
||||
...
|
||||
def __getitem__(self, key: int) -> Propagation: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: slice) -> list[Propagation]:
|
||||
...
|
||||
def __getitem__(self, key: slice) -> list[Propagation]: ...
|
||||
|
||||
def __getitem__(self, key: int | slice) -> Propagation | list[Propagation]:
|
||||
return self.propagations[key]
|
||||
|
||||
Reference in New Issue
Block a user