removed manual welch method for scipy's
This commit is contained in:
@@ -4,6 +4,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Callable, ClassVar, Sequence
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal as ss
|
||||
from scipy.integrate import cumulative_trapezoid
|
||||
|
||||
from scgenerator import math, units
|
||||
@@ -13,20 +14,7 @@ from scgenerator import math, units
|
||||
class NoiseMeasurement:
|
||||
freq: np.ndarray
|
||||
psd: np.ndarray
|
||||
phase: np.ndarray | None = None
|
||||
rng: np.random.Generator = field(default_factory=np.random.default_rng)
|
||||
_window_functions: ClassVar[dict[str, Callable[[int], np.ndarray]]] = {}
|
||||
|
||||
@classmethod
|
||||
def window_function(cls, name: str):
|
||||
if name in cls._window_functions:
|
||||
raise ValueError(f"a function labeled {name!r} has already been registered")
|
||||
|
||||
def wrapper(func: Callable[[int], np.ndarray]):
|
||||
cls._window_functions[name] = func
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
@classmethod
|
||||
def from_dBc(cls, freq: np.ndarray, psd_dBc: np.ndarray) -> NoiseMeasurement:
|
||||
@@ -46,9 +34,9 @@ class NoiseMeasurement:
|
||||
cls,
|
||||
signal: Sequence[float],
|
||||
dt: float = 1.0,
|
||||
window: str | None = "Hann",
|
||||
num_segments: int = 1,
|
||||
force_no_dc: bool = True,
|
||||
window: str | None = "hann",
|
||||
nperseg: int | None = None,
|
||||
detrend: bool | str = "constant",
|
||||
) -> NoiseMeasurement:
|
||||
"""
|
||||
compute a PSD from a time-series measurement.
|
||||
@@ -60,44 +48,29 @@ class NoiseMeasurement:
|
||||
signal to process. You may or may not remove the DC component, as this will only affect
|
||||
the 0 frequency bin of the PSD
|
||||
window : str | None, optional
|
||||
window to use on the input data to avoid leakage. Possible values are
|
||||
'Square', 'Bartlett', 'Welch' and 'Hann' (default). You may register your own window
|
||||
function by using `NoiseMeasurement.window_function` decorator, and use the name
|
||||
you gave it as this argument.
|
||||
`None` is an alias for square, since in that case, no windowing is performed.
|
||||
num_segments : int, optional
|
||||
number of segments to cut the signal into. This will trade lower frequency information
|
||||
for better variance of the estimated PSD. The default 1 means no cutting.
|
||||
force_no_dc : bool, optional
|
||||
take out the DC component (0-frequency) of each segement after segmentation
|
||||
refer to scipy.signal.welch for possible windows
|
||||
nperseg : int, optional
|
||||
number of points per segment. The PSD of each segment is computed and then averaged
|
||||
to reduce variange. By default None, which means only one segment (i.e. the full signal
|
||||
at once) is computed.
|
||||
detrend : bool, optional
|
||||
remove DC and optionally linear trend, by default only removes DC. See
|
||||
scipy.signal.welch for more details.
|
||||
"""
|
||||
signal = np.asanyarray(signal)
|
||||
if signal.ndim > 1:
|
||||
raise ValueError(
|
||||
f"got signal of shape {signal.shape}. Only one 1D signals are supported"
|
||||
)
|
||||
signal_segments = segments(signal, num_segments)
|
||||
n = signal_segments.shape[-1]
|
||||
try:
|
||||
window_arr = cls._window_functions[window](n)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"window function {window!r} not found. "
|
||||
f"Possible values are {set(cls._window_functions)}"
|
||||
) from None
|
||||
window_correction = np.sum(window_arr**2)
|
||||
signal_segments = signal_segments * window_arr
|
||||
if nperseg is None:
|
||||
nperseg = len(signal)
|
||||
if detrend is True:
|
||||
detrend = "constant"
|
||||
if window is None:
|
||||
window = "boxcar"
|
||||
freq, psd = ss.welch(signal, fs=1 / dt, window=window, nperseg=nperseg, detrend=detrend)
|
||||
|
||||
if force_no_dc:
|
||||
signal_segments = (signal_segments.T - signal_segments.mean(axis=1)).T
|
||||
|
||||
freq = np.fft.rfftfreq(n, dt)
|
||||
psd = np.fft.rfft(signal_segments) * np.sqrt(dt)
|
||||
phase = math.mean_angle(psd)
|
||||
psd = psd.real**2 + psd.imag**2
|
||||
psd[..., 1:-1] *= 2
|
||||
|
||||
return cls(freq, psd.mean(axis=0) / window_correction, phase=phase)
|
||||
return cls(freq, psd)
|
||||
|
||||
def plottable(
|
||||
self,
|
||||
@@ -249,58 +222,5 @@ def integrated_noise(freq: np.ndarray, psd: np.ndarray) -> float:
|
||||
return np.sqrt(cumulative_trapezoid(np.abs(psd)[::-1], -freq[::-1], initial=0)[::-1])
|
||||
|
||||
|
||||
@NoiseMeasurement.window_function("Square")
|
||||
def square_window(n: int):
|
||||
return np.ones(n)
|
||||
|
||||
|
||||
NoiseMeasurement._window_functions[None] = square_window
|
||||
|
||||
|
||||
@NoiseMeasurement.window_function("Bartlett")
|
||||
def bartlett_window(n: int):
|
||||
hn = 0.5 * n
|
||||
return 1 - np.abs((np.arange(n) - hn) / hn)
|
||||
|
||||
|
||||
@NoiseMeasurement.window_function("Welch")
|
||||
def welch_window(n: int) -> np.ndarray:
|
||||
hn = 0.5 * n
|
||||
return 1 - ((np.arange(n) - hn) / hn) ** 2
|
||||
|
||||
|
||||
@NoiseMeasurement.window_function("Hann")
|
||||
def hann_window(n: int) -> np.ndarray:
|
||||
return 0.5 * (1 - np.cos(2 * np.pi * np.arange(n) / n))
|
||||
|
||||
|
||||
def segments(signal: np.ndarray, num_segments: int) -> np.ndarray:
|
||||
"""
|
||||
cut a signal into segments
|
||||
"""
|
||||
if num_segments < 1 or not isinstance(num_segments, (int, np.integer)):
|
||||
raise ValueError(f"{num_segments = } but must be an integer and at least 1")
|
||||
if num_segments == 1:
|
||||
return signal[None]
|
||||
n_init = len(signal)
|
||||
seg_size = segement_size(n_init, num_segments)
|
||||
seg = np.arange(seg_size)
|
||||
off = int(n_init / (num_segments + 1))
|
||||
return np.array([signal[seg + i * off] for i in range(num_segments)])
|
||||
|
||||
|
||||
def segement_size(nt: int, num_segments: int) -> int:
|
||||
return 1 << int(np.log2(nt / (num_segments + 1))) + 1
|
||||
|
||||
|
||||
def get_frequencies(nt: int, num_segments: int, dt: float) -> np.ndarray:
|
||||
"""
|
||||
returns the frequency array that would be associated to a NoiseMeasurement.from_time_series
|
||||
call, where `nt` is the size of the signal.
|
||||
"""
|
||||
|
||||
return np.fft.rfftfreq(segement_size(nt, num_segments), dt)
|
||||
|
||||
|
||||
def quantum_noise_limit(wavelength: float, power: float) -> float:
|
||||
return units.m_rads(wavelength) * units.hbar * 2 / power
|
||||
|
||||
@@ -4,40 +4,16 @@ import pytest
|
||||
import scgenerator as sc
|
||||
|
||||
|
||||
def test_segmentation():
|
||||
t = np.arange(32)
|
||||
|
||||
r = np.arange(16)
|
||||
assert np.all(sc.noise.segments(t, 3) == np.vstack([r, r + 8, r + 16]))
|
||||
|
||||
r = np.arange(8)
|
||||
|
||||
assert np.all(sc.noise.segments(t, 4) == np.vstack([r, r + 6, r + 12, r + 18]))
|
||||
assert np.all(sc.noise.segments(t, 5) == np.vstack([r, r + 5, r + 10, r + 15, r + 20]))
|
||||
assert np.all(sc.noise.segments(t, 6) == np.vstack([r, r + 4, r + 8, r + 12, r + 16, r + 20]))
|
||||
assert np.all(
|
||||
sc.noise.segments(t, 7) == np.vstack([r, r + 4, r + 8, r + 12, r + 16, r + 20, r + 24])
|
||||
)
|
||||
|
||||
|
||||
def test_normalisation():
|
||||
rng = np.random.default_rng(56)
|
||||
t = np.linspace(-10, 10, 512)
|
||||
s = np.exp(-((t / 2.568) ** 2)) + rng.random(len(t)) / 15
|
||||
|
||||
target = np.sum(sc.abs2(np.fft.fft(s))) / 512
|
||||
noise = sc.noise.NoiseMeasurement.from_time_series(s, 1, "Square", force_no_dc=False)
|
||||
noise = sc.noise.NoiseMeasurement.from_time_series(s, 1, "boxcar", detrend=False)
|
||||
assert np.sum(noise.psd) == pytest.approx(target)
|
||||
|
||||
|
||||
def test_no_dc():
|
||||
rng = np.random.default_rng(56)
|
||||
t = np.linspace(-10, 10, 512)
|
||||
s = rng.normal(0, 1, len(t)).cumsum()
|
||||
noise = sc.noise.NoiseMeasurement.from_time_series(s, 1)
|
||||
assert noise.psd[0] == pytest.approx(0)
|
||||
|
||||
|
||||
def test_time_and_back():
|
||||
"""
|
||||
sampling a time series from a spectrum and transforming
|
||||
@@ -47,11 +23,9 @@ def test_time_and_back():
|
||||
t = np.linspace(-10, 10, 512)
|
||||
signal = np.exp(-((t / 2.568) ** 2)) + rng.random(len(t)) / 15
|
||||
|
||||
noise = sc.noise.NoiseMeasurement.from_time_series(signal, 1, "Square", force_no_dc=False)
|
||||
noise = sc.noise.NoiseMeasurement.from_time_series(signal, 1, "boxcar", detrend=False)
|
||||
_, new_signal = noise.time_series(len(t))
|
||||
new_noise = sc.noise.NoiseMeasurement.from_time_series(
|
||||
new_signal, 1, "Square", force_no_dc=False
|
||||
)
|
||||
new_noise = sc.noise.NoiseMeasurement.from_time_series(new_signal, 1, "boxcar", detrend=False)
|
||||
assert new_noise.psd == pytest.approx(noise.psd)
|
||||
|
||||
|
||||
@@ -62,9 +36,9 @@ def test_nyquist():
|
||||
"""
|
||||
signal = np.cos(np.arange(1024) * np.pi)
|
||||
|
||||
n1 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 1)
|
||||
n3 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 3)
|
||||
n15 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 15)
|
||||
n1 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None)
|
||||
n3 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 512)
|
||||
n15 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 128)
|
||||
|
||||
assert n1.psd[-1] == n3.psd[-1] * 2 == n15.psd[-1] * 8
|
||||
|
||||
|
||||
Reference in New Issue
Block a user