Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Fix hnsw bitset (#602)
Browse files Browse the repository at this point in the history
Signed-off-by: zh Wang <[email protected]>

Signed-off-by: zh Wang <[email protected]>
  • Loading branch information
hhy3 authored Dec 13, 2022
1 parent 7590690 commit 6dec5e6
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 79 deletions.
13 changes: 8 additions & 5 deletions knowhere/index/vector_index/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,15 @@ IndexHNSW::QueryImpl(int64_t n, const float* xq, int64_t k, float* distances, in
size_t rst_size = rst.size();
auto p_single_dis = distances + index * k;
auto p_single_id = labels + index * k;
for (size_t idx = 0; idx < rst_size; ++idx) {
const auto& [dist, id] = rst[idx];
p_single_dis[idx] = transform ? (1 - dist) : dist;
p_single_id[idx] = id;
size_t idx = rst_size - 1;
while (!rst.empty()) {
auto& it = rst.top();
p_single_dis[idx] = transform ? (1 - it.first) : it.first;
p_single_id[idx] = it.second;
rst.pop();
idx--;
}
for (size_t idx = rst_size; idx < k; idx++) {
for (idx = rst_size; idx < k; idx++) {
p_single_dis[idx] = float(1.0 / 0.0);
p_single_id[idx] = -1;
}
Expand Down
150 changes: 78 additions & 72 deletions thirdparty/hnswlib/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
mutable std::atomic<long> metric_hops;

template <bool has_deletions, bool collect_metrics = false>
std::vector<std::pair<dist_t, tableint>>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayerST(tableint ep_id, const void* data_point, size_t ef, const faiss::BitsetView bitset,
const SearchParam* param = nullptr,
const knowhere::feder::hnsw::FederResultUniq& feder_result = nullptr) const {
Expand All @@ -250,80 +250,81 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}

auto& visited = visited_list_pool_->getFreeVisitedList();
std::vector<Neighbor> retset(ef + 1);

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
top_candidates;
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
candidate_set;

dist_t lowerBound;
if (!has_deletions || !bitset.test((int64_t)ep_id)) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
retset[0] = Neighbor(ep_id, dist, true);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
candidate_set.emplace(-dist, ep_id);
} else {
retset[0] = Neighbor(ep_id, std::numeric_limits<dist_t>::max(), true);
lowerBound = std::numeric_limits<dist_t>::max();
candidate_set.emplace(-lowerBound, ep_id);
}

visited[ep_id] = true;
size_t p = 0, cur_size = 1;
while (p < cur_size) {
int np = cur_size;
if (retset[p].flag) {
retset[p].flag = false;
tableint u = retset[p].id;
tableint* list = (tableint*)get_linklist0(u);
#if defined(USE_PREFETCH)
_mm_prefetch(list, _MM_HINT_T0);
#endif
int size = list[0];
while (!candidate_set.empty()) {
std::pair<dist_t, tableint> current_node_pair = candidate_set.top();

if constexpr (collect_metrics) {
metric_hops++;
metric_distance_computations += size;
}
#if defined(USE_PREFETCH)
for (size_t i = 1; i <= size; ++i) {
_mm_prefetch(getDataByInternalId(list[i]), _MM_HINT_T0);
}
if ((-current_node_pair.first) > lowerBound && (top_candidates.size() == ef || has_deletions == false)) {
break;
}
candidate_set.pop();

tableint current_node_id = current_node_pair.second;
int* data = (int*)get_linklist0(current_node_id);
size_t size = getListCount((linklistsizeint*)data);
// bool cur_node_deleted = isMarkedDeleted(current_node_id);
if (collect_metrics) {
metric_hops++;
metric_distance_computations += size;
}

#ifdef USE_PREFETCH
for (size_t j = 1; j <= size; ++j) {
_mm_prefetch(getDataByInternalId(data[j]), _MM_HINT_T0);
}
#endif
for (size_t i = 1; i <= size; ++i) {
tableint v = list[i];
if (visited[v]) {
if (feder_result != nullptr) {
feder_result->visit_info_.AddVisitRecord(0, u, v, -1.0);
feder_result->id_set_.insert(u);
feder_result->id_set_.insert(v);
}
continue;
}
visited[v] = true;
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(v), dist_func_param_);

for (size_t j = 1; j <= size; j++) {
int candidate_id = *(data + j);
if (!visited[candidate_id]) {
visited[candidate_id] = true;
char* currObj1 = (getDataByInternalId(candidate_id));
dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
if (feder_result != nullptr) {
feder_result->visit_info_.AddVisitRecord(0, u, v, dist);
feder_result->id_set_.insert(u);
feder_result->id_set_.insert(v);
}
if ((cur_size == ef && dist >= retset[ef - 1].distance) ||
(has_deletions && bitset.test((int64_t)v))) {
continue;
feder_result->visit_info_.AddVisitRecord(0, current_node_id, candidate_id, dist);
feder_result->id_set_.insert(current_node_id);
feder_result->id_set_.insert(candidate_id);
}
Neighbor nn(v, dist, true);
int r = InsertIntoPool(retset.data(), cur_size, nn);
if (cur_size < ef) {
++cur_size;

if (top_candidates.size() < ef || lowerBound > dist) {
candidate_set.emplace(-dist, candidate_id);
if (!has_deletions || !bitset.test((int64_t)candidate_id))
top_candidates.emplace(dist, candidate_id);

if (top_candidates.size() > ef)
top_candidates.pop();

if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
}
if (r < np) {
np = r;
} else {
if (feder_result != nullptr) {
feder_result->visit_info_.AddVisitRecord(0, current_node_id, candidate_id, -1.0);
feder_result->id_set_.insert(current_node_id);
feder_result->id_set_.insert(candidate_id);
}
}
}
if (np <= p) {
p = np;
} else {
++p;
}
}

std::vector<std::pair<dist_t, tableint>> ans(cur_size);
for (int i = 0; i < cur_size; ++i) {
ans[i] = {retset[i].distance, retset[i].id};
}
return ans;
return top_candidates;
}

std::vector<tableint>
Expand Down Expand Up @@ -370,16 +371,17 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}

std::vector<std::pair<dist_t, labeltype>>
getNeighboursWithinRadius(std::vector<std::pair<dist_t, tableint>>& top_candidates, const void* data_point,
float radius, const faiss::BitsetView bitset,
getNeighboursWithinRadius(std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>,
CompareByFirst>& top_candidates,
const void* data_point, float radius, const faiss::BitsetView bitset,
const knowhere::feder::hnsw::FederResultUniq& feder_result = nullptr) const {
std::vector<std::pair<dist_t, labeltype>> result;
auto& visited = visited_list_pool_->getFreeVisitedList();

std::queue<std::pair<dist_t, tableint>> radius_queue;
while (!top_candidates.empty()) {
auto cand = top_candidates.back();
top_candidates.pop_back();
auto cand = top_candidates.top();
top_candidates.pop();
if (cand.first < radius) {
radius_queue.push(cand);
result.emplace_back(cand.first, cand.second);
Expand Down Expand Up @@ -1087,11 +1089,12 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
return cur_c;
};

std::vector<std::pair<dist_t, labeltype>>
std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void* query_data, size_t k, const faiss::BitsetView bitset, const SearchParam* param = nullptr,
const knowhere::feder::hnsw::FederResultUniq& feder_result = nullptr) const {
std::priority_queue<std::pair<dist_t, labeltype>> result;
if (cur_element_count == 0)
return {};
return result;

tableint currObj = enterpoint_node_;
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
Expand Down Expand Up @@ -1134,8 +1137,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}
}
}

std::vector<std::pair<dist_t, tableint>> top_candidates;
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
top_candidates;
size_t ef = param ? param->ef_ : this->ef_;
if (!bitset.empty()) {
top_candidates =
Expand All @@ -1144,11 +1147,13 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
top_candidates =
searchBaseLayerST<false, true>(currObj, query_data, std::max(ef, k), bitset, param, feder_result);
}
std::vector<std::pair<dist_t, labeltype>> result;
size_t len = std::min(k, top_candidates.size());
result.reserve(len);
for (int i = 0; i < len; ++i) {
result.emplace_back(top_candidates[i].first, (labeltype)top_candidates[i].second);
while (top_candidates.size() > k) {
top_candidates.pop();
}
while (top_candidates.size() > 0) {
std::pair<dist_t, tableint> rez = top_candidates.top();
result.push(std::pair<dist_t, labeltype>(rez.first, rez.second));
top_candidates.pop();
}
return result;
};
Expand Down Expand Up @@ -1198,7 +1203,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}
}

std::vector<std::pair<dist_t, tableint>> top_candidates;
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
top_candidates;
size_t ef = param ? param->ef_ : this->ef_;
if (!bitset.empty()) {
top_candidates = searchBaseLayerST<true, true>(currObj, query_data, ef, bitset, param, feder_result);
Expand Down
14 changes: 12 additions & 2 deletions thirdparty/hnswlib/hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label)=0;

virtual std::vector<std::pair<dist_t, labeltype>>
virtual std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void*, size_t, const faiss::BitsetView, const SearchParam*,
const knowhere::feder::hnsw::FederResultUniq&) const = 0;

Expand All @@ -202,7 +202,17 @@ AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t
std::vector<std::pair<dist_t, labeltype>> result;

// here searchKnn returns the result in the order of further first
return searchKnn(query_data, k, bitset, nullptr, nullptr);
auto ret = searchKnn(query_data, k, bitset, nullptr, nullptr);
{
size_t sz = ret.size();
result.resize(sz);
while (!ret.empty()) {
result[--sz] = ret.top();
ret.pop();
}
}

return result;
}
}

Expand Down

0 comments on commit 6dec5e6

Please sign in to comment.