Skip to content

Commit

Permalink
Merge pull request #32 from nicolas-chaulet/debug
Browse files Browse the repository at this point in the history
Fix bug with negative index accessing random memory
  • Loading branch information
nicolas-chaulet authored May 14, 2020
2 parents c5cbbae + 7bb9aa3 commit 7d16352
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
11 changes: 8 additions & 3 deletions cpu/src/ball_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ at::Tensor degree(at::Tensor row, int64_t num_nodes)
{
auto zero = at::zeros(num_nodes, row.options());
auto one = at::ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
auto out = zero.scatter_add_(0, row, one);
return out;
}

std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tensor query,
Expand All @@ -79,9 +80,13 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
auto options_dist = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);

int max_count = 0;
auto batch_access = query_batch.accessor<int64_t, 1>();
auto q_batch_access = query_batch.accessor<int64_t, 1>();
auto s_batch_access = support_batch.accessor<int64_t, 1>();

auto batch_size = q_batch_access[query_batch.size(0) - 1] + 1;
TORCH_CHECK(batch_size == (s_batch_access[support_batch.size(0) - 1] + 1),
"Both batches need to have the same number of samples.")

auto batch_size = batch_access[-1] + 1;
query_batch = degree(query_batch, batch_size);
query_batch = at::cat({at::zeros(1, query_batch.options()), query_batch.cumsum(0)}, 0);
support_batch = degree(support_batch, batch_size);
Expand Down
24 changes: 20 additions & 4 deletions test/test_ballquerry.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import unittest
import torch
from torch_points_kernels import ball_query
import numpy.testing as npt
import numpy as np
from sklearn.neighbors import KDTree
import os
import sys

ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
sys.path.insert(0, ROOT)

from . import run_if_cuda
from test import run_if_cuda
from torch_points_kernels import ball_query


class TestBall(unittest.TestCase):
Expand Down Expand Up @@ -76,10 +81,10 @@ def test_simple_gpu(self):
npt.assert_array_almost_equal(dist2, dist2_answer)

def test_simple_cpu(self):
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(torch.float)
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [10.1, 0, 0]]).to(torch.float)
y = torch.tensor([[0, 0, 0]]).to(torch.float)

batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
batch_x = torch.from_numpy(np.asarray([0, 0, 0, 0])).long()
batch_y = torch.from_numpy(np.asarray([0])).long()

idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y)
Expand All @@ -93,6 +98,17 @@ def test_simple_cpu(self):
npt.assert_array_almost_equal(idx, idx_answer)
npt.assert_array_almost_equal(dist2, dist2_answer)


def test_breaks(self):
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [10.1, 0, 0]]).to(torch.float)
y = torch.tensor([[0, 0, 0]]).to(torch.float)

batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
batch_y = torch.from_numpy(np.asarray([0])).long()

with self.assertRaises(RuntimeError):
idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y)

def test_random_cpu(self):
a = torch.randn(100, 3).to(torch.float)
b = torch.randn(50, 3).to(torch.float)
Expand Down

0 comments on commit 7d16352

Please sign in to comment.