Skip to content

Commit

Permalink
Support IVF search: Part 2 (#1984)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Support IVF search:
* l2, ip, cos metric
* real-time index

Close #1917

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
- [x] Test cases
  • Loading branch information
yangzq50 authored Oct 8, 2024
1 parent ead2460 commit edd490a
Show file tree
Hide file tree
Showing 19 changed files with 413 additions and 264 deletions.
2 changes: 1 addition & 1 deletion src/executor/operator/physical_scan/physical_knn_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ void PhysicalKnnScan::ExecuteInternalByColumnDataTypeAndQueryDataType(QueryConte
switch (segment_index_entry->table_index_entry()->index_base()->index_type_) {
case IndexType::kIVF: {
const SegmentOffset max_segment_offset = block_index->GetSegmentOffset(segment_id);
const auto ivf_search_params = IVF_Search_Params::Make(knn_scan_shared_data);
const auto ivf_search_params = IVF_Search_Params::Make(knn_scan_function_data);
auto ivf_result_handler =
GetIVFSearchHandler<t, C, DistanceDataType>(ivf_search_params, use_bitmask, bitmask, max_segment_offset);
ivf_result_handler->Begin();
Expand Down
2 changes: 1 addition & 1 deletion src/function/table/knn_filter.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export class AppendFilter final : public FilterBase<SegmentOffset> {
public:
AppendFilter(SegmentOffset max_segment_offset) : max_segment_offset_(max_segment_offset) {}

bool operator()(const SegmentOffset &segment_offset) const final { return segment_offset <= max_segment_offset_; }
bool operator()(const SegmentOffset &segment_offset) const final { return segment_offset < max_segment_offset_; }

private:
const SegmentOffset max_segment_offset_;
Expand Down
4 changes: 2 additions & 2 deletions src/function/table/knn_scan_data.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ class KnnDistance1 : public KnnDistanceBase1 {
public:
KnnDistance1(KnnDistanceType dist_type);

Vector<DistType> Calculate(const QueryDataType *datas, SizeT data_count, const QueryDataType *query, SizeT dim) {
Vector<DistType> Calculate(const QueryDataType *datas, SizeT data_count, const QueryDataType *query, SizeT dim) const {
Vector<DistType> res(data_count);
for (SizeT i = 0; i < data_count; ++i) {
res[i] = dist_func_(query, datas + i * dim, dim);
}
return res;
}

Vector<DistType> Calculate(const QueryDataType *datas, SizeT data_count, const QueryDataType *query, SizeT dim, Bitmask &bitmask) {
Vector<DistType> Calculate(const QueryDataType *datas, SizeT data_count, const QueryDataType *query, SizeT dim, Bitmask &bitmask) const {
Vector<DistType> res(data_count);
for (SizeT i = 0; i < data_count; ++i) {
if (bitmask.IsTrue(i)) {
Expand Down
1 change: 1 addition & 0 deletions src/storage/knn_index/knn_ivf/ivf_index_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import column_vector;
import logger;
import kmeans_partition;
import logical_type;
import ivf_index_util_func;

namespace infinity {

Expand Down
102 changes: 78 additions & 24 deletions src/storage/knn_index/knn_ivf/ivf_index_data_in_mem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

module;

#include <vector>
module ivf_index_data_in_mem;

import stl;
Expand All @@ -37,26 +38,10 @@ import search_top_1;
import column_vector;
import ivf_index_data;
import buffer_handle;
import knn_scan_data;
import ivf_index_util_func;

namespace infinity {
template <IsAnyOf<u8, i8, f64, f32, Float16T, BFloat16T> ColumnEmbeddingElementT>
Pair<const f32 *, UniquePtr<f32[]>> GetF32Ptr(const ColumnEmbeddingElementT *src_data_ptr, const u32 src_data_cnt) {
Pair<const f32 *, UniquePtr<f32[]>> dst_data_ptr;
if constexpr (std::is_same_v<f32, ColumnEmbeddingElementT>) {
dst_data_ptr.first = src_data_ptr;
} else {
dst_data_ptr.second = MakeUniqueForOverwrite<f32[]>(src_data_cnt);
dst_data_ptr.first = dst_data_ptr.second.get();
for (u32 i = 0; i < src_data_cnt; ++i) {
if constexpr (std::is_same_v<f64, ColumnEmbeddingElementT>) {
dst_data_ptr.second[i] = static_cast<f32>(src_data_ptr[i]);
} else {
dst_data_ptr.second[i] = src_data_ptr[i];
}
}
}
return dst_data_ptr;
}

IVFIndexInMem::IVFIndexInMem(const RowID begin_row_id,
const IndexIVFOption &ivf_option,
Expand Down Expand Up @@ -169,6 +154,7 @@ class IVFIndexInMemT final : public IVFIndexInMem {
}

void BuildIndex() {
LOG_TRACE("Start building in-memory IVF index");
if (have_ivf_index_.test(std::memory_order_acquire)) {
UnrecoverableError("Already have index");
}
Expand Down Expand Up @@ -210,11 +196,78 @@ class IVFIndexInMemT final : public IVFIndexInMem {
return new_chunk_index_entry;
}

void SearchIndexInMem(KnnDistanceType knn_distance_type,
void SearchIndexInMem(const KnnDistanceBase1 *knn_distance,
const void *query_ptr,
EmbeddingDataType query_element_type,
const EmbeddingDataType query_element_type,
std::function<bool(SegmentOffset)> satisfy_filter_func,
std::function<void(f32, SegmentOffset)> add_result_func) const override {
// TODO
auto ReturnT = [&]<EmbeddingDataType query_element_type> {
if constexpr ((query_element_type == EmbeddingDataType::kElemFloat && IsAnyOf<ColumnEmbeddingElementT, f64, f32, Float16T, BFloat16T>) ||
(query_element_type == embedding_data_type &&
(query_element_type == EmbeddingDataType::kElemInt8 || query_element_type == EmbeddingDataType::kElemUInt8))) {
return SearchIndexInMemT<query_element_type>(knn_distance,
static_cast<const EmbeddingDataTypeToCppTypeT<query_element_type> *>(query_ptr),
satisfy_filter_func,
add_result_func);
} else {
UnrecoverableError("Invalid Query EmbeddingDataType");
}
};
switch (query_element_type) {
case EmbeddingDataType::kElemFloat: {
return ReturnT.template operator()<EmbeddingDataType::kElemFloat>();
}
case EmbeddingDataType::kElemUInt8: {
return ReturnT.template operator()<EmbeddingDataType::kElemUInt8>();
}
case EmbeddingDataType::kElemInt8: {
return ReturnT.template operator()<EmbeddingDataType::kElemInt8>();
}
default: {
UnrecoverableError("Invalid EmbeddingDataType");
}
}
}

template <EmbeddingDataType query_element_type>
void SearchIndexInMemT(const KnnDistanceBase1 *knn_distance,
const EmbeddingDataTypeToCppTypeT<query_element_type> *query_ptr,
std::function<bool(SegmentOffset)> satisfy_filter_func,
std::function<void(f32, SegmentOffset)> add_result_func) const {
using QueryDataType = EmbeddingDataTypeToCppTypeT<query_element_type>;
auto knn_distance_1 = dynamic_cast<const KnnDistance1<QueryDataType, f32> *>(knn_distance);
if (!knn_distance_1) [[unlikely]] {
UnrecoverableError("Invalid KnnDistance1");
}
if constexpr (column_logical_type == LogicalType::kEmbedding) {
auto dist_func = knn_distance_1->dist_func_;
for (u32 i = 0; i < in_mem_storage_.source_offsets_.size(); ++i) {
const auto segment_offset = in_mem_storage_.source_offsets_[i];
if (!satisfy_filter_func(segment_offset)) {
continue;
}
auto v_ptr = in_mem_storage_.raw_source_data_.data() + i * embedding_dimension();
auto [calc_ptr, _] = GetSearchCalcPtr<QueryDataType>(v_ptr, embedding_dimension());
auto d = dist_func(calc_ptr, query_ptr, embedding_dimension());
add_result_func(d, segment_offset);
}
} else if constexpr (column_logical_type == LogicalType::kMultiVector) {
for (u32 i = 0; i < in_mem_storage_.source_offsets_.size(); ++i) {
const auto segment_offset = in_mem_storage_.source_offsets_[i];
if (!satisfy_filter_func(segment_offset)) {
continue;
}
auto mv_ptr = in_mem_storage_.raw_source_data_.data() + in_mem_storage_.multi_vector_data_start_pos_[i];
auto mv_num = in_mem_storage_.multi_vector_embedding_num_[i];
auto [calc_ptr, _] = GetSearchCalcPtr<QueryDataType>(mv_ptr, mv_num * embedding_dimension());
auto dists = knn_distance_1->Calculate(calc_ptr, mv_num, query_ptr, embedding_dimension());
for (const auto d : dists) {
add_result_func(d, segment_offset);
}
}
} else {
static_assert(false);
}
}
};

Expand Down Expand Up @@ -267,16 +320,17 @@ SharedPtr<IVFIndexInMem> IVFIndexInMem::NewIVFIndexInMem(const ColumnDef *column
return {};
}

void IVFIndexInMem::SearchIndex(const KnnDistanceType knn_distance_type,
void IVFIndexInMem::SearchIndex(const KnnDistanceBase1 *knn_distance,
const void *query_ptr,
const EmbeddingDataType query_element_type,
const u32 nprobe,
std::function<bool(SegmentOffset)> satisfy_filter_func,
std::function<void(f32, SegmentOffset)> add_result_func) const {
std::shared_lock lock(rw_mutex_);
if (have_ivf_index_.test(std::memory_order_acquire)) {
ivf_index_storage_->SearchIndex(knn_distance_type, query_ptr, query_element_type, nprobe, add_result_func);
ivf_index_storage_->SearchIndex(knn_distance, query_ptr, query_element_type, nprobe, satisfy_filter_func, add_result_func);
} else {
SearchIndexInMem(knn_distance_type, query_ptr, query_element_type, add_result_func);
SearchIndexInMem(knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func);
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/storage/knn_index/knn_ivf/ivf_index_data_in_mem.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import ivf_index_storage;
import column_def;
import logical_type;
import buffer_handle;
import knn_expr;

namespace infinity {

Expand All @@ -32,6 +31,7 @@ class BufferManager;
class ChunkIndexEntry;
class SegmentIndexEntry;
class IndexBase;
class KnnDistanceBase1;

export class IVFIndexInMem {
protected:
Expand Down Expand Up @@ -61,17 +61,19 @@ public:
u32 row_offset,
u32 row_count) = 0;
virtual SharedPtr<ChunkIndexEntry> Dump(SegmentIndexEntry *segment_index_entry, BufferManager *buffer_mgr) = 0;
void SearchIndex(KnnDistanceType knn_distance_type,
void SearchIndex(const KnnDistanceBase1 *knn_distance,
const void *query_ptr,
EmbeddingDataType query_element_type,
u32 nprobe,
std::function<bool(SegmentOffset)> satisfy_filter_func,
std::function<void(f32, SegmentOffset)> add_result_func) const;
static SharedPtr<IVFIndexInMem> NewIVFIndexInMem(const ColumnDef *column_def, const IndexBase *index_base, RowID begin_row_id);

private:
virtual void SearchIndexInMem(KnnDistanceType knn_distance_type,
virtual void SearchIndexInMem(const KnnDistanceBase1 *knn_distance,
const void *query_ptr,
EmbeddingDataType query_element_type,
std::function<bool(SegmentOffset)> satisfy_filter_func,
std::function<void(f32, SegmentOffset)> add_result_func) const = 0;
};

Expand Down
5 changes: 4 additions & 1 deletion src/storage/knn_index/knn_ivf/ivf_index_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ import ivf_index_storage;

namespace infinity {

IVF_Search_Params IVF_Search_Params::Make(const KnnScanSharedData *knn_scan_shared_data) {
IVF_Search_Params IVF_Search_Params::Make(const KnnScanFunctionData *knn_scan_function_data) {
IVF_Search_Params params;
params.knn_distance_ = knn_scan_function_data->knn_distance_.get();
const auto *knn_scan_shared_data = knn_scan_function_data->knn_scan_shared_data_;
params.knn_scan_shared_data_ = knn_scan_shared_data;
if (knn_scan_shared_data->query_count_ != 1) {
RecoverableError(Status::SyntaxError(fmt::format("Invalid query_count: {} which is not 1.", knn_scan_shared_data->query_count_)));
}
Expand Down
31 changes: 23 additions & 8 deletions src/storage/knn_index/knn_ivf/ivf_index_search.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ import search_top_k;
namespace infinity {

export struct IVF_Search_Params {
KnnScanSharedData *knn_scan_shared_data_{};
const KnnDistanceBase1 *knn_distance_{};
const KnnScanSharedData *knn_scan_shared_data_{};
i64 topk_{};
void *query_embedding_{};
const void *query_embedding_{};
EmbeddingDataType query_elem_type_{EmbeddingDataType::kElemInvalid};
KnnDistanceType knn_distance_type_{KnnDistanceType::kInvalid};
i32 nprobe_{1};

static IVF_Search_Params Make(const KnnScanSharedData *knn_scan_shared_data);
static IVF_Search_Params Make(const KnnScanFunctionData *knn_scan_function_data);
};

export template <typename DistanceDataType>
Expand Down Expand Up @@ -81,14 +82,14 @@ template <>
struct IVF_Filter<true> {
BitmaskFilter<SegmentOffset> filter_;
IVF_Filter(const Bitmask &bitmask, const SegmentOffset max_segment_offset) : filter_(bitmask) {}
bool operator()(const SegmentOffset &segment_offset) const { return filter_(segment_offset); }
bool operator()(const SegmentOffset segment_offset) const { return filter_(segment_offset); }
};

template <>
struct IVF_Filter<false> {
AppendFilter filter_;
IVF_Filter(const Bitmask &bitmask, const SegmentOffset max_segment_offset) : filter_(max_segment_offset) {}
bool operator()(const SegmentOffset &segment_offset) const { return filter_(segment_offset); }
bool operator()(const SegmentOffset segment_offset) const { return filter_(segment_offset); }
};

template <LogicalType t,
Expand All @@ -98,6 +99,7 @@ template <LogicalType t,
bool use_bitmask,
typename MultiVectorInnerTopnIndexType = void>
class IVF_Search_HandlerT final : public IVF_Search_Handler<DistanceDataType> {
static_assert(std::is_same_v<DistanceDataType, f32>); // KnnDistanceBase1 type?
static_assert(t == LogicalType::kEmbedding || t == LogicalType::kMultiVector);
static constexpr bool NEED_FLIP = !std::is_same_v<CompareMax<DistanceDataType, SegmentOffset>, C<DistanceDataType, SegmentOffset>>;
using ResultHandler = std::conditional_t<t == LogicalType::kEmbedding,
Expand All @@ -113,20 +115,27 @@ public:
void Begin() override { result_handler_.Begin(); }
void Search(const IVFIndexInChunk *ivf_index_in_chunk) override {
const auto *ivf_index_storage = ivf_index_in_chunk->GetIVFIndexStoragePtr();
ivf_index_storage->SearchIndex(this->ivf_params_.knn_distance_type_,
ivf_index_storage->SearchIndex(this->ivf_params_.knn_distance_,
this->ivf_params_.query_embedding_,
this->ivf_params_.query_elem_type_,
this->ivf_params_.nprobe_,
std::bind(&IVF_Search_HandlerT::SatisfyFilter, this, std::placeholders::_1),
std::bind(&IVF_Search_HandlerT::AddResult, this, std::placeholders::_1, std::placeholders::_2));
}
void Search(const IVFIndexInMem *ivf_index_in_mem) override {
ivf_index_in_mem->SearchIndex(this->ivf_params_.knn_distance_type_,
ivf_index_in_mem->SearchIndex(this->ivf_params_.knn_distance_,
this->ivf_params_.query_embedding_,
this->ivf_params_.query_elem_type_,
this->ivf_params_.nprobe_,
std::bind(&IVF_Search_HandlerT::SatisfyFilter, this, std::placeholders::_1),
std::bind(&IVF_Search_HandlerT::AddResult, this, std::placeholders::_1, std::placeholders::_2));
}
bool SatisfyFilter(SegmentOffset i) { return filter_(i); }
void AddResult(DistanceDataType d, SegmentOffset i) {
assert(SatisfyFilter(i));
if constexpr (NEED_FLIP) {
d = -d;
}
if constexpr (t == LogicalType::kEmbedding) {
result_handler_.AddResult(0, d, i);
} else {
Expand All @@ -136,7 +145,13 @@ public:
}
SizeT EndWithoutSortAndGetResultSize() override {
result_handler_.EndWithoutSort();
return result_handler_.GetSize(0);
const auto result_cnt = result_handler_.GetSize(0);
if constexpr (NEED_FLIP) {
for (u32 i = 0; i < result_cnt; ++i) {
this->distance_output_ptr_[i] = -(this->distance_output_ptr_[i]);
}
}
return result_cnt;
}
};

Expand Down
Loading

0 comments on commit edd490a

Please sign in to comment.