diff --git a/src/scgenerator/noise.py b/src/scgenerator/noise.py index 63cf7a0..7a1065a 100644 --- a/src/scgenerator/noise.py +++ b/src/scgenerator/noise.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from typing import Callable, ClassVar, Sequence import numpy as np +import scipy.signal as ss from scipy.integrate import cumulative_trapezoid from scgenerator import math, units @@ -13,20 +14,7 @@ from scgenerator import math, units class NoiseMeasurement: freq: np.ndarray psd: np.ndarray - phase: np.ndarray | None = None rng: np.random.Generator = field(default_factory=np.random.default_rng) - _window_functions: ClassVar[dict[str, Callable[[int], np.ndarray]]] = {} - - @classmethod - def window_function(cls, name: str): - if name in cls._window_functions: - raise ValueError(f"a function labeled {name!r} has already been registered") - - def wrapper(func: Callable[[int], np.ndarray]): - cls._window_functions[name] = func - return func - - return wrapper @classmethod def from_dBc(cls, freq: np.ndarray, psd_dBc: np.ndarray) -> NoiseMeasurement: @@ -46,9 +34,9 @@ class NoiseMeasurement: cls, signal: Sequence[float], dt: float = 1.0, - window: str | None = "Hann", - num_segments: int = 1, - force_no_dc: bool = True, + window: str | None = "hann", + nperseg: int | None = None, + detrend: bool | str = "constant", ) -> NoiseMeasurement: """ compute a PSD from a time-series measurement. @@ -60,44 +48,29 @@ class NoiseMeasurement: signal to process. You may or may not remove the DC component, as this will only affect the 0 frequency bin of the PSD window : str | None, optional - window to use on the input data to avoid leakage. Possible values are - 'Square', 'Bartlett', 'Welch' and 'Hann' (default). You may register your own window - function by using `NoiseMeasurement.window_function` decorator, and use the name - you gave it as this argument. - `None` is an alias for square, since in that case, no windowing is performed. - num_segments : int, optional - number of segments to cut the signal into. This will trade lower frequency information - for better variance of the estimated PSD. The default 1 means no cutting. - force_no_dc : bool, optional - take out the DC component (0-frequency) of each segement after segmentation + refer to scipy.signal.welch for possible windows + nperseg : int, optional + number of points per segment. The PSD of each segment is computed and then averaged + to reduce variange. By default None, which means only one segment (i.e. the full signal + at once) is computed. + detrend : bool, optional + remove DC and optionally linear trend, by default only removes DC. See + scipy.signal.welch for more details. """ signal = np.asanyarray(signal) if signal.ndim > 1: raise ValueError( f"got signal of shape {signal.shape}. Only one 1D signals are supported" ) - signal_segments = segments(signal, num_segments) - n = signal_segments.shape[-1] - try: - window_arr = cls._window_functions[window](n) - except KeyError: - raise ValueError( - f"window function {window!r} not found. " - f"Possible values are {set(cls._window_functions)}" - ) from None - window_correction = np.sum(window_arr**2) - signal_segments = signal_segments * window_arr + if nperseg is None: + nperseg = len(signal) + if detrend is True: + detrend = "constant" + if window is None: + window = "boxcar" + freq, psd = ss.welch(signal, fs=1 / dt, window=window, nperseg=nperseg, detrend=detrend) - if force_no_dc: - signal_segments = (signal_segments.T - signal_segments.mean(axis=1)).T - - freq = np.fft.rfftfreq(n, dt) - psd = np.fft.rfft(signal_segments) * np.sqrt(dt) - phase = math.mean_angle(psd) - psd = psd.real**2 + psd.imag**2 - psd[..., 1:-1] *= 2 - - return cls(freq, psd.mean(axis=0) / window_correction, phase=phase) + return cls(freq, psd) def plottable( self, @@ -249,58 +222,5 @@ def integrated_noise(freq: np.ndarray, psd: np.ndarray) -> float: return np.sqrt(cumulative_trapezoid(np.abs(psd)[::-1], -freq[::-1], initial=0)[::-1]) -@NoiseMeasurement.window_function("Square") -def square_window(n: int): - return np.ones(n) - - -NoiseMeasurement._window_functions[None] = square_window - - -@NoiseMeasurement.window_function("Bartlett") -def bartlett_window(n: int): - hn = 0.5 * n - return 1 - np.abs((np.arange(n) - hn) / hn) - - -@NoiseMeasurement.window_function("Welch") -def welch_window(n: int) -> np.ndarray: - hn = 0.5 * n - return 1 - ((np.arange(n) - hn) / hn) ** 2 - - -@NoiseMeasurement.window_function("Hann") -def hann_window(n: int) -> np.ndarray: - return 0.5 * (1 - np.cos(2 * np.pi * np.arange(n) / n)) - - -def segments(signal: np.ndarray, num_segments: int) -> np.ndarray: - """ - cut a signal into segments - """ - if num_segments < 1 or not isinstance(num_segments, (int, np.integer)): - raise ValueError(f"{num_segments = } but must be an integer and at least 1") - if num_segments == 1: - return signal[None] - n_init = len(signal) - seg_size = segement_size(n_init, num_segments) - seg = np.arange(seg_size) - off = int(n_init / (num_segments + 1)) - return np.array([signal[seg + i * off] for i in range(num_segments)]) - - -def segement_size(nt: int, num_segments: int) -> int: - return 1 << int(np.log2(nt / (num_segments + 1))) + 1 - - -def get_frequencies(nt: int, num_segments: int, dt: float) -> np.ndarray: - """ - returns the frequency array that would be associated to a NoiseMeasurement.from_time_series - call, where `nt` is the size of the signal. - """ - - return np.fft.rfftfreq(segement_size(nt, num_segments), dt) - - def quantum_noise_limit(wavelength: float, power: float) -> float: return units.m_rads(wavelength) * units.hbar * 2 / power diff --git a/tests/test_noise.py b/tests/test_noise.py index 3e470e2..3d36111 100644 --- a/tests/test_noise.py +++ b/tests/test_noise.py @@ -4,40 +4,16 @@ import pytest import scgenerator as sc -def test_segmentation(): - t = np.arange(32) - - r = np.arange(16) - assert np.all(sc.noise.segments(t, 3) == np.vstack([r, r + 8, r + 16])) - - r = np.arange(8) - - assert np.all(sc.noise.segments(t, 4) == np.vstack([r, r + 6, r + 12, r + 18])) - assert np.all(sc.noise.segments(t, 5) == np.vstack([r, r + 5, r + 10, r + 15, r + 20])) - assert np.all(sc.noise.segments(t, 6) == np.vstack([r, r + 4, r + 8, r + 12, r + 16, r + 20])) - assert np.all( - sc.noise.segments(t, 7) == np.vstack([r, r + 4, r + 8, r + 12, r + 16, r + 20, r + 24]) - ) - - def test_normalisation(): rng = np.random.default_rng(56) t = np.linspace(-10, 10, 512) s = np.exp(-((t / 2.568) ** 2)) + rng.random(len(t)) / 15 target = np.sum(sc.abs2(np.fft.fft(s))) / 512 - noise = sc.noise.NoiseMeasurement.from_time_series(s, 1, "Square", force_no_dc=False) + noise = sc.noise.NoiseMeasurement.from_time_series(s, 1, "boxcar", detrend=False) assert np.sum(noise.psd) == pytest.approx(target) -def test_no_dc(): - rng = np.random.default_rng(56) - t = np.linspace(-10, 10, 512) - s = rng.normal(0, 1, len(t)).cumsum() - noise = sc.noise.NoiseMeasurement.from_time_series(s, 1) - assert noise.psd[0] == pytest.approx(0) - - def test_time_and_back(): """ sampling a time series from a spectrum and transforming @@ -47,11 +23,9 @@ def test_time_and_back(): t = np.linspace(-10, 10, 512) signal = np.exp(-((t / 2.568) ** 2)) + rng.random(len(t)) / 15 - noise = sc.noise.NoiseMeasurement.from_time_series(signal, 1, "Square", force_no_dc=False) + noise = sc.noise.NoiseMeasurement.from_time_series(signal, 1, "boxcar", detrend=False) _, new_signal = noise.time_series(len(t)) - new_noise = sc.noise.NoiseMeasurement.from_time_series( - new_signal, 1, "Square", force_no_dc=False - ) + new_noise = sc.noise.NoiseMeasurement.from_time_series(new_signal, 1, "boxcar", detrend=False) assert new_noise.psd == pytest.approx(noise.psd) @@ -62,9 +36,9 @@ def test_nyquist(): """ signal = np.cos(np.arange(1024) * np.pi) - n1 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 1) - n3 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 3) - n15 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 15) + n1 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None) + n3 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 512) + n15 = sc.noise.NoiseMeasurement.from_time_series(signal, 1, None, 128) assert n1.psd[-1] == n3.psd[-1] * 2 == n15.psd[-1] * 8