From ee82d3d4861879d28f1eb680c9f0df2da80ac5f1 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Thu, 14 May 2020 08:41:11 +0000 Subject: [PATCH 1/3] Add logging --- cpu/src/ball_query.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cpu/src/ball_query.cpp b/cpu/src/ball_query.cpp index 954b67c..2c5ec72 100644 --- a/cpu/src/ball_query.cpp +++ b/cpu/src/ball_query.cpp @@ -55,9 +55,14 @@ std::pair ball_query(at::Tensor support, at::Tensor quer at::Tensor degree(at::Tensor row, int64_t num_nodes) { + std::cout << "-- Start degree" << std::endl; + std::cout << "Num nodes " << num_nodes << std::endl; auto zero = at::zeros(num_nodes, row.options()); + std::cout << "Row size " << row.size(0) << std::endl; auto one = at::ones(row.size(0), row.options()); - return zero.scatter_add_(0, row, one); + auto out = zero.scatter_add_(0, row, one); + std::cout << "-- End degree" << std::endl; + return out; } std::pair batch_ball_query(at::Tensor support, at::Tensor query, From a5d61459a38d1f2f36acd89ed652d4ee8f2027ea Mon Sep 17 00:00:00 2001 From: Nicolas Date: Thu, 14 May 2020 09:41:21 +0000 Subject: [PATCH 2/3] Remove negative indexing --- cpu/src/ball_query.cpp | 8 ++++++-- test/test_ballquerry.py | 24 ++++++++++++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/cpu/src/ball_query.cpp b/cpu/src/ball_query.cpp index 2c5ec72..444b13c 100644 --- a/cpu/src/ball_query.cpp +++ b/cpu/src/ball_query.cpp @@ -84,9 +84,13 @@ std::pair batch_ball_query(at::Tensor support, at::Tenso auto options_dist = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU); int max_count = 0; - auto batch_access = query_batch.accessor(); + auto q_batch_access = query_batch.accessor(); + auto s_batch_access = support_batch.accessor(); + + auto batch_size = q_batch_access[query_batch.size(0) - 1] + 1; + TORCH_CHECK(batch_size == (s_batch_access[support_batch.size(0) - 1] + 1), + "Both batches need to have the same number of samples.") - auto batch_size = batch_access[-1] + 1; query_batch = degree(query_batch, batch_size); query_batch = at::cat({at::zeros(1, query_batch.options()), query_batch.cumsum(0)}, 0); support_batch = degree(support_batch, batch_size); diff --git a/test/test_ballquerry.py b/test/test_ballquerry.py index bbcd40e..0b32b96 100644 --- a/test/test_ballquerry.py +++ b/test/test_ballquerry.py @@ -1,11 +1,16 @@ import unittest import torch -from torch_points_kernels import ball_query import numpy.testing as npt import numpy as np from sklearn.neighbors import KDTree +import os +import sys + +ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..") +sys.path.insert(0, ROOT) -from . import run_if_cuda +from test import run_if_cuda +from torch_points_kernels import ball_query class TestBall(unittest.TestCase): @@ -76,10 +81,10 @@ def test_simple_gpu(self): npt.assert_array_almost_equal(dist2, dist2_answer) def test_simple_cpu(self): - x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(torch.float) + x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [10.1, 0, 0]]).to(torch.float) y = torch.tensor([[0, 0, 0]]).to(torch.float) - batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long() + batch_x = torch.from_numpy(np.asarray([0, 0, 0, 0])).long() batch_y = torch.from_numpy(np.asarray([0])).long() idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y) @@ -93,6 +98,17 @@ def test_simple_cpu(self): npt.assert_array_almost_equal(idx, idx_answer) npt.assert_array_almost_equal(dist2, dist2_answer) + + def test_breaks(self): + x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [10.1, 0, 0]]).to(torch.float) + y = torch.tensor([[0, 0, 0]]).to(torch.float) + + batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long() + batch_y = torch.from_numpy(np.asarray([0])).long() + + with self.assertRaises(RuntimeError): + idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y) + def test_random_cpu(self): a = torch.randn(100, 3).to(torch.float) b = torch.randn(50, 3).to(torch.float) From d1132d00c3eacf9e066734612b483dec79be9dc1 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Thu, 14 May 2020 09:56:15 +0000 Subject: [PATCH 3/3] Clean logging --- cpu/src/ball_query.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cpu/src/ball_query.cpp b/cpu/src/ball_query.cpp index 444b13c..c9a4777 100644 --- a/cpu/src/ball_query.cpp +++ b/cpu/src/ball_query.cpp @@ -55,13 +55,9 @@ std::pair ball_query(at::Tensor support, at::Tensor quer at::Tensor degree(at::Tensor row, int64_t num_nodes) { - std::cout << "-- Start degree" << std::endl; - std::cout << "Num nodes " << num_nodes << std::endl; auto zero = at::zeros(num_nodes, row.options()); - std::cout << "Row size " << row.size(0) << std::endl; auto one = at::ones(row.size(0), row.options()); auto out = zero.scatter_add_(0, row, one); - std::cout << "-- End degree" << std::endl; return out; }