diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 22132aa..e486f2a 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -689,11 +689,8 @@ def mean_angle(values: np.ndarray, axis: int = 0): array([ 0.92387953+0.38268343j, 0.28978415+0.95709203j, -1. +0.j ]) """ - new_shape = values.shape[:axis] + values.shape[axis + 1 :] - total_phase = np.sum( - values / np.abs(values), - axis=axis, - where=values != 0, - out=np.zeros(new_shape, dtype="complex"), + values = np.divide( + values, np.abs(values), out=np.zeros(values.shape, dtype=values.dtype), where=values != 0 ) + total_phase = np.sum(values, axis=axis) return (total_phase) / np.abs(total_phase)