change: better coherence function

This commit is contained in:
Benoît Sierro
2023-08-16 15:18:56 +02:00
parent eae75a5fd6
commit 3f64b669d5

View File

@@ -17,6 +17,7 @@ from pathlib import Path
from typing import Literal, Tuple, TypeVar from typing import Literal, Tuple, TypeVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numba
import numpy as np import numpy as np
from numpy import pi from numpy import pi
from numpy.fft import fft, fftshift, ifft from numpy.fft import fft, fftshift, ifft
@@ -715,34 +716,42 @@ def spectrogram(
return spec return spec
def g12(values: np.ndarray): def g12(values: np.ndarray, axis: int = 0):
""" """
computes the first order coherence function of a ensemble of values computes the first order coherence function of a ensemble of values
Parameters Parameters
---------- ----------
values : np.ndarray, shape (..., m, n, nt) values : np.ndarray, shape (..., n, nt)
complex values following sc-ordering complex values following sc-ordering
axis : int, optional
axis to collapse on which to compute the coherence
Returns Returns
------- -------
np.ndarray, shape (..., n, nt) np.ndarray, shape (..., nt)
coherence function coherence function
""" """
# Create all the possible pairs of values
n = len(values) n = len(values)
field_pairs = itertools.combinations(values, 2) mean_spec = np.mean(math.abs2(values), axis=axis)
mean_spec = np.mean(math.abs2(values), axis=-3) corr = np.zeros_like(mean_spec, dtype=complex)
mask = mean_spec > 1e-15 * mean_spec.max() mask = mean_spec > 1e-15 * mean_spec.max()
corr = np.zeros_like(values[0]) corr[mask] = _g12_fast(values[..., mask])
for left, right in field_pairs:
corr[mask] += left[mask].conj() * right[mask]
corr[mask] = corr[mask] / (n * (n - 1) / 2 * mean_spec[mask]) corr[mask] = corr[mask] / (n * (n - 1) / 2 * mean_spec[mask])
return np.abs(corr) return np.abs(corr)
@numba.njit()
def _g12_fast(values: np.ndarray) -> np.ndarray:
corr = np.zeros_like(values[0])
n = len(values)
for i in range(n - 1):
for j in range(i + 1, n):
corr += values[i].conj() * values[j]
return corr
def avg_g12(values: np.ndarray): def avg_g12(values: np.ndarray):
""" """
computes the average of the coherence function weighted by amplitude of spectrum computes the average of the coherence function weighted by amplitude of spectrum