diff --git a/src/index/sparse/sparse_index_node.cc b/src/index/sparse/sparse_index_node.cc index 61cf373c9..c646c5eaa 100644 --- a/src/index/sparse/sparse_index_node.cc +++ b/src/index/sparse/sparse_index_node.cc @@ -19,7 +19,6 @@ #include "knowhere/config.h" #include "knowhere/dataset.h" #include "knowhere/expected.h" -#include "knowhere/feature.h" #include "knowhere/index/index_factory.h" #include "knowhere/index/index_node.h" #include "knowhere/log.h" @@ -125,7 +124,7 @@ class SparseInvertedIndexNode : public IndexNode { public: RefineIterator(const sparse::BaseInvertedIndex* index, sparse::SparseRow&& query, std::shared_ptr precomputed_it, - const sparse::DocValueComputer& computer, const float refine_ratio = 0.5f) + const sparse::DocValueComputer& computer, const float refine_ratio = 0.5f) : IndexIterator(true, refine_ratio), index_(index), query_(std::move(query)), @@ -155,7 +154,7 @@ class SparseInvertedIndexNode : public IndexNode { private: const sparse::BaseInvertedIndex* index_; sparse::SparseRow query_; - const sparse::DocValueComputer computer_; + const sparse::DocValueComputer computer_; std::shared_ptr precomputed_it_; bool first_return_ = true; }; @@ -251,7 +250,7 @@ class SparseInvertedIndexNode : public IndexNode { return index_or.error(); } index_ = index_or.value(); - return index_->Load(reader); + return index_->Load(reader, 0, ""); } Status @@ -327,7 +326,8 @@ class SparseInvertedIndexNode : public IndexNode { expected*> CreateIndex(const SparseInvertedIndexConfig& cfg) const { if (IsMetricType(cfg.metric_type.value(), metric::BM25)) { - auto idx = new sparse::InvertedIndex(); + // quantize float to uint16_t when BM25 metric type is used. + auto idx = new sparse::InvertedIndex(); if (!cfg.bm25_k1.has_value() || !cfg.bm25_b.has_value() || !cfg.bm25_avgdl.has_value()) { return expected*>::Err( Status::invalid_args, "BM25 parameters k1, b, and avgdl must be set when building/loading"); @@ -339,7 +339,7 @@ class SparseInvertedIndexNode : public IndexNode { idx->SetBM25Params(k1, b, avgdl, max_score_ratio); return idx; } else { - return new sparse::InvertedIndex(); + return new sparse::InvertedIndex(); } } diff --git a/src/index/sparse/sparse_inverted_index.h b/src/index/sparse/sparse_inverted_index.h index c696fd1dc..59a294b82 100644 --- a/src/index/sparse/sparse_inverted_index.h +++ b/src/index/sparse/sparse_inverted_index.h @@ -20,12 +20,14 @@ #include #include #include +#include #include #include #include "index/sparse/sparse_inverted_index_config.h" #include "io/memory_io.h" #include "knowhere/bitsetview.h" +#include "knowhere/comp/index_param.h" #include "knowhere/expected.h" #include "knowhere/log.h" #include "knowhere/sparse_utils.h" @@ -43,7 +45,7 @@ class BaseInvertedIndex { // supplement_target_filename: when in mmap mode, we need an extra file to store the mmaped index data structure. // this file will be created during loading and deleted in the destructor. virtual Status - Load(MemoryIOReader& reader, int map_flags = MAP_SHARED, const std::string& supplement_target_filename = "") = 0; + Load(MemoryIOReader& reader, int map_flags, const std::string& supplement_target_filename) = 0; virtual Status Train(const SparseRow* data, size_t rows) = 0; @@ -65,7 +67,7 @@ class BaseInvertedIndex { virtual expected> GetDocValueComputer(const SparseInvertedIndexConfig& cfg) const = 0; - virtual bool + [[nodiscard]] virtual bool IsApproximated() const = 0; [[nodiscard]] virtual size_t @@ -78,8 +80,8 @@ class BaseInvertedIndex { n_cols() const = 0; }; -template -class InvertedIndex : public BaseInvertedIndex { +template +class InvertedIndex : public BaseInvertedIndex { public: explicit InvertedIndex() { } @@ -102,12 +104,15 @@ class InvertedIndex : public BaseInvertedIndex { } } + template + using Vector = std::conditional_t, std::vector>; + void SetBM25Params(float k1, float b, float avgdl, float max_score_ratio) { bm25_params_ = std::make_unique(k1, b, avgdl, max_score_ratio); } - expected> + expected> GetDocValueComputer(const SparseInvertedIndexConfig& cfg) const override { // if metric_type is set in config, it must match with how the index was built. auto metric_type = cfg.metric_type; @@ -115,33 +120,34 @@ class InvertedIndex : public BaseInvertedIndex { if (metric_type.has_value() && !IsMetricType(metric_type.value(), metric::IP)) { auto msg = "metric type not match, expected: " + std::string(metric::IP) + ", got: " + metric_type.value(); - return expected>::Err(Status::invalid_metric_type, msg); + return expected>::Err(Status::invalid_metric_type, msg); } - return GetDocValueOriginalComputer(); + return GetDocValueOriginalComputer(); } if (metric_type.has_value() && !IsMetricType(metric_type.value(), metric::BM25)) { auto msg = "metric type not match, expected: " + std::string(metric::BM25) + ", got: " + metric_type.value(); - return expected>::Err(Status::invalid_metric_type, msg); + return expected>::Err(Status::invalid_metric_type, msg); } // avgdl must be supplied during search if (!cfg.bm25_avgdl.has_value()) { - return expected>::Err(Status::invalid_args, "avgdl must be supplied during searching"); + return expected>::Err(Status::invalid_args, + "avgdl must be supplied during searching"); } auto avgdl = cfg.bm25_avgdl.value(); if constexpr (use_wand) { // wand: search time k1/b must equal load time config. if ((cfg.bm25_k1.has_value() && cfg.bm25_k1.value() != bm25_params_->k1) || ((cfg.bm25_b.has_value() && cfg.bm25_b.value() != bm25_params_->b))) { - return expected>::Err( + return expected>::Err( Status::invalid_args, "search time k1/b must equal load time config for WAND index."); } - return GetDocValueBM25Computer(bm25_params_->k1, bm25_params_->b, avgdl); + return GetDocValueBM25Computer(bm25_params_->k1, bm25_params_->b, avgdl); } else { // inverted index: search time k1/b may override load time config. auto k1 = cfg.bm25_k1.has_value() ? cfg.bm25_k1.value() : bm25_params_->k1; auto b = cfg.bm25_b.has_value() ? cfg.bm25_b.value() : bm25_params_->b; - return GetDocValueBM25Computer(k1, b, avgdl); + return GetDocValueBM25Computer(k1, b, avgdl); } } @@ -159,47 +165,46 @@ class InvertedIndex : public BaseInvertedIndex { * 1. table_t idx * 2. T val * - * inverted_lut_ and max_score_in_dim_ not serialized, they will be - * constructed dynamically during deserialization. + * inverted_index_ids_, inverted_index_vals_ and max_score_in_dim_ are + * not serialized, they will be constructed dynamically during + * deserialization. * * Data are densely packed in serialized bytes and no padding is added. */ - T deprecated_value_threshold = 0; + DType deprecated_value_threshold = 0; writeBinaryPOD(writer, n_rows_internal_); writeBinaryPOD(writer, max_dim_); writeBinaryPOD(writer, deprecated_value_threshold); BitsetView bitset(nullptr, 0); - std::vector> cursors; + std::vector> cursors; - for (size_t i = 0; i < inverted_lut_.size(); ++i) { - cursors.emplace_back(inverted_lut_[i], n_rows_internal_, 0, 0, bitset); + for (size_t i = 0; i < inverted_index_ids_.size(); ++i) { + cursors.emplace_back(inverted_index_ids_[i], inverted_index_vals_[i], n_rows_internal_, 0, 0, bitset); } for (table_t vec_id = 0; vec_id < n_rows_internal_; ++vec_id) { - std::vector> vec_row; - for (size_t i = 0; i < inverted_lut_.size(); ++i) { + std::vector> vec_row; + for (size_t i = 0; i < inverted_index_ids_.size(); ++i) { if (cursors[i].cur_vec_id_ == vec_id) { vec_row.emplace_back(dim_map_reverse_[i], cursors[i].cur_vec_val()); cursors[i].next(); } } - SparseRow raw_row(vec_row); + SparseRow raw_row(vec_row); writeBinaryPOD(writer, raw_row.size()); if (raw_row.size() == 0) { continue; } - writer.write(raw_row.data(), raw_row.size() * SparseRow::element_size()); + writer.write(raw_row.data(), raw_row.size() * SparseRow::element_size()); } return Status::success; } - Status - Load(MemoryIOReader& reader, int map_flags = MAP_SHARED, - const std::string& supplement_target_filename = "") override { - T deprecated_value_threshold; + Load(MemoryIOReader& reader, int map_flags, const std::string& supplement_target_filename) override { + DType deprecated_value_threshold; int64_t rows; readBinaryPOD(reader, rows); // previous versions used the signness of rows to indicate whether to @@ -220,14 +225,14 @@ class InvertedIndex : public BaseInvertedIndex { for (int64_t i = 0; i < rows; ++i) { size_t count; readBinaryPOD(reader, count); - SparseRow raw_row; + SparseRow raw_row; if constexpr (mmapped) { - raw_row = std::move(SparseRow(count, reader.data() + reader.tellg(), false)); - reader.advance(count * SparseRow::element_size()); + raw_row = std::move(SparseRow(count, reader.data() + reader.tellg(), false)); + reader.advance(count * SparseRow::element_size()); } else { - raw_row = std::move(SparseRow(count)); + raw_row = std::move(SparseRow(count)); if (count > 0) { - reader.read(raw_row.data(), count * SparseRow::element_size()); + reader.read(raw_row.data(), count * SparseRow::element_size()); } } add_row_to_index(raw_row, i); @@ -242,7 +247,7 @@ class InvertedIndex : public BaseInvertedIndex { Status PrepareMmap(MemoryIOReader& reader, size_t rows, int map_flags, const std::string& supplement_target_filename) { const auto initial_reader_location = reader.tellg(); - const auto nnz = (reader.remaining() - (rows * sizeof(size_t))) / SparseRow::element_size(); + const auto nnz = (reader.remaining() - (rows * sizeof(size_t))) / SparseRow::element_size(); // count raw vector idx occurrences std::unordered_map idx_counts; @@ -257,18 +262,23 @@ class InvertedIndex : public BaseInvertedIndex { readBinaryPOD(reader, idx); idx_counts[idx]++; // skip value - reader.advance(sizeof(T)); + reader.advance(sizeof(DType)); } } // reset reader to the beginning reader.seekg(initial_reader_location); - auto inverted_lut_byte_size = idx_counts.size() * sizeof(typename decltype(inverted_lut_)::value_type); - auto luts_byte_size = nnz * sizeof(typename decltype(inverted_lut_)::value_type::value_type); + auto inverted_index_ids_byte_size = + idx_counts.size() * sizeof(typename decltype(inverted_index_ids_)::value_type); + auto inverted_index_vals_byte_size = + idx_counts.size() * sizeof(typename decltype(inverted_index_vals_)::value_type); + auto plists_ids_byte_size = nnz * sizeof(typename decltype(inverted_index_ids_)::value_type::value_type); + auto plists_vals_byte_size = nnz * sizeof(typename decltype(inverted_index_vals_)::value_type::value_type); auto max_score_in_dim_byte_size = idx_counts.size() * sizeof(typename decltype(max_score_in_dim_)::value_type); size_t row_sums_byte_size = 0; - map_byte_size_ = inverted_lut_byte_size + luts_byte_size; + map_byte_size_ = + inverted_index_ids_byte_size + inverted_index_vals_byte_size + plists_ids_byte_size + plists_vals_byte_size; if constexpr (use_wand) { map_byte_size_ += max_score_in_dim_byte_size; } @@ -313,8 +323,10 @@ class InvertedIndex : public BaseInvertedIndex { char* ptr = map_; // initialize containers memory. - inverted_lut_.initialize(ptr, inverted_lut_byte_size); - ptr += inverted_lut_byte_size; + inverted_index_ids_.initialize(ptr, inverted_index_ids_byte_size); + ptr += inverted_index_ids_byte_size; + inverted_index_vals_.initialize(ptr, inverted_index_vals_byte_size); + ptr += inverted_index_vals_byte_size; if constexpr (use_wand) { max_score_in_dim_.initialize(ptr, max_score_in_dim_byte_size); @@ -326,15 +338,23 @@ class InvertedIndex : public BaseInvertedIndex { ptr += row_sums_byte_size; } + for (const auto& [idx, count] : idx_counts) { + auto& plist_ids = inverted_index_ids_.emplace_back(); + auto plist_ids_byte_size = count * sizeof(typename decltype(inverted_index_ids_)::value_type::value_type); + plist_ids.initialize(ptr, plist_ids_byte_size); + ptr += plist_ids_byte_size; + } + for (const auto& [idx, count] : idx_counts) { + auto& plist_vals = inverted_index_vals_.emplace_back(); + auto plist_vals_byte_size = count * sizeof(typename decltype(inverted_index_vals_)::value_type::value_type); + plist_vals.initialize(ptr, plist_vals_byte_size); + ptr += plist_vals_byte_size; + } size_t dim_id = 0; for (const auto& [idx, count] : idx_counts) { dim_map_[idx] = dim_id; - auto& lut = inverted_lut_.emplace_back(); - auto lut_byte_size = count * sizeof(typename decltype(inverted_lut_)::value_type::value_type); - lut.initialize(ptr, lut_byte_size); - ptr += lut_byte_size; if constexpr (use_wand) { - max_score_in_dim_.emplace_back(0); + max_score_in_dim_.emplace_back(0.0f); } ++dim_id; } @@ -347,7 +367,7 @@ class InvertedIndex : public BaseInvertedIndex { // Non zero drop ratio is only supported for static index, i.e. data should // include all rows that'll be added to the index. Status - Train(const SparseRow* data, size_t rows) override { + Train(const SparseRow* data, size_t rows) override { if constexpr (mmapped) { throw std::invalid_argument("mmapped InvertedIndex does not support Train"); } else { @@ -356,7 +376,7 @@ class InvertedIndex : public BaseInvertedIndex { } Status - Add(const SparseRow* data, size_t rows, int64_t dim) override { + Add(const SparseRow* data, size_t rows, int64_t dim) override { if constexpr (mmapped) { throw std::invalid_argument("mmapped InvertedIndex does not support Add"); } else { @@ -378,8 +398,8 @@ class InvertedIndex : public BaseInvertedIndex { } void - Search(const SparseRow& query, size_t k, float drop_ratio_search, float* distances, label_t* labels, - size_t refine_factor, const BitsetView& bitset, const DocValueComputer& computer) const override { + Search(const SparseRow& query, size_t k, float drop_ratio_search, float* distances, label_t* labels, + size_t refine_factor, const BitsetView& bitset, const DocValueComputer& computer) const override { // initially set result distances to NaN and labels to -1 std::fill(distances, distances + k, std::numeric_limits::quiet_NaN()); std::fill(labels, labels + k, -1); @@ -387,7 +407,7 @@ class InvertedIndex : public BaseInvertedIndex { return; } - std::vector values(query.size()); + std::vector values(query.size()); for (size_t i = 0; i < query.size(); ++i) { values[i] = std::abs(query[i].val); } @@ -397,7 +417,7 @@ class InvertedIndex : public BaseInvertedIndex { if (drop_ratio_search == 0) { refine_factor = 1; } - MaxMinHeap heap(k * refine_factor); + MaxMinHeap heap(k * refine_factor); if constexpr (!use_wand) { search_brute_force(query, q_threshold, heap, bitset, computer); } else { @@ -413,12 +433,12 @@ class InvertedIndex : public BaseInvertedIndex { // Returned distances are inaccurate based on the drop_ratio. std::vector - GetAllDistances(const SparseRow& query, float drop_ratio_search, const BitsetView& bitset, - const DocValueComputer& computer) const override { + GetAllDistances(const SparseRow& query, float drop_ratio_search, const BitsetView& bitset, + const DocValueComputer& computer) const override { if (query.size() == 0) { return {}; } - std::vector values(query.size()); + std::vector values(query.size()); for (size_t i = 0; i < query.size(); ++i) { values[i] = std::abs(query[i].val); } @@ -435,8 +455,8 @@ class InvertedIndex : public BaseInvertedIndex { } float - GetRawDistance(const label_t vec_id, const SparseRow& query, - const DocValueComputer& computer) const override { + GetRawDistance(const label_t vec_id, const SparseRow& query, + const DocValueComputer& computer) const override { float distance = 0.0f; for (size_t i = 0; i < query.size(); ++i) { @@ -445,11 +465,12 @@ class InvertedIndex : public BaseInvertedIndex { if (dim_id == dim_map_.end()) { continue; } - auto& lut = inverted_lut_[dim_id->second]; - auto it = - std::lower_bound(lut.begin(), lut.end(), vec_id, [](const auto& x, table_t y) { return x.id < y; }); - if (it != lut.end() && it->id == vec_id) { - distance += val * computer(it->val, bm25 ? bm25_params_->row_sums.at(vec_id) : 0); + auto& plist_ids = inverted_index_ids_[dim_id->second]; + auto it = std::lower_bound(plist_ids.begin(), plist_ids.end(), vec_id, + [](const auto& x, table_t y) { return x < y; }); + if (it != plist_ids.end() && *it == vec_id) { + distance += val * computer(inverted_index_vals_[dim_id->second][it - plist_ids.begin()], + bm25 ? bm25_params_->row_sums.at(vec_id) : 0); } } @@ -465,9 +486,15 @@ class InvertedIndex : public BaseInvertedIndex { if constexpr (mmapped) { return res + map_byte_size_; } else { - res += sizeof(typename decltype(inverted_lut_)::value_type) * inverted_lut_.capacity(); - for (size_t i = 0; i < inverted_lut_.size(); ++i) { - res += sizeof(typename decltype(inverted_lut_)::value_type::value_type) * inverted_lut_[i].capacity(); + res += sizeof(typename decltype(inverted_index_ids_)::value_type) * inverted_index_ids_.capacity(); + for (size_t i = 0; i < inverted_index_ids_.size(); ++i) { + res += sizeof(typename decltype(inverted_index_ids_)::value_type::value_type) * + inverted_index_ids_[i].capacity(); + } + res += sizeof(typename decltype(inverted_index_vals_)::value_type) * inverted_index_vals_.capacity(); + for (size_t i = 0; i < inverted_index_vals_.size(); ++i) { + res += sizeof(typename decltype(inverted_index_vals_)::value_type::value_type) * + inverted_index_vals_[i].capacity(); } if constexpr (use_wand) { res += sizeof(typename decltype(max_score_in_dim_)::value_type) * max_score_in_dim_.capacity(); @@ -495,8 +522,8 @@ class InvertedIndex : public BaseInvertedIndex { // Given a vector of values, returns the threshold value. // All values strictly smaller than the threshold will be ignored. // values will be modified in this function. - inline T - get_threshold(std::vector& values, float drop_ratio) const { + inline DType + get_threshold(std::vector& values, float drop_ratio) const { // drop_ratio is in [0, 1) thus drop_count is guaranteed to be less // than values.size(). auto drop_count = static_cast(drop_ratio * values.size()); @@ -509,7 +536,8 @@ class InvertedIndex : public BaseInvertedIndex { } std::vector - compute_all_distances(const SparseRow& q_vec, T q_threshold, const DocValueComputer& computer) const { + compute_all_distances(const SparseRow& q_vec, DType q_threshold, + const DocValueComputer& computer) const { std::vector scores(n_rows_internal_, 0.0f); for (size_t idx = 0; idx < q_vec.size(); ++idx) { auto [i, v] = q_vec[idx]; @@ -520,24 +548,27 @@ class InvertedIndex : public BaseInvertedIndex { if (dim_id == dim_map_.end()) { continue; } - auto& lut = inverted_lut_[dim_id->second]; + auto& plist_ids = inverted_index_ids_[dim_id->second]; + auto& plist_vals = inverted_index_vals_[dim_id->second]; // TODO: improve with SIMD - for (size_t j = 0; j < lut.size(); j++) { - auto [doc_id, val] = lut[j]; - T val_sum = bm25 ? bm25_params_->row_sums.at(doc_id) : 0; + for (size_t j = 0; j < plist_ids.size(); ++j) { + auto doc_id = plist_ids[j]; + auto val = plist_vals[j]; + float val_sum = bm25 ? bm25_params_->row_sums.at(doc_id) : 0; scores[doc_id] += v * computer(val, val_sum); } } return scores; } - // LUT supports size() and operator[] which returns an SparseIdVal. - template + template struct Cursor { public: - Cursor(const LUT& lut, size_t num_vec, float max_score, float q_value, DocIdFilter filter) - : lut_(lut), - lut_size_(lut.size()), + Cursor(const Vector& plist_ids, const Vector& plist_vals, size_t num_vec, float max_score, + float q_value, DocIdFilter filter) + : plist_ids_(plist_ids), + plist_vals_(plist_vals), + plist_size_(plist_ids.size()), total_num_vec_(num_vec), max_score_(max_score), q_value_(q_value), @@ -555,23 +586,23 @@ class InvertedIndex : public BaseInvertedIndex { update_cur_vec_id(); } - // advance loc until cur_vec_id_ >= vec_id void seek(table_t vec_id) { - while (loc_ < lut_size_ && lut_[loc_].id < vec_id) { + while (loc_ < plist_size_ && plist_ids_[loc_] < vec_id) { ++loc_; } skip_filtered_ids(); update_cur_vec_id(); } - T + QType cur_vec_val() const { - return lut_[loc_].val; + return plist_vals_[loc_]; } - const LUT& lut_; - const size_t lut_size_; + const Vector& plist_ids_; + const Vector& plist_vals_; + const size_t plist_size_; size_t loc_ = 0; size_t total_num_vec_ = 0; float max_score_ = 0.0f; @@ -582,16 +613,12 @@ class InvertedIndex : public BaseInvertedIndex { private: inline void update_cur_vec_id() { - if (loc_ >= lut_size_) { - cur_vec_id_ = total_num_vec_; - } else { - cur_vec_id_ = lut_[loc_].id; - } + cur_vec_id_ = (loc_ >= plist_size_) ? total_num_vec_ : plist_ids_[loc_]; } inline void skip_filtered_ids() { - while (loc_ < lut_size_ && !filter_.empty() && filter_.test(lut_[loc_].id)) { + while (loc_ < plist_size_ && !filter_.empty() && filter_.test(plist_ids_[loc_])) { ++loc_; } } @@ -602,8 +629,8 @@ class InvertedIndex : public BaseInvertedIndex { // TODO: may switch to row-wise brute force if filter rate is high. Benchmark needed. template void - search_brute_force(const SparseRow& q_vec, T q_threshold, MaxMinHeap& heap, DocIdFilter& filter, - const DocValueComputer& computer) const { + search_brute_force(const SparseRow& q_vec, DType q_threshold, MaxMinHeap& heap, DocIdFilter& filter, + const DocValueComputer& computer) const { auto scores = compute_all_distances(q_vec, q_threshold, computer); for (size_t i = 0; i < n_rows_internal_; ++i) { if ((filter.empty() || !filter.test(i)) && scores[i] != 0) { @@ -615,11 +642,10 @@ class InvertedIndex : public BaseInvertedIndex { // any value in q_vec that is smaller than q_threshold will be ignored. template void - search_wand(const SparseRow& q_vec, T q_threshold, MaxMinHeap& heap, DocIdFilter& filter, - const DocValueComputer& computer) const { + search_wand(const SparseRow& q_vec, DType q_threshold, MaxMinHeap& heap, DocIdFilter& filter, + const DocValueComputer& computer) const { auto q_dim = q_vec.size(); - std::vector>> cursors( - q_dim); + std::vector>> cursors(q_dim); size_t valid_q_dim = 0; for (size_t i = 0; i < q_dim; ++i) { auto [idx, val] = q_vec[i]; @@ -627,9 +653,10 @@ class InvertedIndex : public BaseInvertedIndex { if (dim_id == dim_map_.end() || std::abs(val) < q_threshold) { continue; } - auto& lut = inverted_lut_[dim_id->second]; - cursors[valid_q_dim++] = std::make_shared>( - lut, n_rows_internal_, max_score_in_dim_[dim_id->second] * val, val, filter); + auto& plist_ids = inverted_index_ids_[dim_id->second]; + auto& plist_vals = inverted_index_vals_[dim_id->second]; + cursors[valid_q_dim++] = std::make_shared>( + plist_ids, plist_vals, n_rows_internal_, max_score_in_dim_[dim_id->second] * val, val, filter); } if (valid_q_dim == 0) { return; @@ -645,7 +672,7 @@ class InvertedIndex : public BaseInvertedIndex { size_t pivot; bool found_pivot = false; for (pivot = 0; pivot < valid_q_dim; ++pivot) { - if (cursors[pivot]->loc_ >= cursors[pivot]->lut_size_) { + if (cursors[pivot]->loc_ >= cursors[pivot]->plist_size_) { break; } upper_bound += cursors[pivot]->max_score_; @@ -664,7 +691,7 @@ class InvertedIndex : public BaseInvertedIndex { if (cursor->cur_vec_id_ != pivot_id) { break; } - T cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor->cur_vec_id_) : 0; + float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor->cur_vec_id_) : 0; score += cursor->q_value_ * computer(cursor->cur_vec_val(), cur_vec_sum); cursor->next(); } @@ -686,15 +713,14 @@ class InvertedIndex : public BaseInvertedIndex { } void - refine_and_collect(const SparseRow& q_vec, MaxMinHeap& inacc_heap, size_t k, float* distances, - label_t* labels, const DocValueComputer& computer) const { - MaxMinHeap heap(k); + refine_and_collect(const SparseRow& q_vec, MaxMinHeap& inacc_heap, size_t k, float* distances, + label_t* labels, const DocValueComputer& computer) const { + MaxMinHeap heap(k); std::vector docids; while (!inacc_heap.empty()) { - auto [u, d] = inacc_heap.top(); + auto u = inacc_heap.pop(); docids.push_back(u); - inacc_heap.pop(); } DocIdFilterByVector filter(std::move(docids)); @@ -718,8 +744,8 @@ class InvertedIndex : public BaseInvertedIndex { } inline void - add_row_to_index(const SparseRow& row, table_t vec_id) { - [[maybe_unused]] T row_sum = 0; + add_row_to_index(const SparseRow& row, table_t vec_id) { + [[maybe_unused]] float row_sum = 0; for (size_t j = 0; j < row.size(); ++j) { auto [idx, val] = row[j]; if constexpr (bm25) { @@ -737,14 +763,16 @@ class InvertedIndex : public BaseInvertedIndex { } dim_it = dim_map_.insert({idx, next_dim_id_++}).first; dim_map_reverse_[next_dim_id_ - 1] = idx; - inverted_lut_.emplace_back(); + inverted_index_ids_.emplace_back(); + inverted_index_vals_.emplace_back(); if constexpr (use_wand) { - max_score_in_dim_.emplace_back(0); + max_score_in_dim_.emplace_back(0.0f); } } - inverted_lut_[dim_it->second].emplace_back(vec_id, val); + inverted_index_ids_[dim_it->second].emplace_back(vec_id); + inverted_index_vals_[dim_it->second].emplace_back(get_quant_val(val)); if constexpr (use_wand) { - auto score = val; + auto score = static_cast(val); if constexpr (bm25) { score = bm25_params_->max_score_ratio * bm25_params_->wand_max_score_computer(val, row_sum); } @@ -756,16 +784,30 @@ class InvertedIndex : public BaseInvertedIndex { } } + inline QType + get_quant_val(DType val) const { + if constexpr (!std::is_same_v) { + const DType max_val = static_cast(std::numeric_limits::max()); + if (val >= max_val) { + return std::numeric_limits::max(); + } else if (val <= std::numeric_limits::min()) { + return std::numeric_limits::min(); + } else { + return static_cast(val); + } + } else { + return val; + } + } + // key is raw sparse vector dim/idx, value is the mapped dim/idx id in the index. std::unordered_map dim_map_; std::unordered_map dim_map_reverse_; - template - using Vector = std::conditional_t, std::vector>; - // reserve, [], size, emplace_back - Vector>> inverted_lut_; - Vector max_score_in_dim_; + Vector> inverted_index_ids_; + Vector> inverted_index_vals_; + Vector max_score_in_dim_; size_t n_rows_internal_ = 0; size_t max_dim_ = 0; @@ -780,17 +822,17 @@ class InvertedIndex : public BaseInvertedIndex { float b; // row_sums is used to cache the sum of values of each row, which // corresponds to the document length of each doc in the BM25 formula. - Vector row_sums; + Vector row_sums; // below are used only for WAND index. float max_score_ratio; - DocValueComputer wand_max_score_computer; + DocValueComputer wand_max_score_computer; BM25Params(float k1, float b, float avgdl, float max_score_ratio) : k1(k1), b(b), max_score_ratio(max_score_ratio), - wand_max_score_computer(GetDocValueBM25Computer(k1, b, avgdl)) { + wand_max_score_computer(GetDocValueBM25Computer(k1, b, avgdl)) { } }; // struct BM25Params diff --git a/tests/ut/test_sparse.cc b/tests/ut/test_sparse.cc index 83a73eb29..c84886b32 100644 --- a/tests/ut/test_sparse.cc +++ b/tests/ut/test_sparse.cc @@ -45,9 +45,10 @@ TEST_CASE("Test Mem Sparse Index With Float Vector", "[float metrics]") { auto topk = 5; int64_t nq = 10; - auto drop_ratio_search = GENERATE(0.0, 0.3); - auto metric = GENERATE(knowhere::metric::IP, knowhere::metric::BM25); + + auto drop_ratio_search = metric == knowhere::metric::BM25 ? GENERATE(0.0, 0.1) : GENERATE(0.0, 0.3); + auto version = GenTestVersionList(); auto base_gen = [=, dim = dim]() { @@ -67,9 +68,16 @@ TEST_CASE("Test Mem Sparse Index With Float Vector", "[float metrics]") { return json; }; - const auto train_ds = GenSparseDataSet(nb, dim, doc_sparsity); - // it is possible the query has more dims than the train dataset. - const auto query_ds = GenSparseDataSet(nq, dim + 20, query_sparsity); + auto sparse_dataset_gen = [&](int nr, int dim, float sparsity) -> knowhere::DataSetPtr { + if (metric == knowhere::metric::BM25) { + return GenSparseDataSetWithMaxVal(nr, dim, sparsity, 256, true); + } else { + return GenSparseDataSet(nr, dim, sparsity); + } + }; + + auto train_ds = sparse_dataset_gen(nb, dim, doc_sparsity); + auto query_ds = sparse_dataset_gen(nq, dim + 20, query_sparsity); const knowhere::Json conf = { {knowhere::meta::METRIC_TYPE, metric}, {knowhere::meta::TOPK, topk}, {knowhere::meta::BM25_K1, 1.2}, @@ -251,10 +259,15 @@ TEST_CASE("Test Mem Sparse Index With Float Vector", "[float metrics]") { REQUIRE(idx.Size() > 0); REQUIRE(idx.Count() == nb); - auto [radius, range_filter] = GENERATE(table({ - {0.5, 1}, - {1, 1.5}, - })); + auto [radius, range_filter] = metric == knowhere::metric::BM25 ? GENERATE(table({ + {80.0, 100.0}, + {100.0, 200.0}, + })) + : GENERATE(table({ + {0.5, 1}, + {1, 1.5}, + })); + json[knowhere::meta::RADIUS] = radius; json[knowhere::meta::RANGE_FILTER] = range_filter; diff --git a/tests/ut/utils.h b/tests/ut/utils.h index 3bbc3d794..3b62ba7ef 100644 --- a/tests/ut/utils.h +++ b/tests/ut/utils.h @@ -380,6 +380,35 @@ GenSparseDataSet(int32_t rows, int32_t cols, float sparsity, int seed = 42) { return GenSparseDataSet(data, cols); } +// Generate a sparse dataset with given sparsity and max value. +inline knowhere::DataSetPtr +GenSparseDataSetWithMaxVal(int32_t rows, int32_t cols, float sparsity, float max_val, bool use_bm25 = false, + int seed = 42) { + int32_t num_elements = static_cast(rows * cols * (1.0f - sparsity)); + + std::mt19937 rng(seed); + auto real_distrib = std::uniform_real_distribution(0, max_val); + auto row_distrib = std::uniform_int_distribution(0, rows - 1); + auto col_distrib = std::uniform_int_distribution(0, cols - 1); + + std::vector> data(rows); + + for (int32_t i = 0; i < num_elements; ++i) { + auto row = row_distrib(rng); + while (data[row].size() == (size_t)cols) { + row = row_distrib(rng); + } + auto col = col_distrib(rng); + while (data[row].find(col) != data[row].end()) { + col = col_distrib(rng); + } + auto val = use_bm25 ? static_cast(static_cast(real_distrib(rng))) : real_distrib(rng); + data[row][col] = val; + } + + return GenSparseDataSet(data, cols); +} + // a timer struct StopWatch { using timepoint_t = std::chrono::time_point;