removed manual welch method for scipy's

This commit is contained in:
Benoît Sierro
2024-01-08 13:59:05 +01:00
parent eafb88a899
commit 2cfbb714de
2 changed files with 26 additions and 132 deletions

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass, field
from typing import Callable, ClassVar, Sequence from typing import Callable, ClassVar, Sequence
import numpy as np import numpy as np
import scipy.signal as ss
from scipy.integrate import cumulative_trapezoid from scipy.integrate import cumulative_trapezoid
from scgenerator import math, units from scgenerator import math, units
@@ -13,20 +14,7 @@ from scgenerator import math, units
class NoiseMeasurement: class NoiseMeasurement:
freq: np.ndarray freq: np.ndarray
psd: np.ndarray psd: np.ndarray
phase: np.ndarray | None = None
rng: np.random.Generator = field(default_factory=np.random.default_rng) rng: np.random.Generator = field(default_factory=np.random.default_rng)
_window_functions: ClassVar[dict[str, Callable[[int], np.ndarray]]] = {}
@classmethod
def window_function(cls, name: str):
if name in cls._window_functions:
raise ValueError(f"a function labeled {name!r} has already been registered")
def wrapper(func: Callable[[int], np.ndarray]):
cls._window_functions[name] = func
return func
return wrapper
@classmethod @classmethod
def from_dBc(cls, freq: np.ndarray, psd_dBc: np.ndarray) -> NoiseMeasurement: def from_dBc(cls, freq: np.ndarray, psd_dBc: np.ndarray) -> NoiseMeasurement:
@@ -46,9 +34,9 @@ class NoiseMeasurement:
cls, cls,
signal: Sequence[float], signal: Sequence[float],
dt: float = 1.0, dt: float = 1.0,
window: str | None = "Hann", window: str | None = "hann",
num_segments: int = 1, nperseg: int | None = None,
force_no_dc: bool = True, detrend: bool | str = "constant",
) -> NoiseMeasurement: ) -> NoiseMeasurement:
""" """
compute a PSD from a time-series measurement. compute a PSD from a time-series measurement.
@@ -60,44 +48,29 @@ class NoiseMeasurement:
signal to process. You may or may not remove the DC component, as this will only affect signal to process. You may or may not remove the DC component, as this will only affect
the 0 frequency bin of the PSD the 0 frequency bin of the PSD
window : str | None, optional window : str | None, optional
window to use on the input data to avoid leakage. Possible values are refer to scipy.signal.welch for possible windows
'Square', 'Bartlett', 'Welch' and 'Hann' (default). You may register your own window nperseg : int, optional
function by using `NoiseMeasurement.window_function` decorator, and use the name number of points per segment. The PSD of each segment is computed and then averaged
you gave it as this argument. to reduce variange. By default None, which means only one segment (i.e. the full signal
`None` is an alias for square, since in that case, no windowing is performed. at once) is computed.
num_segments : int, optional detrend : bool, optional
number of segments to cut the signal into. This will trade lower frequency information remove DC and optionally linear trend, by default only removes DC. See
for better variance of the estimated PSD. The default 1 means no cutting. scipy.signal.welch for more details.
force_no_dc : bool, optional
take out the DC component (0-frequency) of each segement after segmentation
""" """
signal = np.asanyarray(signal) signal = np.asanyarray(signal)
if signal.ndim > 1: if signal.ndim > 1:
raise ValueError( raise ValueError(
f"got signal of shape {signal.shape}. Only one 1D signals are supported" f"got signal of shape {signal.shape}. Only one 1D signals are supported"
) )
signal_segments = segments(signal, num_segments) if nperseg is None:
n = signal_segments.shape[-1] nperseg = len(signal)
try: if detrend is True:
window_arr = cls._window_functions[window](n) detrend = "constant"
except KeyError: if window is None:
raise ValueError( window = "boxcar"
f"window function {window!r} not found. " freq, psd = ss.welch(signal, fs=1 / dt, window=window, nperseg=nperseg, detrend=detrend)
f"Possible values are {set(cls._window_functions)}"
) from None
window_correction = np.sum(window_arr**2)
signal_segments = signal_segments * window_arr
if force_no_dc: return cls(freq, psd)
signal_segments = (signal_segments.T - signal_segments.mean(axis=1)).T
freq = np.fft.rfftfreq(n, dt)
psd = np.fft.rfft(signal_segments) * np.sqrt(dt)
phase = math.mean_angle(psd)
psd = psd.real**2 + psd.imag**2
psd[..., 1:-1] *= 2
return cls(freq, psd.mean(axis=0) / window_correction, phase=phase)
def plottable( def plottable(
self, self,
@@ -249,58 +222,5 @@ def integrated_noise(freq: np.ndarray, psd: np.ndarray) -> float:
return np.sqrt(cumulative_trapezoid(np.abs(psd)[::-1], -freq[::-1], initial=0)[::-1]) return np.sqrt(cumulative_trapezoid(np.abs(psd)[::-1], -freq[::-1], initial=0)[::-1])
@NoiseMeasurement.window_function("Square")
def square_window(n: int):
return np.ones(n)
NoiseMeasurement._window_functions[None] = square_window
@NoiseMeasurement.window_function("Bartlett")
def bartlett_window(n: int):
hn = 0.5 * n
return 1 - np.abs((np.arange(n) - hn) / hn)
@NoiseMeasurement.window_function("Welch")
def welch_window(n: int) -> np.ndarray:
hn = 0.5 * n
return 1 - ((np.arange(n) - hn) / hn) ** 2
@NoiseMeasurement.window_function("Hann")
def hann_window(n: int) -> np.ndarray:
return 0.5 * (1 - np.cos(2 * np.pi * np.arange(n) / n))
def segments(signal: np.ndarray, num_segments: int) -> np.ndarray:
"""
cut a signal into segments
"""
if num_segments < 1 or not isinstance(num_segments, (int, np.integer)):
raise ValueError(f"{num_segments = } but must be an integer and at least 1")
if num_segments == 1:
return signal[None]
n_init = len(signal)
seg_size = segement_size(n_init, num_segments)
seg = np.arange(seg_size)
off = int(n_init / (num_segments + 1))
return np.array([signal[seg + i * off] for i in range(num_segments)])
def segement_size(nt: int, num_segments: int) -> int:
return 1 << int(np.log2(nt / (num_segments + 1))) + 1
def get_frequencies(nt: int, num_segments: int, dt: float) -> np.ndarray:
"""
returns the frequency array that would be associated to a NoiseMeasurement.from_time_series
call, where `nt` is the size of the signal.
"""
return np.fft.rfftfreq(segement_size(nt, num_segments), dt)
def quantum_noise_limit(wavelength: float, power: float) -> float: def quantum_noise_limit(wavelength: float, power: float) -> float:
return units.m_rads(wavelength) * units.hbar * 2 / power return units.m_rads(wavelength) * units.hbar * 2 / power

View File

@@ -4,40 +4,16 @@ import pytest
import scgenerator as sc import scgenerator as sc
def test_segmentation():
t = np.arange(32)
r = np.arange(16)
assert np.all(sc.noise.segments(t, 3) == np.vstack([r, r + 8, r + 16]))
r = np.arange(8)
assert np.all(sc.noise.segments(t, 4) == np.vstack([r, r + 6, r + 12, r + 18]))
assert np.all(sc.noise.segments(t, 5) == np.vstack([r, r + 5, r + 10, r + 15, r + 20]))
assert np.all(sc.noise.segments(t, 6) == np.vstack([r, r + 4, r + 8, r + 12, r + 16, r + 20]))
assert np.all(
sc.noise.segments(t, 7) == np.vstack([r, r + 4, r + 8, r + 12, r + 16, r + 20, r + 24])
)
def test_normalisation(): def test_normalisation():
rng = np.random.default_rng(56) rng = np.random.default_rng(56)
t = np.linspace(-10, 10, 512) t = np.linspace(-10, 10, 512)
s = np.exp(-((t / 2.568) ** 2)) + rng.random(len(t)) / 15 s = np.exp(-((t / 2.568) ** 2)) + rng.random(len(t)) / 15
target = np.sum(sc.abs2(np.fft.fft(s))) / 512 target = np.sum(sc.abs2(np.fft.fft(s))) / 512
noise = sc.noise.NoiseMeasurement.from_time_series(s, 1, "Square", force_no_dc=False) noise = sc.noise.NoiseMeasurement.from_time_series(s, 1, "boxcar", detrend=False)
assert np.sum(noise.psd) == pytest.approx(target) assert np.sum(noise.psd) == pytest.approx(target)
def test_no_dc():
rng = np.random.default_rng(56)
t = np.linspace(-10, 10, 512)
s = rng.normal(0, 1, len(t)).cumsum()
noise = sc.noise.NoiseMeasurement.from_time_series(s, 1)
assert noise.psd[0] == pytest.approx(0)
def test_time_and_back(): def test_time_and_back():
""" """
sampling a time series from a spectrum and transforming sampling a time series from a spectrum and transforming
@@ -47,11 +23,9 @@ def test_time_and_back():
t = np.linspace(-10, 10, 512) t = np.linspace(-10, 10, 512)
signal = np.exp(-((t / 2.568) ** 2)) + rng.random(len(t)) / 15 signal = np.exp(-((t / 2.568) ** 2)) + rng.random(len(t)) / 15
noise = sc.noise.NoiseMeasurement.from_time_series(signal, 1, "Square", force_no_dc=False) noise = sc.noise.NoiseMeasurement.from_time_series(signal, 1, "boxcar", detrend=False)
_, new_signal = noise.time_series(len(t)) _, new_signal = noise.time_series(len(t))
new_noise = sc.noise.NoiseMeasurement.from_time_series( new_noise = sc.noise.NoiseMeasurement.from_time_series(new_signal, 1, "boxcar", detrend=False)
new_signal, 1, "Square", force_no_dc=False
)
assert new_noise.psd == pytest.approx(noise.psd) assert new_noise.psd == pytest.approx(noise.psd)
@@ -62,9 +36,9 @@ def test_nyquist():
""" """
signal = np.cos(np.arange(1024) * np.pi) signal = np.cos(np.arange(1024) * np.pi)
n1 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 1) n1 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None)
n3 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 3) n3 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 512)
n15 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 15) n15 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 128)
assert n1.psd[-1] == n3.psd[-1] * 2 == n15.psd[-1] * 8 assert n1.psd[-1] == n3.psd[-1] * 2 == n15.psd[-1] * 8