Skip to content

Commit

Permalink
Fix knn query
Browse files Browse the repository at this point in the history
* the "far branch" should be checked also when we have not filled the
  KNN buffer.
  • Loading branch information
crvs committed Jun 4, 2024
1 parent 7928cea commit f64d3cf
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 42 deletions.
57 changes: 23 additions & 34 deletions KDTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ double KDNode::coord(size_t const& idx) { return x.at(idx); }
KDNode::operator bool() { return (!x.empty()); }
KDNode::operator point_t() { return x; }
KDNode::operator size_t() { return index; }
KDNode::operator pointIndex() { return pointIndex(x, index); }
KDNode::operator pointIndex() { return std::make_pair(x, index); }

KDNodePtr NewKDNodePtr() {
KDNodePtr mynode = std::make_shared<KDNode>();
return mynode;
}

inline double dist2(point_t const& a, point_t const& b) {
assert(a.size() == b.size());
double distc = 0;
for (size_t i = 0; i < a.size(); i++) {
double di = a.at(i) - b.at(i);
Expand All @@ -63,14 +64,6 @@ inline double dist2(KDNodePtr const& a, KDNodePtr const& b) {
return dist2(a->x, b->x);
}

inline double dist(point_t const& a, point_t const& b) {
return std::sqrt(dist2(a, b));
}

inline double dist(KDNodePtr const& a, KDNodePtr const& b) {
return std::sqrt(dist2(a, b));
}

comparer::comparer(size_t idx_) : idx{idx_} {}

inline bool comparer::compare_idx(pointIndex const& a, pointIndex const& b) {
Expand Down Expand Up @@ -145,7 +138,7 @@ void KDTree::node_query_(
insert_it->first != std::next(insert_it)->first) {
k_nearest_buffer.insert(insert_it, node_distance);
}
if (k_nearest_buffer.size() > num_nearest) {
while (k_nearest_buffer.size() > num_nearest) {
k_nearest_buffer.pop_back();
}
}
Expand All @@ -154,13 +147,14 @@ void KDTree::knearest_(
KDNodePtr const& branch, point_t const& pt, size_t const& level,
size_t const& num_nearest,
std::list<std::pair<KDNodePtr, double>>& k_nearest_buffer) {
if (branch == nullptr || !bool(*branch)) {
if (branch == nullptr || !static_cast<bool>(*branch)) {
return;
}

point_t branch_pt(*branch);
point_t branch_pt{*branch};
size_t dim = branch_pt.size();
assert(dim != 0);
assert(dim == pt.size());

double const dx = branch_pt.at(level) - pt.at(level);
double const dx2 = dx * dx;
Expand All @@ -173,7 +167,7 @@ void KDTree::knearest_(
node_query_(close_branch, pt, next_level, num_nearest, k_nearest_buffer);

// only check the other branch if it makes sense to do so
if (dx2 < k_nearest_buffer.back().second) {
if (dx2 < k_nearest_buffer.back().second || k_nearest_buffer.size() < num_nearest) {
node_query_(far_branch, pt, next_level, num_nearest, k_nearest_buffer);
}
};
Expand Down Expand Up @@ -205,7 +199,7 @@ size_t KDTree::nearest_index(point_t const& pt) {

pointIndex KDTree::nearest_pointIndex(point_t const& pt) {
KDNodePtr Nearest = nearest_(pt);
return pointIndex(point_t(*Nearest), size_t(*Nearest));
return static_cast<pointIndex>(*Nearest);
}

pointIndexArr KDTree::nearest_pointIndices(point_t const& pt,
Expand Down Expand Up @@ -244,7 +238,7 @@ indexArr KDTree::nearest_indices(point_t const& pt, size_t const& num_nearest) {
}

void KDTree::neighborhood_(KDNodePtr const& branch, point_t const& pt,
double const& rad, size_t const& level,
double const& rad2, size_t const& level,
pointIndexArr& nbh) {
if (!bool(*branch)) {
// branch has no point, means it is a leaf,
Expand All @@ -254,49 +248,44 @@ void KDTree::neighborhood_(KDNodePtr const& branch, point_t const& pt,

size_t const dim = pt.size();

double const r2 = rad * rad;

double const d = dist2(point_t(*branch), pt);
double const dx = point_t(*branch).at(level) - pt.at(level);
double const d = dist2(static_cast<point_t>(*branch), pt);
double const dx = static_cast<point_t>(*branch).at(level) - pt.at(level);
double const dx2 = dx * dx;

if (d <= r2) {
nbh.push_back(pointIndex(*branch));
if (d <= rad2) {
nbh.push_back(static_cast<pointIndex>(*branch));
}

KDNodePtr const close_branch = (dx > 0) ? branch->left : branch->right;
KDNodePtr const far_branch = (dx > 0) ? branch->right : branch->left;

size_t const next_level{(level + 1) % dim};
neighborhood_(close_branch, pt, rad, next_level, nbh);
if (dx2 < r2) {
neighborhood_(far_branch, pt, rad, next_level, nbh);
neighborhood_(close_branch, pt, rad2, next_level, nbh);
if (dx2 < rad2) {
neighborhood_(far_branch, pt, rad2, next_level, nbh);
}
}

pointIndexArr KDTree::neighborhood(point_t const& pt, double const& rad) {
size_t level = 0;
pointIndexArr nbh;
neighborhood_(root_, pt, rad, level, nbh);
neighborhood_(root_, pt, rad * rad, /*level*/ 0, nbh);
return nbh;
}

pointVec KDTree::neighborhood_points(point_t const& pt, double const& rad) {
size_t level = 0;
auto nbh = std::make_shared<pointIndexArr>();
neighborhood_(root_, pt, rad, level, *nbh);
neighborhood_(root_, pt, rad * rad, /*level*/ 0, *nbh);
pointVec nbhp{nbh->size()};
std::transform(nbh->begin(), nbh->end(), nbhp.begin(),
[](pointIndex x) { return x.first; });
auto const first = [](pointIndex const& x) { return x.first; };
std::transform(nbh->begin(), nbh->end(), nbhp.begin(), first);
return nbhp;
}

indexArr KDTree::neighborhood_indices(point_t const& pt, double const& rad) {
size_t level = 0;
auto nbh = std::make_shared<pointIndexArr>();
neighborhood_(root_, pt, rad, level, *nbh);
neighborhood_(root_, pt, rad * rad, /*level*/ 0, *nbh);
indexArr nbhi{nbh->size()};
std::transform(nbh->begin(), nbh->end(), nbhi.begin(),
[](pointIndex x) { return x.second; });
auto const second = [](pointIndex const& x) { return x.second; };
std::transform(nbh->begin(), nbh->end(), nbhi.begin(), second);
return nbhi;
}
6 changes: 1 addition & 5 deletions KDTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ KDNodePtr NewKDNodePtr();
inline double dist2(point_t const&, point_t const&);
inline double dist2(KDNodePtr const&, KDNodePtr const&);

// euclidean distance
inline double dist(point_t const&, point_t const&);
inline double dist(KDNodePtr const&, KDNodePtr const&);

// Need for sorting
class comparer {
public:
Expand Down Expand Up @@ -175,7 +171,7 @@ class KDTree {
KDNodePtr nearest_(point_t const& pt);

void neighborhood_(KDNodePtr const& branch, point_t const& pt,
double const& rad, size_t const& level,
double const& rad2, size_t const& level,
pointIndexArr& nbh);

KDNodePtr root_;
Expand Down
2 changes: 1 addition & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

add_executable(construction_time construction_time.cpp)
target_link_libraries(construction_time KDTree)

Expand All @@ -16,3 +15,4 @@ add_test(NAME error_test COMMAND $<TARGET_FILE:error_test>)
add_test(NAME knn_error_test COMMAND $<TARGET_FILE:knn_error_test>)
add_test(NAME toy_test COMMAND $<TARGET_FILE:toy_test>)

add_custom_target(check ${CMAKE_CTEST_COMMAND} DEPENDS construction_time error_test knn_error_test toy_test)
4 changes: 2 additions & 2 deletions tests/knn_error_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <random>
#include <vector>

#define DIM 1
#define DIM 2

double getNum() { return ((double)rand() / (RAND_MAX)); }
using secondsf = std::chrono::duration<float>;
Expand Down Expand Up @@ -75,7 +75,7 @@ int main() {
std::chrono::nanoseconds bruteForceRetTotalTime{};

int correct = 0;
int const k = 3;
int const k = 10;
for (int i = 0; i < nIter; i++) {
// generate test points to build a tree
points = getListofGeneratedVectors(sizes);
Expand Down

0 comments on commit f64d3cf

Please sign in to comment.