diff --git a/knowhere/common/Heap.h b/knowhere/common/Heap.h new file mode 100644 index 000000000..396e26a82 --- /dev/null +++ b/knowhere/common/Heap.h @@ -0,0 +1,62 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +namespace knowhere { + +// Maintain intermediate top-k results via maxheap +// TODO: this naive implementation might be optimzed later +// 1. Based on top-k and pushed element count to swtich strategy +// 2. Combine `pop` and `push` operation to `replace` +template +class ResultMaxHeap { + public: + explicit ResultMaxHeap(size_t k) : k_(k) {} + + inline std::optional> + Pop() { + if (pq.empty()) { + return std::nullopt; + } + std::optional> res = pq.top(); + pq.pop(); + return res; + } + + inline void + Push(DisT dis, IdT id) { + if (pq.size() < k_) { + pq.emplace(dis, id); + return; + } + + if (dis < pq.top().first) { + pq.pop(); + pq.emplace(dis, id); + } + } + + inline size_t + Size() { + return pq.size(); + } + + private: + size_t k_; + std::priority_queue> pq; +}; + +} // namespace knowhere diff --git a/knowhere/index/vector_index/IndexDiskANN.cpp b/knowhere/index/vector_index/IndexDiskANN.cpp index 440ff4a7f..654a0214d 100644 --- a/knowhere/index/vector_index/IndexDiskANN.cpp +++ b/knowhere/index/vector_index/IndexDiskANN.cpp @@ -403,7 +403,7 @@ IndexDiskANN::Query(const DatasetPtr& dataset_ptr, const Config& config, cons futures.push_back(pool_->push([&, index = row]() { pq_flash_index_->cached_beam_search(query + (index * dim), k, query_conf.search_list_size, p_id + (index * k), p_dist + (index * k), query_conf.beamwidth, false, - nullptr, nullptr, bitset); + nullptr, nullptr, bitset, query_conf.filter_threshold); })); } diff --git a/knowhere/index/vector_index/IndexDiskANNConfig.cpp b/knowhere/index/vector_index/IndexDiskANNConfig.cpp index 3998c7cf2..62095aab2 100644 --- a/knowhere/index/vector_index/IndexDiskANNConfig.cpp +++ b/knowhere/index/vector_index/IndexDiskANNConfig.cpp @@ -37,6 +37,7 @@ static constexpr const char* kAioMaxnr = "aio_maxnr"; static constexpr const char* kK = "k"; static constexpr const char* kBeamwidth = "beamwidth"; +static constexpr const char* kFilterThreshold = "filter_threshold"; static constexpr const char* kRadius = "radius"; static constexpr const char* kRangeFilter = "range_filter"; @@ -65,6 +66,8 @@ static constexpr std::optional kDiskPqBytesMaxValue = std::nullopt; static constexpr uint32_t kSearchListSizeMaxValue = 200; static constexpr uint32_t kBeamwidthMinValue = 1; static constexpr uint32_t kBeamwidthMaxValue = 128; +static constexpr float kFilterThresholdMinValue = -1; +static constexpr float kFilterThresholdMaxValue = 1; static constexpr uint64_t kKMinValue = 1; static constexpr std::optional kKMaxValue = std::nullopt; static constexpr uint64_t kAioMaxnrMinValue = 1; @@ -206,8 +209,10 @@ from_json(const Config& config, DiskANNPrepareConfig& prep_conf) { void to_json(Config& config, const DiskANNQueryConfig& query_conf) { - config = - Config{{kK, query_conf.k}, {kSearchListSize, query_conf.search_list_size}, {kBeamwidth, query_conf.beamwidth}}; + config = Config{{kK, query_conf.k}, + {kSearchListSize, query_conf.search_list_size}, + {kBeamwidth, query_conf.beamwidth}, + {kFilterThreshold, query_conf.filter_threshold}}; } void @@ -218,6 +223,8 @@ from_json(const Config& config, DiskANNQueryConfig& query_conf) { std::max(kSearchListSizeMaxValue, static_cast(10 * query_conf.k)), query_conf.search_list_size); CheckNumericParamAndSet(config, kBeamwidth, kBeamwidthMinValue, kBeamwidthMaxValue, query_conf.beamwidth); + CheckNumericParamAndSet(config, kFilterThreshold, kFilterThresholdMinValue, kFilterThresholdMaxValue, + query_conf.filter_threshold); } void diff --git a/knowhere/index/vector_index/IndexDiskANNConfig.h b/knowhere/index/vector_index/IndexDiskANNConfig.h index 3929c5419..ffc25866c 100644 --- a/knowhere/index/vector_index/IndexDiskANNConfig.h +++ b/knowhere/index/vector_index/IndexDiskANNConfig.h @@ -92,6 +92,10 @@ struct DiskANNQueryConfig { // slightly higher total number of IO requests to SSD per query. For the highest query throughput with a fixed SSD // IOps rating, use W=1. For best latency, use W=4,8 or higher complexity search. uint32_t beamwidth = 8; + // The threshold which determines when to switch to PQ + Refine strategy based on the number of bits set. The + // value should be in range of [0.0, 1.0] which means when greater or equal to x% of the bits are set, + // use PQ + Refine. Default to -1.0f, negative vlaues will use dynamic threshold calculator given topk. + float filter_threshold = -1.0f; static DiskANNQueryConfig Get(const Config& config); diff --git a/thirdparty/DiskANN/include/pq_flash_index.h b/thirdparty/DiskANN/include/pq_flash_index.h index dd1b1191d..a31c19f26 100644 --- a/thirdparty/DiskANN/include/pq_flash_index.h +++ b/thirdparty/DiskANN/include/pq_flash_index.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include #include #include @@ -98,8 +99,9 @@ namespace diskann { const T *query, const _u64 k_search, const _u64 l_search, _s64 *res_ids, float *res_dists, const _u64 beam_width, const bool use_reorder_data = false, QueryStats *stats = nullptr, - const knowhere::feder::diskann::FederResultUniq& feder = nullptr, - faiss::BitsetView bitset_view = nullptr); + const knowhere::feder::diskann::FederResultUniq &feder = nullptr, + faiss::BitsetView bitset_view = nullptr, + const float filter_ratio = -1.0f); DISKANN_DLLEXPORT _u32 range_search(const T *query1, const double range, @@ -144,6 +146,23 @@ namespace diskann { : sector_buf + (node_id % nnodes_per_sector) * max_node_len; } + inline void copy_vec_base_data(T *des, const int64_t des_idx, void *src); + + // Init thread data and returns query norm if avaialble. + // If there is no value, there is nothing to do with the given query + std::optional init_thread_data(ThreadData &data, const T *query1); + + // Brute force search for the given query. Use beam search rather than + // sending whole bunch of requests at once to avoid all threads sending I/O + // requests and the time overlaps. + // The beam width is adjusted in the function. + void brute_force_beam_search( + ThreadData &data, const float query_norm, const _u64 k_search, + _s64 *indices, float *distances, const _u64 beam_width_param, IOContext &ctx, + QueryStats *stats, + const knowhere::feder::diskann::FederResultUniq &feder, + faiss::BitsetView bitset_view); + // index info // nhood of node `i` is in sector: [i / nnodes_per_sector] // offset in sector: [(i % nnodes_per_sector) * max_node_len] diff --git a/thirdparty/DiskANN/include/utils.h b/thirdparty/DiskANN/include/utils.h index 77fffe176..614da03b0 100644 --- a/thirdparty/DiskANN/include/utils.h +++ b/thirdparty/DiskANN/include/utils.h @@ -33,7 +33,6 @@ typedef HANDLE FileHandle; typedef int FileHandle; #endif -#include "utils.h" #include "logger.h" #include "cached_io.h" #include "ann_exception.h" diff --git a/thirdparty/DiskANN/src/pq_flash_index.cpp b/thirdparty/DiskANN/src/pq_flash_index.cpp index bb31aad32..ed169528d 100644 --- a/thirdparty/DiskANN/src/pq_flash_index.cpp +++ b/thirdparty/DiskANN/src/pq_flash_index.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -19,6 +20,7 @@ #include "timer.h" #include "utils.h" +#include "knowhere/common/Heap.h" #include "tsl/robin_set.h" #ifdef _WINDOWS @@ -82,6 +84,14 @@ namespace { } return h; } + + constexpr _u64 kRefineBeamWidthFactor = 2; + constexpr _u64 kBruteForceTopkRefineExpansionFactor = 2; + auto calcFilterThreshold = [](const auto topk) -> float { + return std::max(-0.04570166137874405f * log2(topk + 58.96422392240403) + + 1.1982775974217197, + 0.5); + }; } // namespace namespace diskann { @@ -815,32 +825,13 @@ namespace diskann { #endif template - void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_search, - const _u64 l_search, _s64 *indices, - float * distances, - const _u64 beam_width, - const bool use_reorder_data, - QueryStats *stats, - const knowhere::feder::diskann::FederResultUniq& feder_result, - faiss::BitsetView bitset_view) { - ThreadData data = this->thread_data.pop(); - while (data.scratch.sector_scratch == nullptr) { - this->thread_data.wait_for_push_notify(); - data = this->thread_data.pop(); - } - - if (beam_width > MAX_N_SECTOR_READS) - throw ANNException("Beamwidth can not be higher than MAX_N_SECTOR_READS", - -1, __FUNCSIG__, __FILE__, __LINE__); - + std::optional PQFlashIndex::init_thread_data(ThreadData &data, + const T *query1) { // copy query to thread specific aligned and allocated memory (for distance // calculations we need aligned data) - double query_norm = 0; - const T * query = data.scratch.aligned_query_T; - const float *query_float = data.scratch.aligned_query_float; - - auto q_dim = this->data_dim; - if(metric== diskann::Metric::INNER_PRODUCT ){ + float query_norm = 0; + auto q_dim = this->data_dim; + if (metric == diskann::Metric::INNER_PRODUCT) { // query_dim need to be specially treated when using IP q_dim--; } @@ -854,11 +845,8 @@ namespace diskann { // if inner product, we laso normalize the query and set the last coordinate // to 0 (this is the extra coordindate used to convert MIPS to L2 search) if (metric == diskann::Metric::INNER_PRODUCT) { - if(query_norm == 0){ - // return an empty answer when calcu a zero point - this->thread_data.push(data); - this->thread_data.push_notify_all(); - return; + if (query_norm == 0) { + return std::nullopt; } query_norm = std::sqrt(query_norm); data.scratch.aligned_query_T[this->data_dim - 1] = 0; @@ -869,11 +857,202 @@ namespace diskann { } } + data.scratch.reset(); + return query_norm; + } + + template + void PQFlashIndex::brute_force_beam_search( + ThreadData &data, const float query_norm, const _u64 k_search, + _s64 *indices, float *distances, const _u64 beam_width_param, IOContext &ctx, + QueryStats *stats, const knowhere::feder::diskann::FederResultUniq &feder, + faiss::BitsetView bitset_view) { + auto query_scratch = &(data.scratch); + const T *query = data.scratch.aligned_query_T; + auto beam_width = beam_width_param * kRefineBeamWidthFactor; + const float *query_float = data.scratch.aligned_query_float; + float *pq_dists = query_scratch->aligned_pqtable_dist_scratch; + pq_table.populate_chunk_distances(query_float, pq_dists); + float *dist_scratch = query_scratch->aligned_dist_scratch; + _u8 *pq_coord_scratch = query_scratch->aligned_pq_coord_scratch; + constexpr _u32 pq_batch_size = MAX_GRAPH_DEGREE; + std::vector pq_batch_ids; + pq_batch_ids.reserve(pq_batch_size); + const _u64 pq_topk = k_search * kBruteForceTopkRefineExpansionFactor; + knowhere::ResultMaxHeap pq_max_heap(pq_topk); + T *data_buf = query_scratch->coord_scratch; + std::unordered_map<_u64, std::vector<_u64>> nodes_in_sectors_to_visit; + std::vector frontier_read_reqs; + frontier_read_reqs.reserve(beam_width); + char *sector_scratch = query_scratch->sector_scratch; + _u64 §or_scratch_idx = query_scratch->sector_idx; + knowhere::ResultMaxHeap max_heap(k_search); + Timer io_timer, query_timer; + + // scan un-marked points and calculate pq dists + for (_u64 id = 0; id < num_points; ++id) { + if (!bitset_view.test(id)) { + pq_batch_ids.push_back(id); + } + + if (pq_batch_ids.size() == pq_batch_size || id == num_points - 1) { + const size_t sz = pq_batch_ids.size(); + aggregate_coords(pq_batch_ids.data(), sz, this->data, this->n_chunks, + pq_coord_scratch); + pq_dist_lookup(pq_coord_scratch, sz, this->n_chunks, pq_dists, + dist_scratch); + for (size_t i = 0; i < sz; ++i) { + pq_max_heap.Push(dist_scratch[i], pq_batch_ids[i]); + } + pq_batch_ids.clear(); + } + } + + // deduplicate sectors by ids + while (const auto opt = pq_max_heap.Pop()) { + const auto [dist, id] = opt.value(); + + // check if in cache + if (coord_cache.find(id) != coord_cache.end()) { + float dist = dist_cmp(query, coord_cache.at(id), (size_t) aligned_dim); + max_heap.Push(dist, id); + continue; + } + + // deduplicate and prepare for I/O + const _u64 sector_offset = get_node_sector_offset(id); + nodes_in_sectors_to_visit[sector_offset].push_back(id); + } + + for (auto it = nodes_in_sectors_to_visit.cbegin(); + it != nodes_in_sectors_to_visit.cend();) { + const auto sector_offset = it->first; + frontier_read_reqs.emplace_back( + sector_offset, read_len_for_node, + sector_scratch + sector_scratch_idx * read_len_for_node); + ++sector_scratch_idx, ++it; + if (stats != nullptr) { + stats->n_4k++; + stats->n_ios++; + } + + // perform I/Os and calculate exact distances + if (frontier_read_reqs.size() == beam_width || + it == nodes_in_sectors_to_visit.cend()) { + io_timer.reset(); +#ifdef USE_BING_INFRA + reader->read(frontier_read_reqs, ctx, true); // async reader windows. +#else + reader->read(frontier_read_reqs, ctx); // synchronous IO linux +#endif + if (stats != nullptr) { + stats->io_us += (double) io_timer.elapsed(); + } + + T *node_fp_coords_copy = data_buf; + for (const auto &req : frontier_read_reqs) { + const auto offset = req.offset; + char *sector_buf = reinterpret_cast(req.buf); + for (const auto cur_id : nodes_in_sectors_to_visit[offset]) { + char *node_buf = get_offset_to_node(sector_buf, cur_id); + memcpy(node_fp_coords_copy, node_buf, + disk_bytes_per_point); // Do we really need memcpy here? + float dist = + dist_cmp(query, node_fp_coords_copy, (size_t) aligned_dim); + max_heap.Push(dist, cur_id); + if (feder != nullptr) { + feder->visit_info_.AddTopCandidateInfo(cur_id, dist); + feder->id_set_.insert(cur_id); + } + } + } + frontier_read_reqs.clear(); + sector_scratch_idx = 0; + } + } + + for (_s64 i = k_search - 1; i >= 0; --i) { + if ((_u64) i >= max_heap.Size()) { + indices[i] = -1; + if (distances != nullptr) { + distances[i] = -1; + } + continue; + } + if (const auto op = max_heap.Pop()) { + const auto [dis, id] = op.value(); + indices[i] = id; + if (distances != nullptr) { + distances[i] = dis; + if (metric == diskann::Metric::INNER_PRODUCT) { + distances[i] = 1.0 - distances[i] / 2.0; + if (max_base_norm != 0) { + distances[i] *= (max_base_norm * query_norm); + } + } + } + } else { + LOG(ERROR) << "Size is incorrect"; + } + } + if (stats != nullptr) { + stats->total_us = (double) query_timer.elapsed(); + } + return; + } + + template + void PQFlashIndex::cached_beam_search( + const T *query1, const _u64 k_search, const _u64 l_search, _s64 *indices, + float *distances, const _u64 beam_width, const bool use_reorder_data, + QueryStats *stats, const knowhere::feder::diskann::FederResultUniq &feder_result, + faiss::BitsetView bitset_view, const float filter_ratio_in) { + if (beam_width > MAX_N_SECTOR_READS) + throw ANNException("Beamwidth can not be higher than MAX_N_SECTOR_READS", + -1, __FUNCSIG__, __FILE__, __LINE__); + + ThreadData data = this->thread_data.pop(); + while (data.scratch.sector_scratch == nullptr) { + this->thread_data.wait_for_push_notify(); + data = this->thread_data.pop(); + } + auto query_norm_opt = init_thread_data(data, query1); + if (!query_norm_opt.has_value()) { + // return an empty answer when calcu a zero point + this->thread_data.push(data); + this->thread_data.push_notify_all(); + return; + } + float query_norm = query_norm_opt.value(); IOContext &ctx = data.ctx; - auto query_scratch = &(data.scratch); - // reset query - query_scratch->reset(); + if (!bitset_view.empty()) { + const auto filter_threshold = filter_ratio_in < 0 + ? calcFilterThreshold(k_search) + : filter_ratio_in; + const auto bv_cnt = bitset_view.count(); + if (bitset_view.size() == bv_cnt) { + for (_u64 i = 0; i < k_search; i++) { + indices[i] = -1; + if (distances != nullptr) { + distances[i] = -1; + } + } + return; + } + + if (bv_cnt >= bitset_view.size() * filter_threshold) { + brute_force_beam_search(data, query_norm, k_search, indices, distances, + beam_width, ctx, stats, feder_result, bitset_view); + this->thread_data.push(data); + this->thread_data.push_notify_all(); + return; + } + } + + auto query_scratch = &(data.scratch); + const T *query = data.scratch.aligned_query_T; + const float *query_float = data.scratch.aligned_query_float; // pointers to buffers for data T * data_buf = query_scratch->coord_scratch; @@ -883,6 +1062,18 @@ namespace diskann { char *sector_scratch = query_scratch->sector_scratch; _u64 §or_scratch_idx = query_scratch->sector_idx; + Timer io_timer, query_timer; + // cleared every iteration + std::vector frontier; + frontier.reserve(2 * beam_width); + std::vector> frontier_nhoods; + frontier_nhoods.reserve(2 * beam_width); + std::vector frontier_read_reqs; + frontier_read_reqs.reserve(2 * beam_width); + std::vector>> + cached_nhoods; + cached_nhoods.reserve(2 * beam_width); + // query <-> PQ chunk centers distances float *pq_dists = query_scratch->aligned_pqtable_dist_scratch; pq_table.populate_chunk_distances(query_float, pq_dists); @@ -900,7 +1091,7 @@ namespace diskann { ::pq_dist_lookup(pq_coord_scratch, n_ids, this->n_chunks, pq_dists, dists_out); }; - Timer query_timer, io_timer, cpu_timer; + Timer cpu_timer; std::vector retset(l_search + 1); tsl::robin_set<_u64> &visited = *(query_scratch->visited); @@ -937,17 +1128,6 @@ namespace diskann { unsigned num_ios = 0; unsigned k = 0; - // cleared every iteration - std::vector frontier; - frontier.reserve(2 * beam_width); - std::vector> frontier_nhoods; - frontier_nhoods.reserve(2 * beam_width); - std::vector frontier_read_reqs; - frontier_read_reqs.reserve(2 * beam_width); - std::vector>> - cached_nhoods; - cached_nhoods.reserve(2 * beam_width); - while (k < cur_list_size) { auto nk = cur_list_size; // clear iteration state diff --git a/thirdparty/hnswlib/hnswlib/hnswalg.h b/thirdparty/hnswlib/hnswlib/hnswalg.h index 4abd122b9..46336cf37 100644 --- a/thirdparty/hnswlib/hnswlib/hnswalg.h +++ b/thirdparty/hnswlib/hnswlib/hnswalg.h @@ -24,6 +24,7 @@ namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; +constexpr float kHnswBruteForceFilterRate = 0.93f; template class HierarchicalNSW : public AlgorithmInterface { @@ -1100,6 +1101,28 @@ class HierarchicalNSW : public AlgorithmInterface { if (cur_element_count == 0) return result; + if (!bitset.empty()) { + const auto bs_cnt = bitset.count(); + if (bs_cnt == cur_element_count) return {}; + if (bs_cnt >= (cur_element_count * kHnswBruteForceFilterRate)) { + assert(cur_element_count == bitset.size()); + for (labeltype id = 0; id < cur_element_count; ++id) { + if (!bitset.test(id)) { + dist_t dist = fstdistfunc_(query_data, getDataByInternalId(id), dist_func_param_); + if (result.size() < k) { + result.emplace(dist, id); + continue; + } + if (dist < result.top().first) { + result.pop(); + result.emplace(dist, id); + } + } + } + return result; + } + } + tableint currObj = enterpoint_node_; auto vec_hash = knowhere::utils::hash_vec((const float*)query_data, *(size_t*)dist_func_param_); if (!lru_cache.try_get(vec_hash, currObj)) { diff --git a/unittest/CMakeLists.txt b/unittest/CMakeLists.txt index 802ffbf88..0a8a459ae 100644 --- a/unittest/CMakeLists.txt +++ b/unittest/CMakeLists.txt @@ -45,6 +45,7 @@ set( UTIL_SRCS set( ALL_INDEXING_SRCS ${UTIL_SRCS} test_common.cpp + test_utils.cpp ) set(ALL_INDEXING_SRCS diff --git a/unittest/test_common.cpp b/unittest/test_common.cpp index aefad7f95..4747d0c08 100644 --- a/unittest/test_common.cpp +++ b/unittest/test_common.cpp @@ -12,9 +12,12 @@ #include #include +#include "utils.h" + #include "knowhere/common/Dataset.h" #include "knowhere/common/Timer.h" #include "knowhere/common/Exception.h" +#include "knowhere/common/Heap.h" #include "knowhere/utils/BitsetView.h" /*Some unittest for knowhere/common, mainly for improve code coverage.*/ @@ -56,3 +59,24 @@ TEST(COMMON_TEST, BitsetView) { std::cout << bitset.to_string(0, N) << std::endl; } } + +namespace { +constexpr size_t kHeapSize = 10; +constexpr size_t kElementCount = 10000; +} // namespace + +TEST(COMMON_TEST, ResultMaxHeap) { + knowhere::ResultMaxHeap heap(kHeapSize); + auto pairs = GenerateRandomDistanceIdPair(kElementCount); + for (const auto& [dist, id] : pairs) { + heap.Push(dist, id); + } + ASSERT_EQ(heap.Size(), kHeapSize); + std::sort(pairs.begin(), pairs.end()); + for (int i = kHeapSize - 1; i >= 0; --i) { + auto op = heap.Pop(); + ASSERT_TRUE(op.has_value()); + ASSERT_EQ(op.value().second, pairs[i].second); + } + ASSERT_EQ(heap.Size(), 0); +} \ No newline at end of file diff --git a/unittest/test_diskann.cpp b/unittest/test_diskann.cpp index 1df3e8adc..2dfc70333 100644 --- a/unittest/test_diskann.cpp +++ b/unittest/test_diskann.cpp @@ -32,16 +32,12 @@ error "Missing the header." #include "knowhere/index/vector_index/IndexDiskANN.h" #include "knowhere/index/vector_index/IndexDiskANNConfig.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" -#include "knowhere/index/vector_index/helpers/RangeUtil.h" #include "unittest/LocalFileManager.h" #include "unittest/utils.h" using ::testing::Combine; using ::testing::TestWithParam; using ::testing::Values; -using IdDisPair = std::pair; -using GroundTruth = std::vector>; -using GroundTruthPtr = std::shared_ptr; namespace { @@ -113,88 +109,6 @@ GenLargeData(size_t num) { return data_p; } -struct DisPairLess { - bool - operator()(const IdDisPair& p1, const IdDisPair& p2) { - return p1.second < p2.second; - } -}; - -GroundTruthPtr -GenGroundTruth(const float* data_p, const float* query_p, const std::string metric, const uint32_t num_rows, - const uint32_t num_dims, const uint32_t num_queries, const faiss::BitsetView bitset = nullptr) { - GroundTruthPtr ground_truth = std::make_shared(); - ground_truth->resize(num_queries); - - for (uint32_t query_index = 0; query_index < num_queries; ++query_index) { // for each query - // use priority_queue to keep the topK; - std::priority_queue, DisPairLess> pq; - for (int64_t row = 0; row < num_rows; ++row) { // for each row - if (!bitset.empty() && bitset.test(row)) { - continue; - } - float dis = 0; - for (uint32_t dim = 0; dim < num_dims; ++dim) { // for every dim - if (metric == knowhere::metric::IP) { - dis -= (data_p[num_dims * row + dim] * query_p[query_index * num_dims + dim]); - } else { - dis += ((data_p[num_dims * row + dim] - query_p[query_index * num_dims + dim]) * - (data_p[num_dims * row + dim] - query_p[query_index * num_dims + dim])); - } - } - if (pq.size() < kK) { - pq.push(std::make_pair(row, dis)); - } else if (pq.top().second > dis) { - pq.pop(); - pq.push(std::make_pair(row, dis)); - } - } - - auto& result_ids = ground_truth->at(query_index); - - // write id in priority_queue to vector for sorting. - int pq_size = pq.size(); - for (uint32_t index = 0; index < pq_size; ++index) { - auto& id_dis_pair = pq.top(); - result_ids.push_back(id_dis_pair.first); - pq.pop(); - } - } - return ground_truth; -} - -GroundTruthPtr -GenRangeSearchGrounTruth(const float* data_p, const float* query_p, const std::string metric, const uint32_t num_rows, - const uint32_t num_dims, const uint32_t num_queries, const float radius, - const float range_filter, const faiss::BitsetView bitset = nullptr) { - GroundTruthPtr ground_truth = std::make_shared(); - ground_truth->resize(num_queries); - bool is_ip = (metric == knowhere::metric::IP); - for (uint32_t query_index = 0; query_index < num_queries; ++query_index) { - std::vector paris; - const float* xq = query_p + query_index * num_dims; - for (int64_t row = 0; row < num_rows; ++row) { // for each row - if (!bitset.empty() && bitset.test(row)) { - continue; - } - const float* xb = data_p + row * num_dims; - float dis = 0; - if (metric == knowhere::metric::IP) { - for (uint32_t dim = 0; dim < num_dims; ++dim) { // for every dim - dis += xb[dim] * xq[dim]; - } - } else { - for (uint32_t dim = 0; dim < num_dims; ++dim) { // for every dim - dis += std::pow(xb[dim] - xq[dim], 2); - } - } - if (knowhere::distance_in_range(dis, radius, range_filter, is_ip)) { - ground_truth->at(query_index).emplace_back(row); - } - } - } - return ground_truth; -} void WriteRawDataToDisk(const std::string data_path, const float* raw_data, const uint32_t num, const uint32_t dim) { @@ -205,42 +119,6 @@ WriteRawDataToDisk(const std::string data_path, const float* raw_data, const uin writer.close(); } -uint32_t -GetMatchedNum(const std::vector& ground_truth, const int64_t* result, const int32_t limit) { - uint32_t matched_num = 0; - int missed = 0; - for (uint32_t index = 0; index < limit; ++index) { - if (std::find(ground_truth.begin(), ground_truth.end(), result[index]) != ground_truth.end()) { - matched_num++; - } - } - return matched_num; -} - -float -CheckTopKRecall(GroundTruthPtr ground_truth, const int64_t* result, const int32_t k, const uint32_t num_queries) { - uint32_t recall = 0; - for (uint32_t n = 0; n < num_queries; ++n) { - recall += GetMatchedNum(ground_truth->at(n), result + (n * k), ground_truth->at(n).size()); - } - return ((float)recall) / ((float)num_queries * k); -} - -float -CheckRangeSearchRecall(GroundTruthPtr ground_truth, const int64_t* result, const size_t* limits, - const uint32_t num_queries) { - uint32_t recall = 0; - uint32_t total = 0; - for (uint32_t n = 0; n < num_queries; ++n) { - recall += GetMatchedNum(ground_truth->at(n), result + limits[n], limits[n + 1] - limits[n]); - total += ground_truth->at(n).size(); - } - if (total == 0) { - return 1; - } - return ((float)recall) / ((float)total); -} - template void CheckConfigError(DiskANNConfig& config_to_test) { @@ -345,9 +223,9 @@ class DiskANNTest : public TestWithParam> { // lr.close(); ip_ground_truth_ = - GenGroundTruth(global_raw_data_, global_query_data_, knowhere::metric::IP, kNumRows, kDim, kNumQueries); + GenGroundTruth(global_raw_data_, global_query_data_, knowhere::metric::IP, kNumRows, kDim, kNumQueries, kK); l2_ground_truth_ = - GenGroundTruth(global_raw_data_, global_query_data_, knowhere::metric::L2, kNumRows, kDim, kNumQueries); + GenGroundTruth(global_raw_data_, global_query_data_, knowhere::metric::L2, kNumRows, kDim, kNumQueries, kK); ip_range_search_ground_truth_ = GenRangeSearchGrounTruth( global_raw_data_, global_query_data_, knowhere::metric::IP, kNumRows, kDim, kNumQueries, kIPRadius, kIPRangeFilter); l2_range_search_ground_truth_ = GenRangeSearchGrounTruth( @@ -355,10 +233,10 @@ class DiskANNTest : public TestWithParam> { large_dim_ip_ground_truth_ = GenGroundTruth(global_large_dim_raw_data_, global_large_dim_query_data_, knowhere::metric::IP, - kLargeDimNumRows, kLargeDim, kLargeDimNumQueries); + kLargeDimNumRows, kLargeDim, kLargeDimNumQueries, kK); large_dim_l2_ground_truth_ = GenGroundTruth(global_large_dim_raw_data_, global_large_dim_query_data_, knowhere::metric::L2, - kLargeDimNumRows, kLargeDim, kLargeDimNumQueries); + kLargeDimNumRows, kLargeDim, kLargeDimNumQueries, kK); large_dim_ip_range_search_ground_truth_ = GenRangeSearchGrounTruth(global_large_dim_raw_data_, global_large_dim_query_data_, knowhere::metric::IP, kLargeDimNumRows, kLargeDim, kLargeDimNumQueries, kLargeDimIPRadius, kLargeDimIPRangeFilter); @@ -494,7 +372,7 @@ TEST_P(DiskANNTest, bitset_view_test) { set_bit(knn_bitset_data.data(), id_to_mask); } faiss::BitsetView knn_bitset(knn_bitset_data.data(), num_rows_); - auto ground_truth = GenGroundTruth(raw_data_, query_data_, metric_, num_rows_, dim_, num_queries_, knn_bitset); + auto ground_truth = GenGroundTruth(raw_data_, query_data_, metric_, num_rows_, dim_, num_queries_, kK, knn_bitset); // query with bitset view result = diskann->Query(data_set_ptr, cfg, knn_bitset); @@ -503,6 +381,33 @@ TEST_P(DiskANNTest, bitset_view_test) { auto recall = CheckTopKRecall(ground_truth, ids, kK, num_queries_); EXPECT_GT(recall, 0.8); + // test query with bitset view below/above threshold + knowhere::DiskANNQueryConfig bs_cfg = query_conf; + std::vector(size_t, size_t)>> gen_bitset_funcs = { + GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet}; + const auto bitset_percentages = {0.4f, 0.98f}; + const auto bitset_thresholds = {-1.0f, 0.9f}; + for (const float threshold : bitset_thresholds) { + bs_cfg.filter_threshold = threshold; + cfg.clear(); + knowhere::DiskANNQueryConfig::Set(cfg, bs_cfg); + for (const float percentage : bitset_percentages) { + for (const auto& gen_func : gen_bitset_funcs) { + auto bitset_data = gen_func(num_rows_, percentage * num_rows_); + faiss::BitsetView bs(bitset_data.data(), num_rows_); + auto result = diskann->Query(data_set_ptr, cfg, bs); + auto gt = GenGroundTruth(raw_data_, query_data_, metric_, num_rows_, dim_, num_queries_, kK, bs); + auto ids = knowhere::GetDatasetIDs(result); + float recall = CheckTopKRecall(gt, ids, kK, num_queries_); + if (percentage == 0.98f) { + EXPECT_GT(recall, 0.9f); + } else { + EXPECT_GT(recall, 0.8f); + } + } + } + } + // test for range search cfg.clear(); knowhere::DiskANNQueryByRangeConfig::Set(cfg, range_search_conf_); diff --git a/unittest/test_hnsw.cpp b/unittest/test_hnsw.cpp index b0de6d2f8..aa36edee2 100644 --- a/unittest/test_hnsw.cpp +++ b/unittest/test_hnsw.cpp @@ -11,7 +11,10 @@ #include #include +#include +#include +#include "BitsetView.h" #include "knowhere/common/Config.h" #include "knowhere/feder/HNSW.h" #include "knowhere/index/vector_index/ConfAdapterMgr.h" @@ -344,3 +347,34 @@ TEST_P(HNSWTest, hnsw_data_overflow) { auto result = index_->Query(base_dataset, conf_, nullptr); } + +namespace { + constexpr float kKnnRecallThreshold = 0.8f; + constexpr float kBruteForceRecallThreshold = 0.99f; +} + +TEST_P(HNSWTest, hnsw_bitset) { + index_->BuildAll(base_dataset, conf_); + const auto metric = knowhere::GetMetaMetricType(conf_); + std::vector(size_t, size_t)>> gen_bitset_funcs = { + GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet}; + const float threshold = hnswlib::kHnswBruteForceFilterRate; + const auto bitset_percentages = {0.4f, 0.98f}; + for (const float percentage : bitset_percentages) { + for (const auto& gen_func : gen_bitset_funcs) { + auto bitset_data = gen_func(nb, percentage * nb); + faiss::BitsetView bs(bitset_data.data(), nb); + float* data_p = (float*)knowhere::GetDatasetTensor(base_dataset); + auto result = index_->Query(query_dataset, conf_, bs); + float* query_p = (float*)knowhere::GetDatasetTensor(query_dataset); + auto gt = GenGroundTruth(data_p, query_p, metric, nb, dim, nq, k, bs); + auto result_p = knowhere::GetDatasetIDs(result); + float recall = CheckTopKRecall(gt, result_p, k, nq); + if (percentage > threshold) { + ASSERT_TRUE(recall > kBruteForceRecallThreshold); + } else { + ASSERT_TRUE(recall > kKnnRecallThreshold); + } + } + } +} \ No newline at end of file diff --git a/unittest/test_utils.cpp b/unittest/test_utils.cpp new file mode 100644 index 000000000..f686d6be0 --- /dev/null +++ b/unittest/test_utils.cpp @@ -0,0 +1,50 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include "unittest/utils.h" +#include "knowhere/utils/BitsetView.h" + +class UtilsBitsetTest : public ::testing::Test { + protected: + const std::vector kBitsetSizes{4, 8, 10, 64, 100, 500, 1024}; +}; + +TEST_F(UtilsBitsetTest, FirstTBits) { + for (const auto size : kBitsetSizes) { + for (size_t i = 0; i <= size; ++i) { + auto bitset_data = GenerateBitsetWithFirstTbitsSet(size, i); + faiss::BitsetView bitset(bitset_data.data(), size); + for (size_t j = 0; j < i; ++j) { + ASSERT_TRUE(bitset.test(j)); + } + for (size_t j = i; j < size; ++j) { + ASSERT_FALSE(bitset.test(j)); + } + } + } +} + +TEST_F(UtilsBitsetTest, RandomTBits) { + for (const auto size : kBitsetSizes) { + for (size_t i = 0; i <= size; ++i) { + auto bitset_data = GenerateBitsetWithRandomTbitsSet(size, i); + faiss::BitsetView bitset(bitset_data.data(), size); + size_t cnt = 0; + for (size_t j = 0; j < size; ++j) { + cnt += bitset.test(j); + } + ASSERT_EQ(cnt, i); + } + } +} diff --git a/unittest/utils.cpp b/unittest/utils.cpp index d5ad9af7d..4d52e64ee 100644 --- a/unittest/utils.cpp +++ b/unittest/utils.cpp @@ -11,12 +11,15 @@ #include #include +#include #include #include #include #include +#include #include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/RangeUtil.h" #include "knowhere/utils/BitsetView.h" #include "unittest/utils.h" @@ -276,3 +279,122 @@ int64_t random() { } #endif + +struct DisPairLess { + bool + operator()(const IdDisPair& p1, const IdDisPair& p2) { + return p1.second < p2.second; + } +}; + +uint32_t +GetMatchedNum(const std::vector& ground_truth, const int64_t* result, const int32_t limit) { + uint32_t matched_num = 0; + int missed = 0; + for (uint32_t index = 0; index < limit; ++index) { + if (std::find(ground_truth.begin(), ground_truth.end(), result[index]) != ground_truth.end()) { + matched_num++; + } + } + return matched_num; +} + +GroundTruthPtr +GenGroundTruth(const float* data_p, const float* query_p, const std::string metric, const uint32_t num_rows, + const uint32_t num_dims, const uint32_t num_queries, const uint32_t topk, const faiss::BitsetView bitset) { + GroundTruthPtr ground_truth = std::make_shared(); + ground_truth->resize(num_queries); + + for (uint32_t query_index = 0; query_index < num_queries; ++query_index) { // for each query + // use priority_queue to keep the topK; + std::priority_queue, DisPairLess> pq; + for (int64_t row = 0; row < num_rows; ++row) { // for each row + if (!bitset.empty() && bitset.test(row)) { + continue; + } + float dis = 0; + for (uint32_t dim = 0; dim < num_dims; ++dim) { // for every dim + if (metric == knowhere::metric::IP) { + dis -= (data_p[num_dims * row + dim] * query_p[query_index * num_dims + dim]); + } else { + dis += ((data_p[num_dims * row + dim] - query_p[query_index * num_dims + dim]) * + (data_p[num_dims * row + dim] - query_p[query_index * num_dims + dim])); + } + } + if (pq.size() < topk) { + pq.push(std::make_pair(row, dis)); + } else if (pq.top().second > dis) { + pq.pop(); + pq.push(std::make_pair(row, dis)); + } + } + + auto& result_ids = ground_truth->at(query_index); + + // write id in priority_queue to vector for sorting. + int pq_size = pq.size(); + for (uint32_t index = 0; index < pq_size; ++index) { + auto& id_dis_pair = pq.top(); + result_ids.push_back(id_dis_pair.first); + pq.pop(); + } + } + return ground_truth; +} + +GroundTruthPtr +GenRangeSearchGrounTruth(const float* data_p, const float* query_p, const std::string metric, const uint32_t num_rows, + const uint32_t num_dims, const uint32_t num_queries, const float radius, + const float range_filter, const faiss::BitsetView bitset) { + GroundTruthPtr ground_truth = std::make_shared(); + ground_truth->resize(num_queries); + bool is_ip = (metric == knowhere::metric::IP); + for (uint32_t query_index = 0; query_index < num_queries; ++query_index) { + std::vector paris; + const float* xq = query_p + query_index * num_dims; + for (int64_t row = 0; row < num_rows; ++row) { // for each row + if (!bitset.empty() && bitset.test(row)) { + continue; + } + const float* xb = data_p + row * num_dims; + float dis = 0; + if (metric == knowhere::metric::IP) { + for (uint32_t dim = 0; dim < num_dims; ++dim) { // for every dim + dis += xb[dim] * xq[dim]; + } + } else { + for (uint32_t dim = 0; dim < num_dims; ++dim) { // for every dim + dis += std::pow(xb[dim] - xq[dim], 2); + } + } + if (knowhere::distance_in_range(dis, radius, range_filter, is_ip)) { + ground_truth->at(query_index).emplace_back(row); + } + } + } + return ground_truth; +} + +float +CheckTopKRecall(GroundTruthPtr ground_truth, const int64_t* result, const int32_t k, const uint32_t num_queries) { + uint32_t recall = 0; + for (uint32_t n = 0; n < num_queries; ++n) { + recall += GetMatchedNum(ground_truth->at(n), result + (n * k), ground_truth->at(n).size()); + } + return ((float)recall) / ((float)num_queries * k); +} + +float +CheckRangeSearchRecall(GroundTruthPtr ground_truth, const int64_t* result, const size_t* limits, + const uint32_t num_queries) { + uint32_t recall = 0; + uint32_t total = 0; + for (uint32_t n = 0; n < num_queries; ++n) { + recall += GetMatchedNum(ground_truth->at(n), result + limits[n], limits[n + 1] - limits[n]); + total += ground_truth->at(n).size(); + } + if (total == 0) { + return 1; + } + return ((float)recall) / ((float)total); +} diff --git a/unittest/utils.h b/unittest/utils.h index 676653c82..d26f1ad82 100644 --- a/unittest/utils.h +++ b/unittest/utils.h @@ -16,10 +16,13 @@ #include #include #include +#include +#include #include "knowhere/archive/KnowhereConfig.h" #include "knowhere/common/Dataset.h" #include "knowhere/common/Log.h" +#include "knowhere/utils/BitsetView.h" class DataGen { public: @@ -202,3 +205,66 @@ int64_t random(); #endif + +constexpr int64_t kSeed = 42; + +// Return a n-bits bitset data with first t bits set to true +inline std::vector +GenerateBitsetWithFirstTbitsSet(size_t n, size_t t) { + assert(t >= 0 && t <= n); + std::vector data((n + 8 - 1) / 8, 0); + for (size_t i = 0; i < t; ++i) { + data[i >> 3] |= (0x1 << (i & 0x7)); + } + return data; +} + +// Return a n-bits bitset data with random t bits set to true +inline std::vector +GenerateBitsetWithRandomTbitsSet(size_t n, size_t t) { + assert(t >= 0 && t <= n); + std::vector bits_shuffle(n, false); + for (size_t i = 0; i < t; ++i) bits_shuffle[i] = true; + std::mt19937 g(kSeed); + std::shuffle(bits_shuffle.begin(), bits_shuffle.end(), g); + std::vector data((n + 8 - 1) / 8, 0); + for (size_t i = 0; i < n; ++i) { + if (bits_shuffle[i]) { + data[i >> 3] |= (0x1 << (i & 0x7)); + } + } + return data; +} + +// Randomly generate n (distances, id) pairs +inline std::vector> +GenerateRandomDistanceIdPair(size_t n) { + std::mt19937 rng(kSeed); + std::uniform_real_distribution<> distrib(std::numeric_limits().min(), std::numeric_limits().max()); + std::vector> res; + res.reserve(n); + for (size_t i = 0; i < n; ++i) { + res.emplace_back(distrib(rng), i); + } + return res; +} + +using IdDisPair = std::pair; +using GroundTruth = std::vector>; +using GroundTruthPtr = std::shared_ptr; + +GroundTruthPtr +GenGroundTruth(const float* data_p, const float* query_p, const std::string metric, const uint32_t num_rows, + const uint32_t num_dims, const uint32_t num_queries, const uint32_t topk, const faiss::BitsetView bitset = nullptr); + +GroundTruthPtr +GenRangeSearchGrounTruth(const float* data_p, const float* query_p, const std::string metric, const uint32_t num_rows, + const uint32_t num_dims, const uint32_t num_queries, const float radius, + const float range_filter, const faiss::BitsetView bitset = nullptr); + +float +CheckTopKRecall(GroundTruthPtr ground_truth, const int64_t* result, const int32_t k, const uint32_t num_queries); + +float +CheckRangeSearchRecall(GroundTruthPtr ground_truth, const int64_t* result, const size_t* limits, + const uint32_t num_queries); \ No newline at end of file