g12 update
This commit is contained in:
@@ -100,6 +100,7 @@ def _power_fact_array(x, n):
|
||||
return result
|
||||
|
||||
|
||||
@numba.njit()
|
||||
def abs2(z: np.ndarray) -> np.ndarray:
|
||||
return z.real**2 + z.imag**2
|
||||
|
||||
|
||||
@@ -672,23 +672,23 @@ def g12(values):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
values : np.ndarray, shape (..., m, n)
|
||||
values : np.ndarray, shape (..., m, n, nt)
|
||||
complex values following sc-ordering
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray, shape (..., n)
|
||||
np.ndarray, shape (..., n, nt)
|
||||
coherence function
|
||||
"""
|
||||
|
||||
# Create all the possible pairs of values
|
||||
n = len(values)
|
||||
field_pairs = itertools.combinations(values, 2)
|
||||
mean_spec = np.mean(math.abs2(values), axis=-2)
|
||||
mean_spec = np.mean(math.abs2(values), axis=-3)
|
||||
mask = mean_spec > 1e-15 * mean_spec.max()
|
||||
corr = np.zeros_like(values[0])
|
||||
for pair in field_pairs:
|
||||
corr[mask] += pair[0][mask].conj() * pair[1][mask]
|
||||
for left, right in field_pairs:
|
||||
corr[mask] += left[mask].conj() * right[mask]
|
||||
corr[mask] = corr[mask] / (n * (n - 1) / 2 * mean_spec[mask])
|
||||
|
||||
return np.abs(corr)
|
||||
|
||||
Reference in New Issue
Block a user