diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index bd9e9cf..c32f013 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -17,6 +17,7 @@ from pathlib import Path from typing import Literal, Tuple, TypeVar import matplotlib.pyplot as plt +import numba import numpy as np from numpy import pi from numpy.fft import fft, fftshift, ifft @@ -715,34 +716,42 @@ def spectrogram( return spec -def g12(values: np.ndarray): +def g12(values: np.ndarray, axis: int = 0): """ computes the first order coherence function of a ensemble of values Parameters ---------- - values : np.ndarray, shape (..., m, n, nt) + values : np.ndarray, shape (..., n, nt) complex values following sc-ordering + axis : int, optional + axis to collapse on which to compute the coherence Returns ------- - np.ndarray, shape (..., n, nt) + np.ndarray, shape (..., nt) coherence function """ - # Create all the possible pairs of values n = len(values) - field_pairs = itertools.combinations(values, 2) - mean_spec = np.mean(math.abs2(values), axis=-3) + mean_spec = np.mean(math.abs2(values), axis=axis) + corr = np.zeros_like(mean_spec, dtype=complex) mask = mean_spec > 1e-15 * mean_spec.max() - corr = np.zeros_like(values[0]) - for left, right in field_pairs: - corr[mask] += left[mask].conj() * right[mask] + corr[mask] = _g12_fast(values[..., mask]) corr[mask] = corr[mask] / (n * (n - 1) / 2 * mean_spec[mask]) - return np.abs(corr) +@numba.njit() +def _g12_fast(values: np.ndarray) -> np.ndarray: + corr = np.zeros_like(values[0]) + n = len(values) + for i in range(n - 1): + for j in range(i + 1, n): + corr += values[i].conj() * values[j] + return corr + + def avg_g12(values: np.ndarray): """ computes the average of the coherence function weighted by amplitude of spectrum