njit update

This commit is contained in:
Benoît Sierro
2023-09-26 16:32:06 +02:00
parent 10026ea8a0
commit 24c51a7dd8

View File

@@ -69,13 +69,14 @@ class SimulationResult:
return cls(spectra, stats, z) return cls(spectra, stats, z)
@numba.jit(nopython=True) @numba.njit
def compute_diff(coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float: def compute_diff(coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float:
diff = coarse_spec - fine_spec diff = coarse_spec - fine_spec
diff2 = diff.imag**2 + diff.real**2 diff2 = diff.imag**2 + diff.real**2
return np.sqrt(diff2.sum() / (fine_spec.real**2 + fine_spec.imag**2).sum()) return np.sqrt(diff2.sum() / (fine_spec.real**2 + fine_spec.imag**2).sum())
@numba.njit
def weaknorm(fine: np.ndarray, coarse: np.ndarray, rtol: float, atol: float) -> float: def weaknorm(fine: np.ndarray, coarse: np.ndarray, rtol: float, atol: float) -> float:
alpha = max(max(np.sqrt(abs2(fine).sum()), np.sqrt(abs2(coarse).sum())), atol) alpha = max(max(np.sqrt(abs2(fine).sum()), np.sqrt(abs2(coarse).sum())), atol)
return 1 / (alpha * rtol) * np.sqrt(abs2(coarse - fine).sum()) return 1 / (alpha * rtol) * np.sqrt(abs2(coarse - fine).sum())
@@ -86,6 +87,7 @@ def norm_hairer(fine: np.ndarray, coarse: np.ndarray, rtol: float, atol: float)
return np.sqrt(abs2((fine - coarse) / (atol + rtol * alpha)).mean()) return np.sqrt(abs2((fine - coarse) / (atol + rtol * alpha)).mean())
@numba.njit
def pi_step_factor(error: float, last_error: float, order: int, eps: float = 0.8): def pi_step_factor(error: float, last_error: float, order: int, eps: float = 0.8):
""" """
computes the next step factor based on the current and last error. computes the next step factor based on the current and last error.