Skip to content

Commit

Permalink
add unit test index
Browse files Browse the repository at this point in the history
  • Loading branch information
bcharlier committed Mar 29, 2024
1 parent 83524cd commit ad3930c
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions pykeops/pykeops/test/test_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from pykeops.torch import LazyTensor
import pytest


def fun_torch(A, I, J):
return A[I, J].sum(axis=1)


def fun_keops(A, I, J):
ncol = A.shape[1]
A = LazyTensor(A.flatten())
I = LazyTensor((I + 0.0)[..., None])
J = LazyTensor((J + 0.0)[..., None])
K = A[I * ncol + J]
return K.sum(axis=1).flatten()


P, Q = 12, 5
M, N = 300, 200
device = "cuda" if torch.cuda.is_available() else "cpu"
A = torch.randn((P, Q), requires_grad=True, device=device)
I = torch.randint(P, (M, 1), device=device)
J = torch.randint(Q, (1, N), device=device)

res_torch = fun_torch(A, I, J)
print(res_torch)

res_keops = fun_keops(A, I, J)
print(res_keops)


def test_index():
assert torch.allclose(res_torch, res_torch)


# testing gradients
def test_index_grad():
loss_torch = (res_torch ** 2).sum()
loss_keops = (res_keops ** 2).sum()
assert torch.allclose(torch.autograd.grad(loss_torch, [A])[0],
torch.autograd.grad(loss_keops, [A])[0]
)


0 comments on commit ad3930c

Please sign in to comment.