From 24c51a7dd8eac98dfe4db77f60c40a3de7f30cbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 26 Sep 2023 16:32:06 +0200 Subject: [PATCH] njit update --- src/scgenerator/solver.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/scgenerator/solver.py b/src/scgenerator/solver.py index 52867d5..7f95859 100644 --- a/src/scgenerator/solver.py +++ b/src/scgenerator/solver.py @@ -69,13 +69,14 @@ class SimulationResult: return cls(spectra, stats, z) -@numba.jit(nopython=True) +@numba.njit def compute_diff(coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float: diff = coarse_spec - fine_spec diff2 = diff.imag**2 + diff.real**2 return np.sqrt(diff2.sum() / (fine_spec.real**2 + fine_spec.imag**2).sum()) +@numba.njit def weaknorm(fine: np.ndarray, coarse: np.ndarray, rtol: float, atol: float) -> float: alpha = max(max(np.sqrt(abs2(fine).sum()), np.sqrt(abs2(coarse).sum())), atol) return 1 / (alpha * rtol) * np.sqrt(abs2(coarse - fine).sum()) @@ -86,6 +87,7 @@ def norm_hairer(fine: np.ndarray, coarse: np.ndarray, rtol: float, atol: float) return np.sqrt(abs2((fine - coarse) / (atol + rtol * alpha)).mean()) +@numba.njit def pi_step_factor(error: float, last_error: float, order: int, eps: float = 0.8): """ computes the next step factor based on the current and last error.