From 3b4fd175df256cd0c731b5cd5887a48f2f00a13e Mon Sep 17 00:00:00 2001 From: Jakub Both Date: Fri, 26 Jul 2024 16:08:16 +0200 Subject: [PATCH] MAINT: Add numba version of linear combination for linear kernels. --- src/darsia/utils/kernels.py | 51 +++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/darsia/utils/kernels.py b/src/darsia/utils/kernels.py index 26c2ff99..9afc92b0 100644 --- a/src/darsia/utils/kernels.py +++ b/src/darsia/utils/kernels.py @@ -65,6 +65,56 @@ def __call__(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: """ return np.sum(np.multiply(x, y), axis=-1) + self.a + def linear_combination( + self, + signal: np.ndarray, + supports: np.ndarray, + interpolation_weights: np.ndarray, + ) -> np.ndarray: + """Linear combination using a numba version of the Gaussian kernel. + + Args: + signal (np.ndarray): signal to be interpolated + supports (np.ndarray): supports + interpolation_weights (np.ndarray): interpolation weights + + Returns: + np.ndarray: interpolated signal + + """ + + @numba.jit( + [ + "float32(float32[:], float32[:,:], float32[:], float32)", + "float32[:](float32[:,:], float32[:,:], float32[:], float32)", + "float32[:,:](float32[:,:,:], float32[:,:], float32[:], float32)", + ], + nopython=True, + parallel=True, + fastmath=True, + cache=True, + ) + def _linear_combination_numba( + signal: np.ndarray, + supports: np.ndarray, + interpolation_weights: np.ndarray, + a: float, + ): + """Linear combination of the linear kernel.""" + num_supports = len(supports) + output = interpolation_weights[0] * ( + np.sum(np.multiply(signal, supports[0]), axis=-1) + a + ) + for n in range(1, num_supports): + output += interpolation_weights[n] * ( + np.sum(np.multiply(signal, supports[n]), axis=-1) + a + ) + return output + + return _linear_combination_numba( + signal, supports, interpolation_weights, self.a + ) + class GaussianKernel(BaseKernel): """Gaussian kernel.""" @@ -105,6 +155,7 @@ def linear_combination( @numba.jit( [ + "float32(float32[:], float32[:,:], float32[:], float32)", "float32[:](float32[:,:], float32[:,:], float32[:], float32)", "float32[:,:](float32[:,:,:], float32[:,:], float32[:], float32)", ],