Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Support to get vector raw data from faiss index (#199)
Browse files Browse the repository at this point in the history
Signed-off-by: yudong.cai <[email protected]>
  • Loading branch information
cydrain authored May 31, 2022
1 parent 4c6ff55 commit a993cf2
Show file tree
Hide file tree
Showing 35 changed files with 422 additions and 66 deletions.
5 changes: 5 additions & 0 deletions knowhere/index/VecIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class VecIndex : public Index {
virtual void
AddWithoutIds(const DatasetPtr& dataset, const Config& config) = 0;

virtual DatasetPtr
GetVectorById(const DatasetPtr& dataset, const Config& config) {
KNOWHERE_THROW_MSG("GetVectorById not supported yet");
}

virtual DatasetPtr
Query(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView bitset) = 0;

Expand Down
29 changes: 29 additions & 0 deletions knowhere/index/vector_index/IndexBinaryIDMAP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,35 @@ BinaryIDMAP::Load(const BinarySet& index_binary) {
LoadImpl(index_binary, index_type_);
}

DatasetPtr
BinaryIDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}

GET_DATA_WITH_IDS(dataset_ptr)

uint8_t* p_x = nullptr;
auto release_when_exception = [&]() {
if (p_x != nullptr) {
free(p_x);
}
};

try {
p_x = (uint8_t*)malloc(sizeof(uint8_t) * (dim / 8) * rows);
auto bin_idmap_index = dynamic_cast<faiss::IndexBinaryFlat*>(index_.get());
bin_idmap_index->get_vector_by_id(rows, p_ids, p_x);
return GenResultDataset(p_x);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
} catch (std::exception& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
}
}

DatasetPtr
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
if (!index_) {
Expand Down
3 changes: 3 additions & 0 deletions knowhere/index/vector_index/IndexBinaryIDMAP.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
void
AddWithoutIds(const DatasetPtr&, const Config&) override;

DatasetPtr
GetVectorById(const DatasetPtr&, const Config&) override;

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
29 changes: 29 additions & 0 deletions knowhere/index/vector_index/IndexBinaryIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,35 @@ BinaryIVF::Load(const BinarySet& index_binary) {
#endif
}

DatasetPtr
BinaryIVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}

GET_DATA_WITH_IDS(dataset_ptr)

uint8_t* p_x = nullptr;
auto release_when_exception = [&]() {
if (p_x != nullptr) {
free(p_x);
}
};

try {
p_x = (uint8_t*)malloc(sizeof(uint8_t) * (dim / 8) * rows);
auto bin_ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
bin_ivf_index->get_vector_by_id(rows, p_ids, p_x);
return GenResultDataset(p_x);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
} catch (std::exception& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
}
}

DatasetPtr
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
if (!index_ || !index_->is_trained) {
Expand Down
3 changes: 3 additions & 0 deletions knowhere/index/vector_index/IndexBinaryIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
void
AddWithoutIds(const DatasetPtr&, const Config&) override;

DatasetPtr
GetVectorById(const DatasetPtr&, const Config&) override;

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
29 changes: 29 additions & 0 deletions knowhere/index/vector_index/IndexIDMAP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,35 @@ IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
index_->add(rows, reinterpret_cast<const float*>(p_data));
}

DatasetPtr
IDMAP::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_) {
KNOWHERE_THROW_MSG("index not initialize");
}

GET_DATA_WITH_IDS(dataset_ptr)

float* p_x = nullptr;
auto release_when_exception = [&]() {
if (p_x != nullptr) {
free(p_x);
}
};

try {
p_x = (float*)malloc(sizeof(float) * dim * rows);
auto idmap_index = dynamic_cast<faiss::IndexFlat*>(index_.get());
idmap_index->get_vector_by_id(rows, p_ids, p_x);
return GenResultDataset(p_x);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
} catch (std::exception& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
}
}

DatasetPtr
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
if (!index_) {
Expand Down
3 changes: 3 additions & 0 deletions knowhere/index/vector_index/IndexIDMAP.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
void
AddWithoutIds(const DatasetPtr&, const Config&) override;

DatasetPtr
GetVectorById(const DatasetPtr&, const Config&) override;

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
29 changes: 29 additions & 0 deletions knowhere/index/vector_index/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,35 @@ IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
index_->add(rows, reinterpret_cast<const float*>(p_data));
}

DatasetPtr
IVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}

GET_DATA_WITH_IDS(dataset_ptr)

