From 3741954d69cdfd833e0153beadac85df0a552cdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 6 Feb 2024 10:10:25 +0100 Subject: [PATCH] added scipy fft alternative --- src/scgenerator/math.py | 14 +++++++++++--- src/scgenerator/operators.py | 32 ++++++++++++++++++++++++-------- src/scgenerator/spectra.py | 30 +++++++++++++++++------------- 3 files changed, 52 insertions(+), 24 deletions(-) diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 3c5e14f..6457c91 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -3,6 +3,7 @@ collection of purely mathematical function """ import math +import platform import warnings from dataclasses import dataclass from functools import cache @@ -10,6 +11,7 @@ from typing import Callable, Sequence import numba import numpy as np +from scipy import fft as sfft from scipy.interpolate import interp1d, lagrange from scipy.special import jn_zeros @@ -20,10 +22,16 @@ c = 299792458.0 def fft_functions( full_field: bool, ) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]: - if full_field: - return np.fft.rfft, np.fft.irfft + if platform.processor() == "arm": + if full_field: + return sfft.rfft, sfft.irfft + else: + return sfft.fft, sfft.ifft else: - return np.fft.fft, np.fft.ifft + if full_field: + return np.fft.rfft, np.fft.irfft + else: + return np.fft.fft, np.fft.ifft def expm1_int(y: np.ndarray, dx: float) -> np.ndarray: diff --git a/src/scgenerator/operators.py b/src/scgenerator/operators.py index 62071d9..bf2c4ae 100644 --- a/src/scgenerator/operators.py +++ b/src/scgenerator/operators.py @@ -2,6 +2,7 @@ This file includes Dispersion, NonLinear and Loss classes to be used in the solver Nothing except the solver should depend on this file """ + from __future__ import annotations from typing import Callable @@ -268,21 +269,32 @@ def constant_wave_vector( ################################################## -def envelope_raman(hr_w: np.ndarray, raman_fraction: float) -> FieldOperator: +def envelope_raman( + hr_w: np.ndarray, + raman_fraction: float, + fft: Callable[[np.ndarray], np.ndarray], + ifft: Callable[[np.ndarray], np.ndarray], +) -> FieldOperator: def operate(field: np.ndarray, z: float) -> np.ndarray: - return raman_fraction * np.fft.ifft(hr_w * np.fft.fft(math.abs2(field))) + return raman_fraction * ifft(hr_w * fft(math.abs2(field))) return operate def full_field_raman( - raman_type: str, raman_fraction: float, t: np.ndarray, w: np.ndarray, chi3: float + raman_type: str, + raman_fraction: float, + t: np.ndarray, + w: np.ndarray, + chi3: float, + fft: Callable[[np.ndarray], np.ndarray], + ifft: Callable[[np.ndarray], np.ndarray], ) -> FieldOperator: hr_w = fiber.delayed_raman_w(t, raman_type) factor_in = units.epsilon0 * chi3 * raman_fraction def operate(field: np.ndarray, z: float) -> np.ndarray: - return factor_in * field * np.fft.irfft(hr_w * np.fft.rfft(math.abs2(field))) + return factor_in * field * ifft(hr_w * fft(math.abs2(field))) return operate @@ -461,14 +473,16 @@ def envelope_nonlinear_operator( ss_op: VariableQuantity, spm_op: FieldOperator, raman_op: FieldOperator, + fft: Callable[[np.ndarray], np.ndarray], + ifft: Callable[[np.ndarray], np.ndarray], ) -> SpecOperator: def operate(spec: np.ndarray, z: float) -> np.ndarray: - field = np.fft.ifft(spec) + field = ifft(spec) return ( -1j * gamma_op(z) * (1 + ss_op(z)) - * np.fft.fft(field * (spm_op(field, z) + raman_op(field, z))) + * fft(field * (spm_op(field, z) + raman_op(field, z))) ) return operate @@ -480,10 +494,12 @@ def full_field_nonlinear_operator( spm_op: FieldOperator, plasma_op: FieldOperator, fullfield_nl_prefactor: VariableQuantity, + fft: Callable[[np.ndarray], np.ndarray], + ifft: Callable[[np.ndarray], np.ndarray], ) -> SpecOperator: def operate(spec: np.ndarray, z: float) -> np.ndarray: - field = np.fft.irfft(spec) + field = ifft(spec) total_nonlinear = spm_op(field) + raman_op(field) + plasma_op(field) - return 1j * fullfield_nl_prefactor(z) * np.fft.rfft(total_nonlinear) + return 1j * fullfield_nl_prefactor(z) * fft(total_nonlinear) return operate diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index e129613..4c07de1 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload import numpy as np +import scipy.fft as sfft import scipy.signal as ss from tqdm import tqdm @@ -48,7 +49,12 @@ class Spectrum(np.ndarray): if buf[:2] != b"S\xc9": raise OSError("not a valid buffer") buf = buf[2:] - ifft = np.fft.ifft if buf[:4] == b"comp" else np.fft.irfft + ifft = { + b"comp": np.fft.ifft, + b"real": np.fft.irfft, + b"scic": sfft.ifft, + b"scir": sfft.irfft, + } shape_n = buf[4] nt, *shape = ( int.from_bytes(buf[5 + i * 8 : 5 + (i + 1) * 8], "big") for i in range(shape_n) @@ -122,6 +128,10 @@ class Spectrum(np.ndarray): f_name = b"comp" elif self.ifft is np.fft.irfft: f_name = b"real" + elif self.ifft is sfft.ifft: + f_name = b"scic" + elif self.ifft is sfft.irfft: + f_name = b"scir" else: raise ValueError(f"cannot export ifft function {self.ifft!r}") @@ -298,12 +308,10 @@ class Propagation(Generic[ParamsOrNone]): return self._current_index @overload - def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum: - ... + def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum: ... @overload - def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray: - ... + def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray: ... def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray: if isinstance(key, slice): @@ -326,12 +334,10 @@ class Propagation(Generic[ParamsOrNone]): return array @overload - def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum: - ... + def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum: ... @overload - def _load_slice(self: Propagation[None], key: slice) -> np.ndarray: - ... + def _load_slice(self: Propagation[None], key: slice) -> np.ndarray: ... def _load_slice(self, key: slice) -> Spectrum: self._warn_negative_index(key.start) @@ -412,12 +418,10 @@ class PropagationCollection: wl: np.ndarray @overload - def __getitem__(self, key: int) -> Propagation: - ... + def __getitem__(self, key: int) -> Propagation: ... @overload - def __getitem__(self, key: slice) -> list[Propagation]: - ... + def __getitem__(self, key: slice) -> list[Propagation]: ... def __getitem__(self, key: int | slice) -> Propagation | list[Propagation]: return self.propagations[key]