Skip to content

Commit

Permalink
MAINT: Add numba version of linear combination for linear kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwboth committed Jul 26, 2024
1 parent bdbc7f0 commit 3b4fd17
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions src/darsia/utils/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,56 @@ def __call__(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""
return np.sum(np.multiply(x, y), axis=-1) + self.a

def linear_combination(
self,
signal: np.ndarray,
supports: np.ndarray,
interpolation_weights: np.ndarray,
) -> np.ndarray:
"""Linear combination using a numba version of the Gaussian kernel.
Args:
signal (np.ndarray): signal to be interpolated
supports (np.ndarray): supports
interpolation_weights (np.ndarray): interpolation weights
Returns:
np.ndarray: interpolated signal
"""

@numba.jit(
[
"float32(float32[:], float32[:,:], float32[:], float32)",
"float32[:](float32[:,:], float32[:,:], float32[:], float32)",
"float32[:,:](float32[:,:,:], float32[:,:], float32[:], float32)",
],
nopython=True,
parallel=True,
fastmath=True,
cache=True,
)
def _linear_combination_numba(
signal: np.ndarray,
supports: np.ndarray,
interpolation_weights: np.ndarray,
a: float,
):
"""Linear combination of the linear kernel."""
num_supports = len(supports)
output = interpolation_weights[0] * (
np.sum(np.multiply(signal, supports[0]), axis=-1) + a
)
for n in range(1, num_supports):
output += interpolation_weights[n] * (
np.sum(np.multiply(signal, supports[n]), axis=-1) + a
)
return output

return _linear_combination_numba(
signal, supports, interpolation_weights, self.a
)


class GaussianKernel(BaseKernel):
"""Gaussian kernel."""
Expand Down Expand Up @@ -105,6 +155,7 @@ def linear_combination(

@numba.jit(
[
"float32(float32[:], float32[:,:], float32[:], float32)",
"float32[:](float32[:,:], float32[:,:], float32[:], float32)",
"float32[:,:](float32[:,:,:], float32[:,:], float32[:], float32)",
],
Expand Down

0 comments on commit 3b4fd17

Please sign in to comment.