diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 5ad3505..1b70331 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -1,5 +1,5 @@ # ruff: noqa -from scgenerator import io, math, operators, plotting +from scgenerator import io, math, noise, operators, plotting from scgenerator.helpers import * from scgenerator.io import MemoryIOHandler, ZipFileIOHandler from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace diff --git a/src/scgenerator/noise.py b/src/scgenerator/noise.py index 82406c6..1f71f65 100644 --- a/src/scgenerator/noise.py +++ b/src/scgenerator/noise.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import Callable, ClassVar import numpy as np from scipy.integrate import cumulative_trapezoid @@ -43,14 +44,17 @@ class NoiseMeasurement: phase: np.ndarray | None = None psd_interp: interp1d = field(init=False) is_uniform: bool = field(default=False, init=False) + _window_functions: ClassVar[dict[str, Callable[[int], np.ndarray]]] = {} - def __post_init__(self): - df = np.diff(self.freq) - if df.std() / df.mean() < 1e-12: - self.is_uniform = True - self.psd_interp = interp1d( - self.freq, self.psd, fill_value=(0, self.psd[-1]), bounds_error=False - ) + @classmethod + def window_function(cls, name: str): + def wrapper(func: Callable[[int], np.ndarray]): + if name in cls._window_functions: + raise ValueError(f"a function labeled {name!r} has already been registered") + cls._window_functions[name] = func + return func + + return wrapper @classmethod def from_dBc(cls, freq: np.ndarray, psd_dBc: np.ndarray) -> NoiseMeasurement: @@ -66,11 +70,33 @@ class NoiseMeasurement: return cls(freq, psd) @classmethod - def from_time_series(cls, time: np.ndarray, signal: np.ndarray) -> NoiseMeasurement: + def from_time_series( + cls, time: np.ndarray, signal: np.ndarray, window: str | None = None + ) -> NoiseMeasurement: + correction = 1 + n = len(time) + if window is not None: + win_arr = cls._window_functions[window](n) + signal = signal * win_arr + correction = np.sum(win_arr**2) / n + freq = np.fft.rfftfreq(len(time), time[1] - time[0]) dt = time[1] - time[0] - psd = np.fft.rfft(signal) / np.sqrt(0.5 * len(time) / dt) - return cls(freq, psd.real**2 + psd.imag**2, phase=np.angle(psd)) + psd = np.fft.rfft(signal) / np.sqrt(0.5 * n / dt) + psd = psd.real**2 + psd.imag**2 + return cls(freq, psd / correction, phase=np.angle(psd)) + + def __post_init__(self): + df = np.diff(self.freq) + if df.std() / df.mean() < 1e-12: + self.is_uniform = True + self.psd_interp = interp1d( + self.freq, self.psd, fill_value=(0, self.psd[-1]), bounds_error=False + ) + + @property + def psd_dBc(self) -> np.ndarray: + return np.log10(self.psd) * 10 def sample_spectrum(self, nt: int, dt: float | None = None) -> tuple[np.ndarray, np.ndarray]: """ @@ -131,3 +157,20 @@ class NoiseMeasurement: The 0th component is the total RIN in the frequency range covered by the measurement """ return integrated_rin(self.freq, self.psd) + + +@NoiseMeasurement.window_function("Bartlett") +def bartlett_window(n: int): + hn = 0.5 * n + return 1 - np.abs((np.arange(n) - hn) / hn) + + +@NoiseMeasurement.window_function("Welch") +def welch_window(n: int) -> np.ndarray: + hn = 0.5 * n + return 1 - ((np.arange(n) - hn) / hn) ** 2 + + +@NoiseMeasurement.window_function("Hann") +def hann_window(n: int) -> np.ndarray: + return 0.5 * (1 - np.cos(2 * np.pi * np.arange(n) / n))