Skip to content

Commit

Permalink
enhance: scann support iterator (zilliztech#992)
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 authored Dec 24, 2024
1 parent 0341b62 commit d448fd6
Show file tree
Hide file tree
Showing 9 changed files with 324 additions and 8 deletions.
24 changes: 21 additions & 3 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class IvfIndexNode : public IndexNode {
}

private:
// only support IVFFlat,IVFFlatCC, IVFSQ and IVFSQCC
// only support IVFFlat,IVFFlatCC, IVFSQ, IVFSQCC and SCANN
// iterator will own the copied_norm_query
// TODO: iterator should copy and own query data.
// TODO: If SCANN support Iterator, raw_distance() function should be override.
Expand Down Expand Up @@ -339,6 +339,18 @@ class IvfIndexNode : public IndexNode {
workspace_->dists.clear();
}

float
raw_distance(int64_t id) override {
if constexpr (std::is_same_v<IndexType, faiss::IndexScaNN>) {
if (refine_) {
return workspace_->dis_refine->operator()(id);
} else {
throw std::runtime_error("raw_distance should not be called if refine == false");
}
}
throw std::runtime_error("raw_distance not implemented");
}

private:
const IndexType* index_ = nullptr;
std::unique_ptr<faiss::IVFIteratorWorkspace> workspace_ = nullptr;
Expand Down Expand Up @@ -923,9 +935,10 @@ IvfIndexNode<DataType, IndexType>::AnnIterator(const DataSetPtr dataset, std::un
if constexpr (!std::is_same<faiss::IndexIVFFlatCC, IndexType>::value &&
!std::is_same<faiss::IndexIVFFlat, IndexType>::value &&
!std::is_same<faiss::IndexIVFScalarQuantizer, IndexType>::value &&
!std::is_same<faiss::IndexIVFScalarQuantizerCC, IndexType>::value) {
!std::is_same<faiss::IndexIVFScalarQuantizerCC, IndexType>::value &&
!std::is_same<faiss::IndexScaNN, IndexType>::value) {
LOG_KNOWHERE_WARNING_ << "Current index_type: " << Type()
<< ", only IVFFlat, IVFFlatCC, IVF_SQ8 and IVF_SQ_CC support Iterator.";
<< ", only IVFFlat, IVFFlatCC, IVF_SQ8, IVF_SQ_CC and SCANN support Iterator.";
return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::not_implemented, "index not supported");
} else {
auto dim = dataset->GetDim();
Expand All @@ -942,6 +955,11 @@ IvfIndexNode<DataType, IndexType>::AnnIterator(const DataSetPtr dataset, std::un
// set iterator_refine_ratio = 0.0. If quantizer != flat, faiss:indexivf will not keep raw data;
// TODO: if SCANN support Iterator, iterator_refine_ratio should be set.
float iterator_refine_ratio = 0.0f;
if constexpr (std::is_same_v<IndexType, faiss::IndexScaNN>) {
if (HasRawData(ivf_cfg.metric_type.value())) {
iterator_refine_ratio = ivf_cfg.iterator_refine_ratio.value();
}
}
try {
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(rows);
Expand Down
20 changes: 20 additions & 0 deletions tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,20 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
return json;
};

auto scann_gen = [ivf_base_gen]() {
knowhere::Json json = ivf_base_gen();
json[knowhere::indexparam::NPROBE] = 14;
json[knowhere::indexparam::REORDER_K] = 200;
json[knowhere::indexparam::WITH_RAW_DATA] = true;
return json;
};

auto scann_gen2 = [ivf_base_gen]() {
knowhere::Json json = ivf_base_gen();
json[knowhere::indexparam::WITH_RAW_DATA] = false;
return json;
};

auto rand = GENERATE(1, 2);

const auto train_ds = GenDataSet(nb, dim, rand);
Expand All @@ -209,6 +223,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
Expand Down Expand Up @@ -295,6 +311,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
Expand Down Expand Up @@ -341,6 +359,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
Expand Down
12 changes: 12 additions & 0 deletions thirdparty/faiss/faiss/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
return list_no;
}

/*****************************************
* IVFIteratorWorkspace implementation
******************************************/
IVFIteratorWorkspace::IVFIteratorWorkspace(
const float* query_data,
const IVFSearchParameters* search_params)
: query_data(query_data),
search_params(search_params),
dis_refine(nullptr) {}

IVFIteratorWorkspace::~IVFIteratorWorkspace() {}

/*****************************************
* IndexIVF implementation
******************************************/
Expand Down
12 changes: 7 additions & 5 deletions thirdparty/faiss/faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,13 @@ struct SearchParametersIVF : SearchParameters {

// the new convention puts the index type after SearchParameters
using IVFSearchParameters = SearchParametersIVF;

struct DistanceComputer;
struct IVFIteratorWorkspace {
IVFIteratorWorkspace() = default;
IVFIteratorWorkspace(
const float* query_data,
const IVFSearchParameters* search_params)
: query_data(query_data), search_params(search_params) {}
const IVFSearchParameters* search_params);
virtual ~IVFIteratorWorkspace();

const float* query_data = nullptr; // single query
const IVFSearchParameters* search_params = nullptr;
Expand All @@ -112,6 +113,7 @@ struct IVFIteratorWorkspace {
nullptr; // backup coarse centroids ids (heap)
std::unique_ptr<size_t[]> coarse_list_sizes =
nullptr; // snapshot of the list_size
std::unique_ptr<DistanceComputer> dis_refine;
};

struct InvertedListScanner;
Expand Down Expand Up @@ -245,7 +247,7 @@ struct IndexIVF : Index, IndexIVFInterface {
size_t code_size,
MetricType metric = METRIC_L2);

std::unique_ptr<IVFIteratorWorkspace> getIteratorWorkspace(
virtual std::unique_ptr<IVFIteratorWorkspace> getIteratorWorkspace(
const float* query_data,
const IVFSearchParameters* ivfsearchParams) const;

Expand All @@ -255,7 +257,7 @@ struct IndexIVF : Index, IndexIVFInterface {
// iterator `Next()` operation.
// When there are not enough nodes in the heap, iterator will scan the
// next coarse list.
void getIteratorNextBatch(
virtual void getIteratorNextBatch(
IVFIteratorWorkspace* workspace,
size_t current_backup_count) const;

Expand Down
100 changes: 100 additions & 0 deletions thirdparty/faiss/faiss/IndexIVFFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,59 @@ void IndexIVFFastScan::range_search(
range_search_dispatch_implem(n, x, radius, *result, cq, nullptr, params);
}

std::unique_ptr<IVFIteratorWorkspace> IndexIVFFastScan::getIteratorWorkspace(
const float* query_data,
const IVFSearchParameters* ivfsearchParams) const {
auto base_workspace =
IndexIVF::getIteratorWorkspace(query_data, ivfsearchParams);

auto ivf_fast_scan_workspace =
std::make_unique<IVFFastScanIteratorWorkspace>(
std::move(base_workspace));

ivf_fast_scan_workspace->dim12 = ksub * M2;
CoarseQuantized cq{
ivf_fast_scan_workspace->nprobe,
ivf_fast_scan_workspace->coarse_dis.get(),
ivf_fast_scan_workspace->coarse_idx.get()};
compute_LUT_uint8(
1,
ivf_fast_scan_workspace->query_data,
cq,
ivf_fast_scan_workspace->dis_tables,
ivf_fast_scan_workspace->biases,
ivf_fast_scan_workspace->normalizers);
return ivf_fast_scan_workspace;
}

void IndexIVFFastScan::getIteratorNextBatch(
IVFIteratorWorkspace* workspace,
size_t current_backup_count) const {
auto ivf_fast_scan_workspace =
dynamic_cast<IVFFastScanIteratorWorkspace*>(workspace);
ivf_fast_scan_workspace->dists.clear();

std::unique_ptr<SIMDResultHandlerToFloat> handler;
bool is_max = !is_similarity_metric(metric_type);
auto id_selector = ivf_fast_scan_workspace->search_params->sel
? ivf_fast_scan_workspace->search_params->sel
: nullptr;
if (is_max) {
handler.reset(new SingleQueryResultCollectHandler<
CMax<uint16_t, int64_t>,
true>(
ivf_fast_scan_workspace->dists, ntotal, id_selector));
} else {
handler.reset(new SingleQueryResultCollectHandler<
CMin<uint16_t, int64_t>,
true>(
ivf_fast_scan_workspace->dists, ntotal, id_selector));
}

get_interator_next_batch_implem_10(
*handler.get(), ivf_fast_scan_workspace, current_backup_count);
}

namespace {

template <class C>
Expand Down Expand Up @@ -1701,6 +1754,53 @@ void IndexIVFFastScan::reconstruct_orig_invlists() {
}
}

void IndexIVFFastScan::get_interator_next_batch_implem_10(
SIMDResultHandlerToFloat& handler,
IVFFastScanIteratorWorkspace* workspace,
size_t current_backup_count) const {
bool single_LUT = !lookup_table_is_3d();
handler.begin(skip & 16 ? nullptr : workspace->normalizers);
auto dim12 = workspace->dim12;
const uint8_t* LUT = nullptr;

if (single_LUT) {
LUT = workspace->dis_tables.get();
}
while (current_backup_count + workspace->dists.size() <
workspace->backup_count_threshold &&
workspace->next_visit_coarse_list_idx < nlist) {
auto next_list_idx = workspace->next_visit_coarse_list_idx;
workspace->next_visit_coarse_list_idx++;
if (!single_LUT) {
LUT = workspace->dis_tables.get() + next_list_idx * dim12;
}
invlists->prefetch_lists(
workspace->coarse_idx.get() + next_list_idx, 1);
if (workspace->biases.get()) {
handler.dbias = workspace->biases.get() + next_list_idx;
}
idx_t list_no = workspace->coarse_idx[next_list_idx];
size_t ls = invlists->list_size(list_no);
if (list_no < 0 || ls == 0)
continue;

InvertedLists::ScopedCodes codes(invlists, list_no);
InvertedLists::ScopedIds ids(invlists, list_no);
handler.ntotal = ls;
handler.id_map = ids.get();
pq4_accumulate_loop(
1,
roundup(ls, bbs),
bbs,
M2,
codes.get(),
LUT,
handler,
nullptr);
}
handler.end();
}

// IVFFastScanStats IVFFastScan_stats;

} // namespace faiss
40 changes: 40 additions & 0 deletions thirdparty/faiss/faiss/IndexIVFFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,34 @@ struct SIMDResultHandlerToFloat;
* For range search, only 10 and 12 are supported.
* add 100 to the implem to force single-thread scanning (the coarse quantizer
* may still use multiple threads).
*
* For search interator, only 10 are supported, one query, no qbs
*/

struct IVFFastScanIteratorWorkspace : IVFIteratorWorkspace {
IVFFastScanIteratorWorkspace() = default;
IVFFastScanIteratorWorkspace(
const float* query_data,
const IVFSearchParameters* search_params)
: IVFIteratorWorkspace(query_data, search_params){};
IVFFastScanIteratorWorkspace(
std::unique_ptr<IVFIteratorWorkspace>&& base_workspace) {
this->query_data = base_workspace->query_data;
this->search_params = base_workspace->search_params;
this->nprobe = base_workspace->nprobe;
this->backup_count_threshold = base_workspace->backup_count_threshold;
this->coarse_dis = std::move(base_workspace->coarse_dis);
this->coarse_idx = std::move(base_workspace->coarse_idx);
this->coarse_list_sizes = std::move(base_workspace->coarse_list_sizes);
base_workspace = nullptr;
return;
}
size_t dim12;
AlignedTable<uint8_t> dis_tables;
AlignedTable<uint16_t> biases;
float normalizers[2];
};

struct IndexIVFFastScan : IndexIVF {
// size of the kernel
int bbs; // set at build time
Expand Down Expand Up @@ -147,6 +173,14 @@ struct IndexIVFFastScan : IndexIVF {
const IVFSearchParameters* params = nullptr,
IndexIVFStats* stats = nullptr) const override;

std::unique_ptr<IVFIteratorWorkspace> getIteratorWorkspace(
const float* query_data,
const IVFSearchParameters* ivfsearchParams) const override;

void getIteratorNextBatch(
IVFIteratorWorkspace* workspace,
size_t current_backup_count) const override;

// range_search implementation was introduced in Knowhere,
// diff 73f03354568b4bf5a370df6f37e8d56dfc3a9c85
void range_search(
Expand Down Expand Up @@ -243,6 +277,12 @@ struct IndexIVFFastScan : IndexIVF {
const NormTableScaler* scaler,
const IVFSearchParameters* params = nullptr) const;

// one query call, no qbs
void get_interator_next_batch_implem_10(
SIMDResultHandlerToFloat& handler,
IVFFastScanIteratorWorkspace* workspace,
size_t current_backup_count) const;

// implem 14 is multithreaded internally across nprobes and queries
void search_implem_14(
idx_t n,
Expand Down
35 changes: 35 additions & 0 deletions thirdparty/faiss/faiss/IndexScaNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <faiss/utils/Heap.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/utils.h>
#include <faiss/FaissHook.h>
#include <faiss/IndexCosine.h>

namespace faiss {

Expand Down Expand Up @@ -255,4 +257,37 @@ void IndexScaNN::range_search(
result->lims[1] = current;
}

std::unique_ptr<IVFIteratorWorkspace> IndexScaNN::getIteratorWorkspace(
const float* query_data,
const IVFSearchParameters* ivfsearchParams) const {
auto base = dynamic_cast<const IndexIVFPQFastScan*>(base_index);
auto iterator = base->getIteratorWorkspace(query_data, ivfsearchParams);
if (refine_index) {
auto refine = dynamic_cast<const IndexFlat*>(refine_index);
if (base->is_cosine) {
iterator->dis_refine = std::unique_ptr<faiss::DistanceComputer>(
new faiss::WithCosineNormDistanceComputer(
base->norms.data(),
base->d,
std::unique_ptr<faiss::DistanceComputer>(
refine->get_distance_computer())));
} else {
iterator->dis_refine = std::unique_ptr<faiss::DistanceComputer>(
refine->get_FlatCodesDistanceComputer());
}
iterator->dis_refine->set_query(query_data);
} else {
iterator->dis_refine = nullptr;
}

return iterator;
}

void IndexScaNN::getIteratorNextBatch(
IVFIteratorWorkspace* workspace,
size_t current_backup_count) const {
auto base = dynamic_cast<const IndexIVFPQFastScan*>(base_index);
return base->getIteratorNextBatch(workspace, current_backup_count);
}

} // namespace faiss
Loading

0 comments on commit d448fd6

Please sign in to comment.