float* p_x = nullptr;
auto release_when_exception = [&]() {
if (p_x != nullptr) {
free(p_x);
}
};

try {
p_x = (float*)malloc(sizeof(float) * dim * rows);
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
ivf_index->get_vector_by_id(rows, p_ids, p_x);
return GenResultDataset(p_x);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
} catch (std::exception& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
}
}

DatasetPtr
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
if (!index_ || !index_->is_trained) {
Expand Down
3 changes: 3 additions & 0 deletions knowhere/index/vector_index/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class IVF : public VecIndex, public FaissBaseIndex {
void
AddWithoutIds(const DatasetPtr&, const Config&) override;

DatasetPtr
GetVectorById(const DatasetPtr&, const Config&) override;

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
16 changes: 16 additions & 0 deletions knowhere/index/vector_index/adapter/VectorAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ GenDataset(const int64_t nb, const int64_t dim, const void* xb) {
return ret_ds;
}

DatasetPtr
GenDatasetWithIds(const int64_t n, const int64_t dim, const int64_t* ids) {
auto ret_ds = std::make_shared<Dataset>();
SetDatasetRows(ret_ds, n);
SetDatasetDim(ret_ds, dim);
SetDatasetInputIDs(ret_ds, ids);
return ret_ds;
}

DatasetPtr
GenResultDataset(const void* tensor) {
auto ret_ds = std::make_shared<Dataset>();
SetDatasetTensor(ret_ds, tensor);
return ret_ds;
}

DatasetPtr
GenResultDataset(const int64_t* ids, const float* distance) {
auto ret_ds = std::make_shared<Dataset>();
Expand Down
22 changes: 18 additions & 4 deletions knowhere/index/vector_index/adapter/VectorAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ inline void func_name(DatasetPtr& ds_ptr, T value) { \

///////////////////////////////////////////////////////////////////////////////

DEFINE_DATASET_GETTER(GetDatasetDim, meta::DIM, int64_t);
DEFINE_DATASET_SETTER(SetDatasetDim, meta::DIM, int64_t);
DEFINE_DATASET_GETTER(GetDatasetDim, meta::DIM, const int64_t);
DEFINE_DATASET_SETTER(SetDatasetDim, meta::DIM, const int64_t);

DEFINE_DATASET_GETTER(GetDatasetTensor, meta::TENSOR, const void*);
DEFINE_DATASET_SETTER(SetDatasetTensor, meta::TENSOR, const void*);

DEFINE_DATASET_GETTER(GetDatasetRows, meta::ROWS, int64_t);
DEFINE_DATASET_SETTER(SetDatasetRows, meta::ROWS, int64_t);
DEFINE_DATASET_GETTER(GetDatasetRows, meta::ROWS, const int64_t);
DEFINE_DATASET_SETTER(SetDatasetRows, meta::ROWS, const int64_t);

DEFINE_DATASET_GETTER(GetDatasetIDs, meta::IDS, const int64_t*);
DEFINE_DATASET_SETTER(SetDatasetIDs, meta::IDS, const int64_t*);
Expand All @@ -48,8 +48,16 @@ DEFINE_DATASET_SETTER(SetDatasetDistance, meta::DISTANCE, const float*);
DEFINE_DATASET_GETTER(GetDatasetLims, meta::LIMS, const size_t*);
DEFINE_DATASET_SETTER(SetDatasetLims, meta::LIMS, const size_t*);

DEFINE_DATASET_GETTER(GetDatasetInputIDs, meta::INPUT_IDS, const int64_t*);
DEFINE_DATASET_SETTER(SetDatasetInputIDs, meta::INPUT_IDS, const int64_t*);

///////////////////////////////////////////////////////////////////////////////

#define GET_DATA_WITH_IDS(ds_ptr) \
auto rows = knowhere::GetDatasetRows(ds_ptr); \
auto dim = knowhere::GetDatasetDim(ds_ptr); \
auto p_ids = knowhere::GetDatasetInputIDs(ds_ptr);

#define GET_TENSOR_DATA(ds_ptr) \
auto rows = knowhere::GetDatasetRows(ds_ptr); \
auto p_data = knowhere::GetDatasetTensor(ds_ptr);
Expand All @@ -61,6 +69,12 @@ DEFINE_DATASET_SETTER(SetDatasetLims, meta::LIMS, const size_t*);
extern DatasetPtr
GenDataset(const int64_t nb, const int64_t dim, const void* xb);

extern DatasetPtr
GenDatasetWithIds(const int64_t n, const int64_t dim, const int64_t* ids);

extern DatasetPtr
GenResultDataset(const void* tensor);

extern DatasetPtr
GenResultDataset(const int64_t* ids, const float* distance);

Expand Down
1 change: 1 addition & 0 deletions knowhere/index/vector_index/helpers/IndexParameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ constexpr MetaType DISTANCE = "distance";
constexpr MetaType LIMS = "lims";
constexpr MetaType TOPK = "k";
constexpr MetaType RADIUS = "radius";
constexpr MetaType INPUT_IDS = "input_ids";
constexpr MetaType DEVICE_ID = "gpu_id";
}; // namespace meta

Expand Down
29 changes: 29 additions & 0 deletions knowhere/index/vector_offset_index/IndexIVF_NM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,35 @@ IVF_NM::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
index_->add_without_codes(rows, reinterpret_cast<const float*>(p_data));
}

DatasetPtr
IVF_NM::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) {
if (!index_ || !index_->is_trained) {
KNOWHERE_THROW_MSG("index not initialize or trained");
}

GET_DATA_WITH_IDS(dataset_ptr)

float* p_x = nullptr;
auto release_when_exception = [&]() {
if (p_x != nullptr) {
free(p_x);
}
};

try {
p_x = (float*)malloc(sizeof(float) * dim * rows);
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
ivf_index->get_vector_by_id_without_codes(rows, p_ids, data_.get(), prefix_sum_.get(), p_x);
return GenResultDataset(p_x);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
} catch (std::exception& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
}
}

DatasetPtr
IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
if (!index_ || !index_->is_trained) {
Expand Down
3 changes: 3 additions & 0 deletions knowhere/index/vector_offset_index/IndexIVF_NM.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex {
void
AddWithoutIds(const DatasetPtr&, const Config&) override;

DatasetPtr
GetVectorById(const DatasetPtr&, const Config&) override;

DatasetPtr
Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override;

Expand Down
24 changes: 24 additions & 0 deletions thirdparty/faiss/faiss/Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ void Index::add_with_ids_without_codes(
FAISS_THROW_MSG("add_with_ids_without_codes not implemented for this type of index");
}

void Index::get_vector_by_id(
idx_t n,
const idx_t* xids,
float* x) {
FAISS_THROW_MSG("get_vector_by_id not implemented for this type of index");
}

void Index::get_vector_by_id_without_codes(
idx_t n,
const idx_t* xids,
const uint8_t* arranged_codes,
const size_t* prefix_sum,
float* x) {
FAISS_THROW_MSG("get_vector_by_id_without_codes not implemented for this type of index");
}

size_t Index::remove_ids(const IDSelector& /*sel*/) {
FAISS_THROW_MSG("remove_ids not implemented for this type of index");
return -1;
Expand All @@ -67,6 +83,14 @@ void Index::reconstruct(idx_t, float*) const {
FAISS_THROW_MSG("reconstruct not implemented for this type of index");
}

void Index::reconstruct_without_codes(
idx_t,
const uint8_t*,
const size_t*,
float*) const {
FAISS_THROW_MSG("reconstruct_without_codes not implemented for this type of index");
}

void Index::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
for (idx_t i = 0; i < ni; i++) {
reconstruct(i0 + i, recons + i * d);
Expand Down
23 changes: 23 additions & 0 deletions thirdparty/faiss/faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,23 @@ struct Index {
*/
virtual void add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids);

/** query n raw vectors from the index by ids.
*
* return n raw vectors.
*
* @param n input num of xids
* @param xids input labels of the NNs, size n
* @param x output raw vectors, size n * d
*/
virtual void get_vector_by_id(idx_t n, const idx_t* xids, float* x);

virtual void get_vector_by_id_without_codes(
idx_t n,
const idx_t* xids,
const uint8_t* arranged_codes,
const size_t* prefix_sum,
float* x);

/** query n vectors of dimension d to the index.
*
* return at most k vectors. If there are not enough results for a
Expand Down Expand Up @@ -178,6 +195,12 @@ struct Index {
*/
virtual void reconstruct(idx_t key, float* recons) const;

virtual void reconstruct_without_codes(
idx_t key,
const uint8_t* arranged_codes,
const size_t* prefix_sum,
float* recons) const;

/** Reconstruct vectors i0 to i0 + ni - 1
*
* this function may not be defined for some indexes
Expand Down
Loading

0 comments on commit a993cf2

Please sign in to comment.