added scipy fft alternative
This commit is contained in:
@@ -3,6 +3,7 @@ collection of purely mathematical function
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import platform
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cache
|
from functools import cache
|
||||||
@@ -10,6 +11,7 @@ from typing import Callable, Sequence
|
|||||||
|
|
||||||
import numba
|
import numba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from scipy import fft as sfft
|
||||||
from scipy.interpolate import interp1d, lagrange
|
from scipy.interpolate import interp1d, lagrange
|
||||||
from scipy.special import jn_zeros
|
from scipy.special import jn_zeros
|
||||||
|
|
||||||
@@ -20,10 +22,16 @@ c = 299792458.0
|
|||||||
def fft_functions(
|
def fft_functions(
|
||||||
full_field: bool,
|
full_field: bool,
|
||||||
) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:
|
) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:
|
||||||
if full_field:
|
if platform.processor() == "arm":
|
||||||
return np.fft.rfft, np.fft.irfft
|
if full_field:
|
||||||
|
return sfft.rfft, sfft.irfft
|
||||||
|
else:
|
||||||
|
return sfft.fft, sfft.ifft
|
||||||
else:
|
else:
|
||||||
return np.fft.fft, np.fft.ifft
|
if full_field:
|
||||||
|
return np.fft.rfft, np.fft.irfft
|
||||||
|
else:
|
||||||
|
return np.fft.fft, np.fft.ifft
|
||||||
|
|
||||||
|
|
||||||
def expm1_int(y: np.ndarray, dx: float) -> np.ndarray:
|
def expm1_int(y: np.ndarray, dx: float) -> np.ndarray:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
This file includes Dispersion, NonLinear and Loss classes to be used in the solver
|
This file includes Dispersion, NonLinear and Loss classes to be used in the solver
|
||||||
Nothing except the solver should depend on this file
|
Nothing except the solver should depend on this file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
@@ -268,21 +269,32 @@ def constant_wave_vector(
|
|||||||
##################################################
|
##################################################
|
||||||
|
|
||||||
|
|
||||||
def envelope_raman(hr_w: np.ndarray, raman_fraction: float) -> FieldOperator:
|
def envelope_raman(
|
||||||
|
hr_w: np.ndarray,
|
||||||
|
raman_fraction: float,
|
||||||
|
fft: Callable[[np.ndarray], np.ndarray],
|
||||||
|
ifft: Callable[[np.ndarray], np.ndarray],
|
||||||
|
) -> FieldOperator:
|
||||||
def operate(field: np.ndarray, z: float) -> np.ndarray:
|
def operate(field: np.ndarray, z: float) -> np.ndarray:
|
||||||
return raman_fraction * np.fft.ifft(hr_w * np.fft.fft(math.abs2(field)))
|
return raman_fraction * ifft(hr_w * fft(math.abs2(field)))
|
||||||
|
|
||||||
return operate
|
return operate
|
||||||
|
|
||||||
|
|
||||||
def full_field_raman(
|
def full_field_raman(
|
||||||
raman_type: str, raman_fraction: float, t: np.ndarray, w: np.ndarray, chi3: float
|
raman_type: str,
|
||||||
|
raman_fraction: float,
|
||||||
|
t: np.ndarray,
|
||||||
|
w: np.ndarray,
|
||||||
|
chi3: float,
|
||||||
|
fft: Callable[[np.ndarray], np.ndarray],
|
||||||
|
ifft: Callable[[np.ndarray], np.ndarray],
|
||||||
) -> FieldOperator:
|
) -> FieldOperator:
|
||||||
hr_w = fiber.delayed_raman_w(t, raman_type)
|
hr_w = fiber.delayed_raman_w(t, raman_type)
|
||||||
factor_in = units.epsilon0 * chi3 * raman_fraction
|
factor_in = units.epsilon0 * chi3 * raman_fraction
|
||||||
|
|
||||||
def operate(field: np.ndarray, z: float) -> np.ndarray:
|
def operate(field: np.ndarray, z: float) -> np.ndarray:
|
||||||
return factor_in * field * np.fft.irfft(hr_w * np.fft.rfft(math.abs2(field)))
|
return factor_in * field * ifft(hr_w * fft(math.abs2(field)))
|
||||||
|
|
||||||
return operate
|
return operate
|
||||||
|
|
||||||
@@ -461,14 +473,16 @@ def envelope_nonlinear_operator(
|
|||||||
ss_op: VariableQuantity,
|
ss_op: VariableQuantity,
|
||||||
spm_op: FieldOperator,
|
spm_op: FieldOperator,
|
||||||
raman_op: FieldOperator,
|
raman_op: FieldOperator,
|
||||||
|
fft: Callable[[np.ndarray], np.ndarray],
|
||||||
|
ifft: Callable[[np.ndarray], np.ndarray],
|
||||||
) -> SpecOperator:
|
) -> SpecOperator:
|
||||||
def operate(spec: np.ndarray, z: float) -> np.ndarray:
|
def operate(spec: np.ndarray, z: float) -> np.ndarray:
|
||||||
field = np.fft.ifft(spec)
|
field = ifft(spec)
|
||||||
return (
|
return (
|
||||||
-1j
|
-1j
|
||||||
* gamma_op(z)
|
* gamma_op(z)
|
||||||
* (1 + ss_op(z))
|
* (1 + ss_op(z))
|
||||||
* np.fft.fft(field * (spm_op(field, z) + raman_op(field, z)))
|
* fft(field * (spm_op(field, z) + raman_op(field, z)))
|
||||||
)
|
)
|
||||||
|
|
||||||
return operate
|
return operate
|
||||||
@@ -480,10 +494,12 @@ def full_field_nonlinear_operator(
|
|||||||
spm_op: FieldOperator,
|
spm_op: FieldOperator,
|
||||||
plasma_op: FieldOperator,
|
plasma_op: FieldOperator,
|
||||||
fullfield_nl_prefactor: VariableQuantity,
|
fullfield_nl_prefactor: VariableQuantity,
|
||||||
|
fft: Callable[[np.ndarray], np.ndarray],
|
||||||
|
ifft: Callable[[np.ndarray], np.ndarray],
|
||||||
) -> SpecOperator:
|
) -> SpecOperator:
|
||||||
def operate(spec: np.ndarray, z: float) -> np.ndarray:
|
def operate(spec: np.ndarray, z: float) -> np.ndarray:
|
||||||
field = np.fft.irfft(spec)
|
field = ifft(spec)
|
||||||
total_nonlinear = spm_op(field) + raman_op(field) + plasma_op(field)
|
total_nonlinear = spm_op(field) + raman_op(field) + plasma_op(field)
|
||||||
return 1j * fullfield_nl_prefactor(z) * np.fft.rfft(total_nonlinear)
|
return 1j * fullfield_nl_prefactor(z) * fft(total_nonlinear)
|
||||||
|
|
||||||
return operate
|
return operate
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from pathlib import Path
|
|||||||
from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload
|
from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import scipy.fft as sfft
|
||||||
import scipy.signal as ss
|
import scipy.signal as ss
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@@ -48,7 +49,12 @@ class Spectrum(np.ndarray):
|
|||||||
if buf[:2] != b"S\xc9":
|
if buf[:2] != b"S\xc9":
|
||||||
raise OSError("not a valid buffer")
|
raise OSError("not a valid buffer")
|
||||||
buf = buf[2:]
|
buf = buf[2:]
|
||||||
ifft = np.fft.ifft if buf[:4] == b"comp" else np.fft.irfft
|
ifft = {
|
||||||
|
b"comp": np.fft.ifft,
|
||||||
|
b"real": np.fft.irfft,
|
||||||
|
b"scic": sfft.ifft,
|
||||||
|
b"scir": sfft.irfft,
|
||||||
|
}
|
||||||
shape_n = buf[4]
|
shape_n = buf[4]
|
||||||
nt, *shape = (
|
nt, *shape = (
|
||||||
int.from_bytes(buf[5 + i * 8 : 5 + (i + 1) * 8], "big") for i in range(shape_n)
|
int.from_bytes(buf[5 + i * 8 : 5 + (i + 1) * 8], "big") for i in range(shape_n)
|
||||||
@@ -122,6 +128,10 @@ class Spectrum(np.ndarray):
|
|||||||
f_name = b"comp"
|
f_name = b"comp"
|
||||||
elif self.ifft is np.fft.irfft:
|
elif self.ifft is np.fft.irfft:
|
||||||
f_name = b"real"
|
f_name = b"real"
|
||||||
|
elif self.ifft is sfft.ifft:
|
||||||
|
f_name = b"scic"
|
||||||
|
elif self.ifft is sfft.irfft:
|
||||||
|
f_name = b"scir"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"cannot export ifft function {self.ifft!r}")
|
raise ValueError(f"cannot export ifft function {self.ifft!r}")
|
||||||
|
|
||||||
@@ -298,12 +308,10 @@ class Propagation(Generic[ParamsOrNone]):
|
|||||||
return self._current_index
|
return self._current_index
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum:
|
def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray:
|
def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray: ...
|
||||||
...
|
|
||||||
|
|
||||||
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
|
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
|
||||||
if isinstance(key, slice):
|
if isinstance(key, slice):
|
||||||
@@ -326,12 +334,10 @@ class Propagation(Generic[ParamsOrNone]):
|
|||||||
return array
|
return array
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum:
|
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def _load_slice(self: Propagation[None], key: slice) -> np.ndarray:
|
def _load_slice(self: Propagation[None], key: slice) -> np.ndarray: ...
|
||||||
...
|
|
||||||
|
|
||||||
def _load_slice(self, key: slice) -> Spectrum:
|
def _load_slice(self, key: slice) -> Spectrum:
|
||||||
self._warn_negative_index(key.start)
|
self._warn_negative_index(key.start)
|
||||||
@@ -412,12 +418,10 @@ class PropagationCollection:
|
|||||||
wl: np.ndarray
|
wl: np.ndarray
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: int) -> Propagation:
|
def __getitem__(self, key: int) -> Propagation: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: slice) -> list[Propagation]:
|
def __getitem__(self, key: slice) -> list[Propagation]: ...
|
||||||
...
|
|
||||||
|
|
||||||
def __getitem__(self, key: int | slice) -> Propagation | list[Propagation]:
|
def __getitem__(self, key: int | slice) -> Propagation | list[Propagation]:
|
||||||
return self.propagations[key]
|
return self.propagations[key]
|
||||||
|
|||||||
Reference in New Issue
Block a user