diff --git a/knowhere/index/VecIndex.h b/knowhere/index/VecIndex.h index a7a6adcb1..c443640e1 100644 --- a/knowhere/index/VecIndex.h +++ b/knowhere/index/VecIndex.h @@ -43,6 +43,11 @@ class VecIndex : public Index { virtual void AddWithoutIds(const DatasetPtr& dataset, const Config& config) = 0; + virtual DatasetPtr + GetVectorById(const DatasetPtr& dataset, const Config& config) { + KNOWHERE_THROW_MSG("GetVectorById not supported yet"); + } + virtual DatasetPtr Query(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView bitset) = 0; diff --git a/knowhere/index/vector_index/IndexBinaryIDMAP.cpp b/knowhere/index/vector_index/IndexBinaryIDMAP.cpp index bd28d05b6..33f93991b 100644 --- a/knowhere/index/vector_index/IndexBinaryIDMAP.cpp +++ b/knowhere/index/vector_index/IndexBinaryIDMAP.cpp @@ -36,6 +36,35 @@ BinaryIDMAP::Load(const BinarySet& index_binary) { LoadImpl(index_binary, index_type_); } +DatasetPtr +BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + + GET_DATA_WITH_IDS(dataset_ptr) + + uint8_t* p_x = nullptr; + auto release_when_exception = [&]() { + if (p_x != nullptr) { + free(p_x); + } + }; + + try { + p_x = (uint8_t*)malloc(sizeof(uint8_t) * (dim / 8) * rows); + auto bin_idmap_index = dynamic_cast(index_.get()); + bin_idmap_index->get_vector_by_id(rows, p_ids, p_x); + return GenResultDataset(p_x); + } catch (faiss::FaissException& e) { + release_when_exception(); + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + release_when_exception(); + KNOWHERE_THROW_MSG(e.what()); + } +} + DatasetPtr BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) { if (!index_) { diff --git a/knowhere/index/vector_index/IndexBinaryIDMAP.h b/knowhere/index/vector_index/IndexBinaryIDMAP.h index 26a7d1791..5e6fa00b4 100644 --- a/knowhere/index/vector_index/IndexBinaryIDMAP.h +++ b/knowhere/index/vector_index/IndexBinaryIDMAP.h @@ -42,6 +42,9 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex { void AddWithoutIds(const DatasetPtr&, const Config&) override; + DatasetPtr + GetVectorById(const DatasetPtr&, const Config&) override; + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/knowhere/index/vector_index/IndexBinaryIVF.cpp b/knowhere/index/vector_index/IndexBinaryIVF.cpp index 66bad5f1a..986f590e1 100644 --- a/knowhere/index/vector_index/IndexBinaryIVF.cpp +++ b/knowhere/index/vector_index/IndexBinaryIVF.cpp @@ -47,6 +47,35 @@ BinaryIVF::Load(const BinarySet& index_binary) { #endif } +DatasetPtr +BinaryIVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + GET_DATA_WITH_IDS(dataset_ptr) + + uint8_t* p_x = nullptr; + auto release_when_exception = [&]() { + if (p_x != nullptr) { + free(p_x); + } + }; + + try { + p_x = (uint8_t*)malloc(sizeof(uint8_t) * (dim / 8) * rows); + auto bin_ivf_index = dynamic_cast(index_.get()); + bin_ivf_index->get_vector_by_id(rows, p_ids, p_x); + return GenResultDataset(p_x); + } catch (faiss::FaissException& e) { + release_when_exception(); + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + release_when_exception(); + KNOWHERE_THROW_MSG(e.what()); + } +} + DatasetPtr BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) { if (!index_ || !index_->is_trained) { diff --git a/knowhere/index/vector_index/IndexBinaryIVF.h b/knowhere/index/vector_index/IndexBinaryIVF.h index 193560dcf..e8e2b6065 100644 --- a/knowhere/index/vector_index/IndexBinaryIVF.h +++ b/knowhere/index/vector_index/IndexBinaryIVF.h @@ -47,6 +47,9 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex { void AddWithoutIds(const DatasetPtr&, const Config&) override; + DatasetPtr + GetVectorById(const DatasetPtr&, const Config&) override; + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/knowhere/index/vector_index/IndexIDMAP.cpp b/knowhere/index/vector_index/IndexIDMAP.cpp index f3cd7933b..20310ff31 100644 --- a/knowhere/index/vector_index/IndexIDMAP.cpp +++ b/knowhere/index/vector_index/IndexIDMAP.cpp @@ -69,6 +69,35 @@ IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { index_->add(rows, reinterpret_cast(p_data)); } +DatasetPtr +IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + + GET_DATA_WITH_IDS(dataset_ptr) + + float* p_x = nullptr; + auto release_when_exception = [&]() { + if (p_x != nullptr) { + free(p_x); + } + }; + + try { + p_x = (float*)malloc(sizeof(float) * dim * rows); + auto idmap_index = dynamic_cast(index_.get()); + idmap_index->get_vector_by_id(rows, p_ids, p_x); + return GenResultDataset(p_x); + } catch (faiss::FaissException& e) { + release_when_exception(); + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + release_when_exception(); + KNOWHERE_THROW_MSG(e.what()); + } +} + DatasetPtr IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) { if (!index_) { diff --git a/knowhere/index/vector_index/IndexIDMAP.h b/knowhere/index/vector_index/IndexIDMAP.h index 0379a7322..e942e59f9 100644 --- a/knowhere/index/vector_index/IndexIDMAP.h +++ b/knowhere/index/vector_index/IndexIDMAP.h @@ -42,6 +42,9 @@ class IDMAP : public VecIndex, public FaissBaseIndex { void AddWithoutIds(const DatasetPtr&, const Config&) override; + DatasetPtr + GetVectorById(const DatasetPtr&, const Config&) override; + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/knowhere/index/vector_index/IndexIVF.cpp b/knowhere/index/vector_index/IndexIVF.cpp index ec16f318f..444e9b6a0 100644 --- a/knowhere/index/vector_index/IndexIVF.cpp +++ b/knowhere/index/vector_index/IndexIVF.cpp @@ -90,6 +90,35 @@ IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { index_->add(rows, reinterpret_cast(p_data)); } +DatasetPtr +IVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + GET_DATA_WITH_IDS(dataset_ptr) + + float* p_x = nullptr; + auto release_when_exception = [&]() { + if (p_x != nullptr) { + free(p_x); + } + }; + + try { + p_x = (float*)malloc(sizeof(float) * dim * rows); + auto ivf_index = dynamic_cast(index_.get()); + ivf_index->get_vector_by_id(rows, p_ids, p_x); + return GenResultDataset(p_x); + } catch (faiss::FaissException& e) { + release_when_exception(); + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + release_when_exception(); + KNOWHERE_THROW_MSG(e.what()); + } +} + DatasetPtr IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) { if (!index_ || !index_->is_trained) { diff --git a/knowhere/index/vector_index/IndexIVF.h b/knowhere/index/vector_index/IndexIVF.h index 4295cb428..abed91a5e 100644 --- a/knowhere/index/vector_index/IndexIVF.h +++ b/knowhere/index/vector_index/IndexIVF.h @@ -47,6 +47,9 @@ class IVF : public VecIndex, public FaissBaseIndex { void AddWithoutIds(const DatasetPtr&, const Config&) override; + DatasetPtr + GetVectorById(const DatasetPtr&, const Config&) override; + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/knowhere/index/vector_index/adapter/VectorAdapter.cpp b/knowhere/index/vector_index/adapter/VectorAdapter.cpp index c8860fe09..c95915e01 100644 --- a/knowhere/index/vector_index/adapter/VectorAdapter.cpp +++ b/knowhere/index/vector_index/adapter/VectorAdapter.cpp @@ -26,6 +26,22 @@ GenDataset(const int64_t nb, const int64_t dim, const void* xb) { return ret_ds; } +DatasetPtr +GenDatasetWithIds(const int64_t n, const int64_t dim, const int64_t* ids) { + auto ret_ds = std::make_shared(); + SetDatasetRows(ret_ds, n); + SetDatasetDim(ret_ds, dim); + SetDatasetInputIDs(ret_ds, ids); + return ret_ds; +} + +DatasetPtr +GenResultDataset(const void* tensor) { + auto ret_ds = std::make_shared(); + SetDatasetTensor(ret_ds, tensor); + return ret_ds; +} + DatasetPtr GenResultDataset(const int64_t* ids, const float* distance) { auto ret_ds = std::make_shared(); diff --git a/knowhere/index/vector_index/adapter/VectorAdapter.h b/knowhere/index/vector_index/adapter/VectorAdapter.h index 7e35b77d8..7c8c23c1a 100644 --- a/knowhere/index/vector_index/adapter/VectorAdapter.h +++ b/knowhere/index/vector_index/adapter/VectorAdapter.h @@ -30,14 +30,14 @@ inline void func_name(DatasetPtr& ds_ptr, T value) { \ /////////////////////////////////////////////////////////////////////////////// -DEFINE_DATASET_GETTER(GetDatasetDim, meta::DIM, int64_t); -DEFINE_DATASET_SETTER(SetDatasetDim, meta::DIM, int64_t); +DEFINE_DATASET_GETTER(GetDatasetDim, meta::DIM, const int64_t); +DEFINE_DATASET_SETTER(SetDatasetDim, meta::DIM, const int64_t); DEFINE_DATASET_GETTER(GetDatasetTensor, meta::TENSOR, const void*); DEFINE_DATASET_SETTER(SetDatasetTensor, meta::TENSOR, const void*); -DEFINE_DATASET_GETTER(GetDatasetRows, meta::ROWS, int64_t); -DEFINE_DATASET_SETTER(SetDatasetRows, meta::ROWS, int64_t); +DEFINE_DATASET_GETTER(GetDatasetRows, meta::ROWS, const int64_t); +DEFINE_DATASET_SETTER(SetDatasetRows, meta::ROWS, const int64_t); DEFINE_DATASET_GETTER(GetDatasetIDs, meta::IDS, const int64_t*); DEFINE_DATASET_SETTER(SetDatasetIDs, meta::IDS, const int64_t*); @@ -48,8 +48,16 @@ DEFINE_DATASET_SETTER(SetDatasetDistance, meta::DISTANCE, const float*); DEFINE_DATASET_GETTER(GetDatasetLims, meta::LIMS, const size_t*); DEFINE_DATASET_SETTER(SetDatasetLims, meta::LIMS, const size_t*); +DEFINE_DATASET_GETTER(GetDatasetInputIDs, meta::INPUT_IDS, const int64_t*); +DEFINE_DATASET_SETTER(SetDatasetInputIDs, meta::INPUT_IDS, const int64_t*); + /////////////////////////////////////////////////////////////////////////////// +#define GET_DATA_WITH_IDS(ds_ptr) \ + auto rows = knowhere::GetDatasetRows(ds_ptr); \ + auto dim = knowhere::GetDatasetDim(ds_ptr); \ + auto p_ids = knowhere::GetDatasetInputIDs(ds_ptr); + #define GET_TENSOR_DATA(ds_ptr) \ auto rows = knowhere::GetDatasetRows(ds_ptr); \ auto p_data = knowhere::GetDatasetTensor(ds_ptr); @@ -61,6 +69,12 @@ DEFINE_DATASET_SETTER(SetDatasetLims, meta::LIMS, const size_t*); extern DatasetPtr GenDataset(const int64_t nb, const int64_t dim, const void* xb); +extern DatasetPtr +GenDatasetWithIds(const int64_t n, const int64_t dim, const int64_t* ids); + +extern DatasetPtr +GenResultDataset(const void* tensor); + extern DatasetPtr GenResultDataset(const int64_t* ids, const float* distance); diff --git a/knowhere/index/vector_index/helpers/IndexParameter.h b/knowhere/index/vector_index/helpers/IndexParameter.h index 4b25143db..85bd108b8 100644 --- a/knowhere/index/vector_index/helpers/IndexParameter.h +++ b/knowhere/index/vector_index/helpers/IndexParameter.h @@ -30,6 +30,7 @@ constexpr MetaType DISTANCE = "distance"; constexpr MetaType LIMS = "lims"; constexpr MetaType TOPK = "k"; constexpr MetaType RADIUS = "radius"; +constexpr MetaType INPUT_IDS = "input_ids"; constexpr MetaType DEVICE_ID = "gpu_id"; }; // namespace meta diff --git a/knowhere/index/vector_offset_index/IndexIVF_NM.cpp b/knowhere/index/vector_offset_index/IndexIVF_NM.cpp index 215feccd1..c19bd693a 100644 --- a/knowhere/index/vector_offset_index/IndexIVF_NM.cpp +++ b/knowhere/index/vector_offset_index/IndexIVF_NM.cpp @@ -133,6 +133,35 @@ IVF_NM::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { index_->add_without_codes(rows, reinterpret_cast(p_data)); } +DatasetPtr +IVF_NM::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + GET_DATA_WITH_IDS(dataset_ptr) + + float* p_x = nullptr; + auto release_when_exception = [&]() { + if (p_x != nullptr) { + free(p_x); + } + }; + + try { + p_x = (float*)malloc(sizeof(float) * dim * rows); + auto ivf_index = dynamic_cast(index_.get()); + ivf_index->get_vector_by_id_without_codes(rows, p_ids, data_.get(), prefix_sum_.get(), p_x); + return GenResultDataset(p_x); + } catch (faiss::FaissException& e) { + release_when_exception(); + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + release_when_exception(); + KNOWHERE_THROW_MSG(e.what()); + } +} + DatasetPtr IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) { if (!index_ || !index_->is_trained) { diff --git a/knowhere/index/vector_offset_index/IndexIVF_NM.h b/knowhere/index/vector_offset_index/IndexIVF_NM.h index d5c2bd351..6b5fce883 100644 --- a/knowhere/index/vector_offset_index/IndexIVF_NM.h +++ b/knowhere/index/vector_offset_index/IndexIVF_NM.h @@ -48,6 +48,9 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex { void AddWithoutIds(const DatasetPtr&, const Config&) override; + DatasetPtr + GetVectorById(const DatasetPtr&, const Config&) override; + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/thirdparty/faiss/faiss/Index.cpp b/thirdparty/faiss/faiss/Index.cpp index f73d060cf..0ea7f74b8 100644 --- a/thirdparty/faiss/faiss/Index.cpp +++ b/thirdparty/faiss/faiss/Index.cpp @@ -58,6 +58,22 @@ void Index::add_with_ids_without_codes( FAISS_THROW_MSG("add_with_ids_without_codes not implemented for this type of index"); } +void Index::get_vector_by_id( + idx_t n, + const idx_t* xids, + float* x) { + FAISS_THROW_MSG("get_vector_by_id not implemented for this type of index"); +} + +void Index::get_vector_by_id_without_codes( + idx_t n, + const idx_t* xids, + const uint8_t* arranged_codes, + const size_t* prefix_sum, + float* x) { + FAISS_THROW_MSG("get_vector_by_id_without_codes not implemented for this type of index"); +} + size_t Index::remove_ids(const IDSelector& /*sel*/) { FAISS_THROW_MSG("remove_ids not implemented for this type of index"); return -1; @@ -67,6 +83,14 @@ void Index::reconstruct(idx_t, float*) const { FAISS_THROW_MSG("reconstruct not implemented for this type of index"); } +void Index::reconstruct_without_codes( + idx_t, + const uint8_t*, + const size_t*, + float*) const { + FAISS_THROW_MSG("reconstruct_without_codes not implemented for this type of index"); +} + void Index::reconstruct_n(idx_t i0, idx_t ni, float* recons) const { for (idx_t i = 0; i < ni; i++) { reconstruct(i0 + i, recons + i * d); diff --git a/thirdparty/faiss/faiss/Index.h b/thirdparty/faiss/faiss/Index.h index 765afe41d..1b3fa24de 100644 --- a/thirdparty/faiss/faiss/Index.h +++ b/thirdparty/faiss/faiss/Index.h @@ -115,6 +115,23 @@ struct Index { */ virtual void add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids); + /** query n raw vectors from the index by ids. + * + * return n raw vectors. + * + * @param n input num of xids + * @param xids input labels of the NNs, size n + * @param x output raw vectors, size n * d + */ + virtual void get_vector_by_id(idx_t n, const idx_t* xids, float* x); + + virtual void get_vector_by_id_without_codes( + idx_t n, + const idx_t* xids, + const uint8_t* arranged_codes, + const size_t* prefix_sum, + float* x); + /** query n vectors of dimension d to the index. * * return at most k vectors. If there are not enough results for a @@ -178,6 +195,12 @@ struct Index { */ virtual void reconstruct(idx_t key, float* recons) const; + virtual void reconstruct_without_codes( + idx_t key, + const uint8_t* arranged_codes, + const size_t* prefix_sum, + float* recons) const; + /** Reconstruct vectors i0 to i0 + ni - 1 * * this function may not be defined for some indexes diff --git a/thirdparty/faiss/faiss/IndexBinary.cpp b/thirdparty/faiss/faiss/IndexBinary.cpp index 75f10c7da..c84c8ed58 100644 --- a/thirdparty/faiss/faiss/IndexBinary.cpp +++ b/thirdparty/faiss/faiss/IndexBinary.cpp @@ -36,6 +36,10 @@ void IndexBinary::add_with_ids(idx_t, const uint8_t*, const idx_t*) { FAISS_THROW_MSG("add_with_ids not implemented for this type of index"); } +void IndexBinary::get_vector_by_id(idx_t n, const idx_t* xids, uint8_t* x) { + FAISS_THROW_MSG("get_vector_by_id not implemented for this type of index"); +} + size_t IndexBinary::remove_ids(const IDSelector&) { FAISS_THROW_MSG("remove_ids not implemented for this type of index"); return 0; diff --git a/thirdparty/faiss/faiss/IndexBinary.h b/thirdparty/faiss/faiss/IndexBinary.h index a39360f5f..e66f0d5e5 100644 --- a/thirdparty/faiss/faiss/IndexBinary.h +++ b/thirdparty/faiss/faiss/IndexBinary.h @@ -83,6 +83,16 @@ struct IndexBinary { */ virtual void add_with_ids(idx_t n, const uint8_t* x, const idx_t* xids); + /** Query n raw vectors from the index by ids. + * + * return n raw vectors. + * + * @param n input num of xids + * @param xids input labels of the NNs, size n + * @param x output raw vectors, size n * d + */ + virtual void get_vector_by_id(idx_t n, const idx_t* xids, uint8_t* x); + /** Query n vectors of dimension d to the index. * * return at most k vectors. If there are not enough results for a diff --git a/thirdparty/faiss/faiss/IndexBinaryFlat.cpp b/thirdparty/faiss/faiss/IndexBinaryFlat.cpp index a140b18c8..22b493e59 100644 --- a/thirdparty/faiss/faiss/IndexBinaryFlat.cpp +++ b/thirdparty/faiss/faiss/IndexBinaryFlat.cpp @@ -31,6 +31,12 @@ void IndexBinaryFlat::add(idx_t n, const uint8_t* x) { ntotal += n; } +void IndexBinaryFlat::get_vector_by_id(idx_t n, const idx_t* xids, uint8_t* x) { + for (idx_t i = 0; i < n; i++) { + memcpy(x + i * code_size, xb.data() + xids[i] * code_size, code_size); + } +} + void IndexBinaryFlat::reset() { xb.clear(); ntotal = 0; diff --git a/thirdparty/faiss/faiss/IndexBinaryFlat.h b/thirdparty/faiss/faiss/IndexBinaryFlat.h index 619c07110..1f036e602 100644 --- a/thirdparty/faiss/faiss/IndexBinaryFlat.h +++ b/thirdparty/faiss/faiss/IndexBinaryFlat.h @@ -35,6 +35,8 @@ struct IndexBinaryFlat : IndexBinary { void add(idx_t n, const uint8_t* x) override; + void get_vector_by_id(idx_t n, const idx_t* xids, uint8_t* x) override; + void reset() override; void search( diff --git a/thirdparty/faiss/faiss/IndexBinaryIVF.cpp b/thirdparty/faiss/faiss/IndexBinaryIVF.cpp index d64a66327..4cbfaf6c6 100644 --- a/thirdparty/faiss/faiss/IndexBinaryIVF.cpp +++ b/thirdparty/faiss/faiss/IndexBinaryIVF.cpp @@ -132,6 +132,13 @@ void IndexBinaryIVF::add_core( ntotal += n_add; } +void IndexBinaryIVF::get_vector_by_id(idx_t n, const idx_t* xids, uint8_t* x) { + make_direct_map(true); + for (idx_t i = 0; i < n; i++) { + reconstruct(xids[i], x + i * d / 8); + } +} + void IndexBinaryIVF::make_direct_map(bool b) { if (b) { direct_map.set_type(DirectMap::Array, invlists, ntotal); diff --git a/thirdparty/faiss/faiss/IndexBinaryIVF.h b/thirdparty/faiss/faiss/IndexBinaryIVF.h index cc14889e6..25491a710 100644 --- a/thirdparty/faiss/faiss/IndexBinaryIVF.h +++ b/thirdparty/faiss/faiss/IndexBinaryIVF.h @@ -90,6 +90,8 @@ struct IndexBinaryIVF : IndexBinary { const idx_t* xids, const idx_t* precomputed_idx); + void get_vector_by_id(idx_t n, const idx_t* xids, uint8_t* x) override; + /** Search a set of vectors, that are pre-quantized by the IVF * quantizer. Fill in the corresponding heaps with the query * results. search() calls this. diff --git a/thirdparty/faiss/faiss/IndexFlat.cpp b/thirdparty/faiss/faiss/IndexFlat.cpp index 573400bf3..28cb5d95c 100644 --- a/thirdparty/faiss/faiss/IndexFlat.cpp +++ b/thirdparty/faiss/faiss/IndexFlat.cpp @@ -23,6 +23,12 @@ namespace faiss { IndexFlat::IndexFlat(idx_t d, MetricType metric) : IndexFlatCodes(sizeof(float) * d, d, metric) {} +void IndexFlat::get_vector_by_id(idx_t n, const idx_t* xids, float* x) { + for (idx_t i = 0; i < n; i++) { + memcpy(x + i * d, get_xb() + xids[i] * d, code_size); + } +} + void IndexFlat::search( idx_t n, const float* x, diff --git a/thirdparty/faiss/faiss/IndexFlat.h b/thirdparty/faiss/faiss/IndexFlat.h index 297e316ae..d305ec6be 100644 --- a/thirdparty/faiss/faiss/IndexFlat.h +++ b/thirdparty/faiss/faiss/IndexFlat.h @@ -24,6 +24,8 @@ struct IndexFlat : IndexFlatCodes { explicit IndexFlat(idx_t d, MetricType metric = METRIC_L2); + void get_vector_by_id(idx_t n, const idx_t* xids, float* x) override; + void search( idx_t n, const float* x, diff --git a/thirdparty/faiss/faiss/IndexIVF.cpp b/thirdparty/faiss/faiss/IndexIVF.cpp index bb7c392b1..94340e88c 100644 --- a/thirdparty/faiss/faiss/IndexIVF.cpp +++ b/thirdparty/faiss/faiss/IndexIVF.cpp @@ -1447,6 +1447,20 @@ void IndexIVF::reconstruct(idx_t key, float* recons) const { reconstruct_from_offset(lo_listno(lo), lo_offset(lo), recons); } +void IndexIVF::reconstruct_without_codes( + idx_t key, + const uint8_t* arranged_codes, + const size_t* prefix_sum, + float* recons) const { + idx_t lo = direct_map.get(key); + reconstruct_from_offset_without_codes( + lo_listno(lo), + lo_offset(lo), + arranged_codes, + prefix_sum, + recons); +} + void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const { FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal)); @@ -1539,6 +1553,15 @@ void IndexIVF::reconstruct_from_offset( FAISS_THROW_MSG("reconstruct_from_offset not implemented"); } +void IndexIVF::reconstruct_from_offset_without_codes( + int64_t /*list_no*/, + int64_t /*offset*/, + const uint8_t* /*arranged_codes*/, + const size_t* /*prefix_sum*/, + float* /*recons*/) const { + FAISS_THROW_MSG("reconstruct_from_offset_without_codes not implemented"); +} + void IndexIVF::reset() { direct_map.clear(); invlists->reset(); diff --git a/thirdparty/faiss/faiss/IndexIVF.h b/thirdparty/faiss/faiss/IndexIVF.h index 0aa923d47..fba3158eb 100644 --- a/thirdparty/faiss/faiss/IndexIVF.h +++ b/thirdparty/faiss/faiss/IndexIVF.h @@ -310,6 +310,12 @@ struct IndexIVF : Index, Level1Quantizer { */ void reconstruct(idx_t key, float* recons) const override; + void reconstruct_without_codes( + idx_t key, + const uint8_t* arranged_codes, + const size_t* prefix_sum, + float* recons) const override; + /** Update a subset of vectors. * * The index must have a direct_map @@ -360,6 +366,13 @@ struct IndexIVF : Index, Level1Quantizer { int64_t offset, float* recons) const; + virtual void reconstruct_from_offset_without_codes( + int64_t list_no, + int64_t offset, + const uint8_t* arranged_codes, + const size_t* prefix_sum, + float* recons) const; + /// Dataset manipulation functions size_t remove_ids(const IDSelector& sel) override; diff --git a/thirdparty/faiss/faiss/IndexIVFFlat.cpp b/thirdparty/faiss/faiss/IndexIVFFlat.cpp index fc9256d27..f6a78dfdc 100644 --- a/thirdparty/faiss/faiss/IndexIVFFlat.cpp +++ b/thirdparty/faiss/faiss/IndexIVFFlat.cpp @@ -117,6 +117,28 @@ void IndexIVFFlat::add_with_ids_without_codes( ntotal += n; } +void IndexIVFFlat::get_vector_by_id( + idx_t n, + const idx_t* xids, + float* x) { + make_direct_map(true); + for (idx_t i = 0; i < n; i++) { + reconstruct(xids[i], x + i * d); + } +} + +void IndexIVFFlat::get_vector_by_id_without_codes( + idx_t n, + const idx_t* xids, + const uint8_t* arranged_codes, + const size_t* prefix_sum, + float* x) { + make_direct_map(true); + for (idx_t i = 0; i < n; i++) { + reconstruct_without_codes(xids[i], arranged_codes, prefix_sum, x + i * d); + } +} + void IndexIVFFlat::encode_vectors( idx_t n, const float* x, @@ -247,6 +269,15 @@ void IndexIVFFlat::reconstruct_from_offset( memcpy(recons, invlists->get_single_code(list_no, offset), code_size); } +void IndexIVFFlat::reconstruct_from_offset_without_codes( + int64_t list_no, + int64_t offset, + const uint8_t* arranged_codes, + const size_t* prefix_sum, + float* recons) const { + memcpy(recons, arranged_codes + (prefix_sum[list_no] + offset) * code_size, code_size); +} + /***************************************** * IndexIVFFlatDedup implementation ******************************************/ diff --git a/thirdparty/faiss/faiss/IndexIVFFlat.h b/thirdparty/faiss/faiss/IndexIVFFlat.h index 90b3a1c2d..00ce850b4 100644 --- a/thirdparty/faiss/faiss/IndexIVFFlat.h +++ b/thirdparty/faiss/faiss/IndexIVFFlat.h @@ -40,6 +40,15 @@ struct IndexIVFFlat : IndexIVF { const float* x, const idx_t* xids) override; + void get_vector_by_id(idx_t n, const idx_t* xids, float* x) override; + + void get_vector_by_id_without_codes( + idx_t n, + const idx_t* xids, + const uint8_t* arranged_codes, + const size_t* prefix_sum, + float* x) override; + void encode_vectors( idx_t n, const float* x, @@ -53,6 +62,13 @@ struct IndexIVFFlat : IndexIVF { void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons) const override; + void reconstruct_from_offset_without_codes( + int64_t list_no, + int64_t offset, + const uint8_t* arranged_codes, + const size_t* prefix_sum, + float* recons) const override; + void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override; IndexIVFFlat() {} diff --git a/unittest/test_binaryidmap.cpp b/unittest/test_binaryidmap.cpp index 3c45b8f94..7e36261b8 100644 --- a/unittest/test_binaryidmap.cpp +++ b/unittest/test_binaryidmap.cpp @@ -77,11 +77,14 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) { ASSERT_TRUE(index_->GetRawVectors() != nullptr); ASSERT_GT(index_->Size(), 0); + auto result = index_->GetVectorById(id_dataset, conf_); + AssertBinVec(result, base_dataset, id_dataset, nq, dim); + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_); ASSERT_TRUE(adapter->CheckSearch(conf_, index_type_, index_mode_)); - auto result = index_->Query(query_dataset, conf_, nullptr); - AssertAnns(result, nq, k); + auto result1 = index_->Query(query_dataset, conf_, nullptr); + AssertAnns(result1, nq, k); // PrintResult(result, nq, k); auto binaryset = index_->Serialize(conf_); diff --git a/unittest/test_binaryivf.cpp b/unittest/test_binaryivf.cpp index fa5be29ea..f26b95441 100644 --- a/unittest/test_binaryivf.cpp +++ b/unittest/test_binaryivf.cpp @@ -75,11 +75,14 @@ TEST_P(BinaryIVFTest, binaryivf_basic) { EXPECT_EQ(index_->Dim(), dim); ASSERT_GT(index_->Size(), 0); + auto result = index_->GetVectorById(id_dataset, conf_); + AssertBinVec(result, base_dataset, id_dataset, nq, dim); + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_); ASSERT_TRUE(adapter->CheckSearch(conf_, index_type_, index_mode_)); - auto result = index_->Query(query_dataset, conf_, nullptr); - AssertAnns(result, nq, knowhere::GetMetaTopk(conf_)); + auto result1 = index_->Query(query_dataset, conf_, nullptr); + AssertAnns(result1, nq, knowhere::GetMetaTopk(conf_)); // PrintResult(result, nq, k); auto result2 = index_->Query(query_dataset, conf_, *bitset); diff --git a/unittest/test_idmap.cpp b/unittest/test_idmap.cpp index 20391c90e..115ef54a6 100644 --- a/unittest/test_idmap.cpp +++ b/unittest/test_idmap.cpp @@ -91,13 +91,15 @@ TEST_P(IDMAPTest, idmap_basic) { ASSERT_TRUE(index_->GetRawVectors() != nullptr); ASSERT_GT(index_->Size(), 0); + auto result = index_->GetVectorById(id_dataset, conf); + AssertVec(result, base_dataset, id_dataset, nq, dim); + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_); ASSERT_TRUE(adapter->CheckSearch(conf, index_type_, index_mode_)); - auto result = index_->Query(query_dataset, conf, nullptr); - AssertAnns(result, nq, k); - // PrintResult(result, nq, k); - + auto result1 = index_->Query(query_dataset, conf, nullptr); + AssertAnns(result1, nq, k); + // PrintResult(result1, nq, k); #ifdef KNOWHERE_GPU_VERSION if (index_mode_ == knowhere::IndexMode::MODE_GPU) { diff --git a/unittest/test_ivf.cpp b/unittest/test_ivf.cpp index c530b0514..733bc30ca 100644 --- a/unittest/test_ivf.cpp +++ b/unittest/test_ivf.cpp @@ -94,6 +94,8 @@ TEST_P(IVFTest, ivf_basic) { EXPECT_EQ(index_->Dim(), dim); ASSERT_GT(index_->Size(), 0); + ASSERT_ANY_THROW(index_->GetVectorById(id_dataset, conf_)); + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_); ASSERT_TRUE(adapter->CheckSearch(conf_, index_type_, index_mode_)); diff --git a/unittest/test_ivf_nm.cpp b/unittest/test_ivf_nm.cpp index 4b31c37f4..899543bc8 100644 --- a/unittest/test_ivf_nm.cpp +++ b/unittest/test_ivf_nm.cpp @@ -103,11 +103,14 @@ TEST_P(IVFNMTest, ivfnm_basic) { LoadRawData(index_, base_dataset, conf_); + auto result = index_->GetVectorById(id_dataset, conf_); + AssertVec(result, base_dataset, id_dataset, nq, dim); + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_); ASSERT_TRUE(adapter->CheckSearch(conf_, index_type_, index_mode_)); - auto result = index_->Query(query_dataset, conf_, nullptr); - AssertAnns(result, nq, k); + auto result1 = index_->Query(query_dataset, conf_, nullptr); + AssertAnns(result1, nq, k); #ifdef KNOWHERE_GPU_VERSION // copy cpu to gpu diff --git a/unittest/utils.cpp b/unittest/utils.cpp index af96c8944..2d4e84e77 100644 --- a/unittest/utils.cpp +++ b/unittest/utils.cpp @@ -32,7 +32,7 @@ DataGen::Generate(const int dim, const int nb, const int nq, const bool is_binar this->nq = nq; if (!is_binary) { - GenAll(dim, nb, xb, ids, xids, nq, xq); + GenAll(dim, nb, xb, ids, nq, xq); assert(xb.size() == (size_t)dim * nb); assert(xq.size() == (size_t)dim * nq); @@ -40,7 +40,7 @@ DataGen::Generate(const int dim, const int nb, const int nq, const bool is_binar query_dataset = knowhere::GenDataset(nq, dim, xq.data()); } else { int64_t dim_x = dim / 8; - GenAll(dim_x, nb, xb_bin, ids, xids, nq, xq_bin); + GenAll(dim_x, nb, xb_bin, ids, nq, xq_bin); assert(xb_bin.size() == (size_t)dim_x * nb); assert(xq_bin.size() == (size_t)dim_x * nq); @@ -48,8 +48,8 @@ DataGen::Generate(const int dim, const int nb, const int nq, const bool is_binar query_dataset = knowhere::GenDataset(nq, dim, xq_bin.data()); } - id_dataset = knowhere::GenDataset(nq, dim, nullptr); - xid_dataset = knowhere::GenDataset(nq, dim, nullptr); + // used to test GetVectorById [0, nq-1] + id_dataset = knowhere::GenDatasetWithIds(nq, dim, ids.data()); bitset_data.resize(nb/8); for (int64_t i = 0; i < nq; ++i) { @@ -63,14 +63,12 @@ GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector& ids, - std::vector& xids, const int64_t nq, std::vector& xq) { xb.resize(nb * dim); xq.resize(nq * dim); ids.resize(nb); - xids.resize(1); - GenBase(dim, nb, xb.data(), ids.data(), nq, xq.data(), xids.data(), false); + GenBase(dim, nb, xb.data(), ids.data(), nq, xq.data(), false); } void @@ -78,14 +76,12 @@ GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector& ids, - std::vector& xids, const int64_t nq, std::vector& xq) { xb.resize(nb * dim); xq.resize(nq * dim); ids.resize(nb); - xids.resize(1); - GenBase(dim, nb, xb.data(), ids.data(), nq, xq.data(), xids.data(), true); + GenBase(dim, nb, xb.data(), ids.data(), nq, xq.data(), true); } void @@ -95,7 +91,6 @@ GenBase(const int64_t dim, int64_t* ids, const int64_t nq, const void* xq, - int64_t* xids, bool is_binary) { if (!is_binary) { float* xb_f = (float*)xb; @@ -124,7 +119,6 @@ GenBase(const int64_t dim, xq_u[i] = xb_u[i]; } } - xids[0] = 3; // pseudo random } void @@ -145,53 +139,44 @@ AssertAnns(const knowhere::DatasetPtr& result, const int nq, const int k, const } } -#if 0 void -AssertVec(const knowhere::DatasetPtr& result, const knowhere::DatasetPtr& base_dataset, - const knowhere::DatasetPtr& id_dataset, const int n, const int dim, const CheckMode check_mode) { +AssertVec(const knowhere::DatasetPtr& result, + const knowhere::DatasetPtr& base_dataset, + const knowhere::DatasetPtr& id_dataset, + const int n, + const int dim) { float* base = (float*)knowhere::GetDatasetTensor(base_dataset); - auto ids = knowhere::GetDatasetIDs(id_dataset); + auto ids = knowhere::GetDatasetInputIDs(id_dataset); auto x = (float*)knowhere::GetDatasetTensor(result); for (auto i = 0; i < n; i++) { auto id = ids[i]; for (auto j = 0; j < dim; j++) { - switch (check_mode) { - case CheckMode::CHECK_EQUAL: { - ASSERT_EQ(*(base + id * dim + j), *(x + i * dim + j)); - break; - } - case CheckMode::CHECK_NOT_EQUAL: { - ASSERT_NE(*(base + id * dim + j), *(x + i * dim + j)); - break; - } - case CheckMode::CHECK_APPROXIMATE_EQUAL: { - float a = *(base + id * dim + j); - float b = *(x + i * dim + j); - ASSERT_TRUE((std::fabs(a - b) / std::fabs(a)) < 0.1); - break; - } - default: - ASSERT_TRUE(false); - break; - } + float va = *(base + id * dim + j); + float vb = *(x + i * dim + j); + ASSERT_EQ(va, vb); } } } void -AssertBinVec(const knowhere::DatasetPtr& result, const knowhere::DatasetPtr& base_dataset, - const knowhere::DatasetPtr& id_dataset, const int n, const int dim, const CheckMode check_mode) { +AssertBinVec(const knowhere::DatasetPtr& result, + const knowhere::DatasetPtr& base_dataset, + const knowhere::DatasetPtr& id_dataset, + const int n, + const int dim) { auto base = (uint8_t*)knowhere::GetDatasetTensor(base_dataset); - auto ids = knowhere::GetDatasetIDs(id_dataset; - auto x = (float*)knowhere::GetDatasetTensor(result); - for (auto i = 0; i < 1; i++) { + auto ids = knowhere::GetDatasetInputIDs(id_dataset); + auto x = (uint8_t*)knowhere::GetDatasetTensor(result); + int dim_uint8 = dim / 8; + for (auto i = 0; i < n; i++) { auto id = ids[i]; - for (auto j = 0; j < dim; j++) { - ASSERT_EQ(*(base + id * dim + j), *(x + i * dim + j)); + for (auto j = 0; j < dim_uint8; j++) { + uint8_t va = *(base + id * dim_uint8 + j); + uint8_t vb = *(x + i * dim_uint8 + j); + ASSERT_EQ(va, vb); } } } -#endif void PrintResult(const knowhere::DatasetPtr& result, const int& nq, const int& k) { diff --git a/unittest/utils.h b/unittest/utils.h index bd929ddda..fbb48d8ba 100644 --- a/unittest/utils.h +++ b/unittest/utils.h @@ -46,11 +46,9 @@ class DataGen { std::vector xb_bin; std::vector xq_bin; std::vector ids; - std::vector xids; knowhere::DatasetPtr base_dataset = nullptr; knowhere::DatasetPtr query_dataset = nullptr; knowhere::DatasetPtr id_dataset = nullptr; - knowhere::DatasetPtr xid_dataset = nullptr; std::vector bitset_data; faiss::BitsetViewPtr bitset; @@ -61,7 +59,6 @@ GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector& ids, - std::vector& xids, const int64_t nq, std::vector& xq); @@ -70,7 +67,6 @@ GenAll(const int64_t dim, const int64_t nb, std::vector& xb, std::vector& ids, - std::vector& xids, const int64_t nq, std::vector& xq); @@ -81,7 +77,6 @@ GenBase(const int64_t dim, int64_t* ids, const int64_t nq, const void* xq, - int64_t* xids, const bool is_binary); enum class CheckMode { @@ -101,16 +96,14 @@ AssertVec(const knowhere::DatasetPtr& result, const knowhere::DatasetPtr& base_dataset, const knowhere::DatasetPtr& id_dataset, const int n, - const int dim, - const CheckMode check_mode = CheckMode::CHECK_EQUAL); + const int dim); void AssertBinVec(const knowhere::DatasetPtr& result, const knowhere::DatasetPtr& base_dataset, const knowhere::DatasetPtr& id_dataset, const int n, - const int dim, - const CheckMode check_mode = CheckMode::CHECK_EQUAL); + const int dim); void PrintResult(const knowhere::DatasetPtr& result, const int& nq, const int& k);