diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 4da0bf1..11d1bbe 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -100,6 +100,7 @@ def _power_fact_array(x, n): return result +@numba.njit() def abs2(z: np.ndarray) -> np.ndarray: return z.real**2 + z.imag**2 diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index 0a1733a..9356a33 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -672,23 +672,23 @@ def g12(values): Parameters ---------- - values : np.ndarray, shape (..., m, n) + values : np.ndarray, shape (..., m, n, nt) complex values following sc-ordering Returns ------- - np.ndarray, shape (..., n) + np.ndarray, shape (..., n, nt) coherence function """ # Create all the possible pairs of values n = len(values) field_pairs = itertools.combinations(values, 2) - mean_spec = np.mean(math.abs2(values), axis=-2) + mean_spec = np.mean(math.abs2(values), axis=-3) mask = mean_spec > 1e-15 * mean_spec.max() corr = np.zeros_like(values[0]) - for pair in field_pairs: - corr[mask] += pair[0][mask].conj() * pair[1][mask] + for left, right in field_pairs: + corr[mask] += left[mask].conj() * right[mask] corr[mask] = corr[mask] / (n * (n - 1) / 2 * mean_spec[mask]) return np.abs(corr)