removed manual welch method for scipy's

This commit is contained in:
Benoît Sierro
2024-01-08 13:59:05 +01:00
parent eafb88a899
commit 2cfbb714de
2 changed files with 26 additions and 132 deletions

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass, field
from typing import Callable, ClassVar, Sequence
import numpy as np
import scipy.signal as ss
from scipy.integrate import cumulative_trapezoid
from scgenerator import math, units
@@ -13,20 +14,7 @@ from scgenerator import math, units
class NoiseMeasurement:
freq: np.ndarray
psd: np.ndarray
phase: np.ndarray | None = None
rng: np.random.Generator = field(default_factory=np.random.default_rng)
_window_functions: ClassVar[dict[str, Callable[[int], np.ndarray]]] = {}
@classmethod
def window_function(cls, name: str):
if name in cls._window_functions:
raise ValueError(f"a function labeled {name!r} has already been registered")
def wrapper(func: Callable[[int], np.ndarray]):
cls._window_functions[name] = func
return func
return wrapper
@classmethod
def from_dBc(cls, freq: np.ndarray, psd_dBc: np.ndarray) -> NoiseMeasurement:
@@ -46,9 +34,9 @@ class NoiseMeasurement:
cls,
signal: Sequence[float],
dt: float = 1.0,
window: str | None = "Hann",
num_segments: int = 1,
force_no_dc: bool = True,
window: str | None = "hann",
nperseg: int | None = None,
detrend: bool | str = "constant",
) -> NoiseMeasurement:
"""
compute a PSD from a time-series measurement.
@@ -60,44 +48,29 @@ class NoiseMeasurement:
signal to process. You may or may not remove the DC component, as this will only affect
the 0 frequency bin of the PSD
window : str | None, optional
window to use on the input data to avoid leakage. Possible values are
'Square', 'Bartlett', 'Welch' and 'Hann' (default). You may register your own window
function by using `NoiseMeasurement.window_function` decorator, and use the name
you gave it as this argument.
`None` is an alias for square, since in that case, no windowing is performed.
num_segments : int, optional
number of segments to cut the signal into. This will trade lower frequency information
for better variance of the estimated PSD. The default 1 means no cutting.
force_no_dc : bool, optional
take out the DC component (0-frequency) of each segement after segmentation
refer to scipy.signal.welch for possible windows
nperseg : int, optional
number of points per segment. The PSD of each segment is computed and then averaged
to reduce variange. By default None, which means only one segment (i.e. the full signal
at once) is computed.
detrend : bool, optional
remove DC and optionally linear trend, by default only removes DC. See
scipy.signal.welch for more details.
"""
signal = np.asanyarray(signal)
if signal.ndim > 1:
raise ValueError(
f"got signal of shape {signal.shape}. Only one 1D signals are supported"
)
signal_segments = segments(signal, num_segments)
n = signal_segments.shape[-1]
try:
window_arr = cls._window_functions[window](n)
except KeyError:
raise ValueError(
f"window function {window!r} not found. "
f"Possible values are {set(cls._window_functions)}"
) from None
window_correction = np.sum(window_arr**2)
signal_segments = signal_segments * window_arr
if nperseg is None:
nperseg = len(signal)
if detrend is True:
detrend = "constant"
if window is None:
window = "boxcar"
freq, psd = ss.welch(signal, fs=1 / dt, window=window, nperseg=nperseg, detrend=detrend)
if force_no_dc:
signal_segments = (signal_segments.T - signal_segments.mean(axis=1)).T
freq = np.fft.rfftfreq(n, dt)
psd = np.fft.rfft(signal_segments) * np.sqrt(dt)
phase = math.mean_angle(psd)
psd = psd.real**2 + psd.imag**2
psd[..., 1:-1] *= 2
return cls(freq, psd.mean(axis=0) / window_correction, phase=phase)
return cls(freq, psd)
def plottable(
self,
@@ -249,58 +222,5 @@ def integrated_noise(freq: np.ndarray, psd: np.ndarray) -> float:
return np.sqrt(cumulative_trapezoid(np.abs(psd)[::-1], -freq[::-1], initial=0)[::-1])
@NoiseMeasurement.window_function("Square")
def square_window(n: int):
return np.ones(n)
NoiseMeasurement._window_functions[None] = square_window
@NoiseMeasurement.window_function("Bartlett")
def bartlett_window(n: int):
hn = 0.5 * n
return 1 - np.abs((np.arange(n) - hn) / hn)
@NoiseMeasurement.window_function("Welch")
def welch_window(n: int) -> np.ndarray:
hn = 0.5 * n
return 1 - ((np.arange(n) - hn) / hn) ** 2
@NoiseMeasurement.window_function("Hann")
def hann_window(n: int) -> np.ndarray:
return 0.5 * (1 - np.cos(2 * np.pi * np.arange(n) / n))
def segments(signal: np.ndarray, num_segments: int) -> np.ndarray:
"""
cut a signal into segments
"""
if num_segments < 1 or not isinstance(num_segments, (int, np.integer)):
raise ValueError(f"{num_segments = } but must be an integer and at least 1")
if num_segments == 1:
return signal[None]
n_init = len(signal)
seg_size = segement_size(n_init, num_segments)
seg = np.arange(seg_size)
off = int(n_init / (num_segments + 1))
return np.array([signal[seg + i * off] for i in range(num_segments)])
def segement_size(nt: int, num_segments: int) -> int:
return 1 << int(np.log2(nt / (num_segments + 1))) + 1
def get_frequencies(nt: int, num_segments: int, dt: float) -> np.ndarray:
"""
returns the frequency array that would be associated to a NoiseMeasurement.from_time_series
call, where `nt` is the size of the signal.
"""
return np.fft.rfftfreq(segement_size(nt, num_segments), dt)
def quantum_noise_limit(wavelength: float, power: float) -> float:
return units.m_rads(wavelength) * units.hbar * 2 / power

View File

@@ -4,40 +4,16 @@ import pytest
import scgenerator as sc
def test_segmentation():
t = np.arange(32)
r = np.arange(16)
assert np.all(sc.noise.segments(t, 3) == np.vstack([r, r + 8, r + 16]))
r = np.arange(8)
assert np.all(sc.noise.segments(t, 4) == np.vstack([r, r + 6, r + 12, r + 18]))
assert np.all(sc.noise.segments(t, 5) == np.vstack([r, r + 5, r + 10, r + 15, r + 20]))
assert np.all(sc.noise.segments(t, 6) == np.vstack([r, r + 4, r + 8, r + 12, r + 16, r + 20]))
assert np.all(
sc.noise.segments(t, 7) == np.vstack([r, r + 4, r + 8, r + 12, r + 16, r + 20, r + 24])
)
def test_normalisation():
rng = np.random.default_rng(56)
t = np.linspace(-10, 10, 512)
s = np.exp(-((t / 2.568) ** 2)) + rng.random(len(t)) / 15
target = np.sum(sc.abs2(np.fft.fft(s))) / 512
noise = sc.noise.NoiseMeasurement.from_time_series(s, 1, "Square", force_no_dc=False)
noise = sc.noise.NoiseMeasurement.from_time_series(s, 1, "boxcar", detrend=False)
assert np.sum(noise.psd) == pytest.approx(target)
def test_no_dc():
rng = np.random.default_rng(56)
t = np.linspace(-10, 10, 512)
s = rng.normal(0, 1, len(t)).cumsum()
noise = sc.noise.NoiseMeasurement.from_time_series(s, 1)
assert noise.psd[0] == pytest.approx(0)
def test_time_and_back():
"""
sampling a time series from a spectrum and transforming
@@ -47,11 +23,9 @@ def test_time_and_back():
t = np.linspace(-10, 10, 512)
signal = np.exp(-((t / 2.568) ** 2)) + rng.random(len(t)) / 15
noise = sc.noise.NoiseMeasurement.from_time_series(signal, 1, "Square", force_no_dc=False)
noise = sc.noise.NoiseMeasurement.from_time_series(signal, 1, "boxcar", detrend=False)
_, new_signal = noise.time_series(len(t))
new_noise = sc.noise.NoiseMeasurement.from_time_series(
new_signal, 1, "Square", force_no_dc=False
)
new_noise = sc.noise.NoiseMeasurement.from_time_series(new_signal, 1, "boxcar", detrend=False)
assert new_noise.psd == pytest.approx(noise.psd)
@@ -62,9 +36,9 @@ def test_nyquist():
"""
signal = np.cos(np.arange(1024) * np.pi)
n1 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 1)
n3 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 3)
n15 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 15)
n1 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None)
n3 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 512)
n15 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 128)
assert n1.psd[-1] == n3.psd[-1] * 2 == n15.psd[-1] * 8