g12 update

This commit is contained in:
Benoît Sierro
2023-08-09 15:21:46 +02:00
parent a4e9dc4b9b
commit 5bbee88554
2 changed files with 6 additions and 5 deletions

View File

@@ -100,6 +100,7 @@ def _power_fact_array(x, n):
return result return result
@numba.njit()
def abs2(z: np.ndarray) -> np.ndarray: def abs2(z: np.ndarray) -> np.ndarray:
return z.real**2 + z.imag**2 return z.real**2 + z.imag**2

View File

@@ -672,23 +672,23 @@ def g12(values):
Parameters Parameters
---------- ----------
values : np.ndarray, shape (..., m, n) values : np.ndarray, shape (..., m, n, nt)
complex values following sc-ordering complex values following sc-ordering
Returns Returns
------- -------
np.ndarray, shape (..., n) np.ndarray, shape (..., n, nt)
coherence function coherence function
""" """
# Create all the possible pairs of values # Create all the possible pairs of values
n = len(values) n = len(values)
field_pairs = itertools.combinations(values, 2) field_pairs = itertools.combinations(values, 2)
mean_spec = np.mean(math.abs2(values), axis=-2) mean_spec = np.mean(math.abs2(values), axis=-3)
mask = mean_spec > 1e-15 * mean_spec.max() mask = mean_spec > 1e-15 * mean_spec.max()
corr = np.zeros_like(values[0]) corr = np.zeros_like(values[0])
for pair in field_pairs: for left, right in field_pairs:
corr[mask] += pair[0][mask].conj() * pair[1][mask] corr[mask] += left[mask].conj() * right[mask]
corr[mask] = corr[mask] / (n * (n - 1) / 2 * mean_spec[mask]) corr[mask] = corr[mask] / (n * (n - 1) / 2 * mean_spec[mask])
return np.abs(corr) return np.abs(corr)