Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Add interface to access Dataset (#172)
Browse files Browse the repository at this point in the history
Signed-off-by: yudong.cai <[email protected]>
  • Loading branch information
cydrain authored May 10, 2022
1 parent 9c981fc commit 9f86cd6
Show file tree
Hide file tree
Showing 33 changed files with 306 additions and 707 deletions.
18 changes: 7 additions & 11 deletions knowhere/common/Dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@ class Dataset {
~Dataset() {
for (auto const& d : data_) {
if (d.first == meta::IDS) {
auto ids = Get<int64_t*>(meta::IDS);
auto ids = Get<const int64_t*>(meta::IDS);
// the space of ids must be allocated through malloc
free(ids);
free((void*)ids);
}
if (d.first == meta::DISTANCE) {
auto distances = Get<float*>(meta::DISTANCE);
auto distances = Get<const float*>(meta::DISTANCE);
// the space of distance must be allocated through malloc
free(distances);
free((void*)distances);
}
if (d.first == meta::LIMS) {
auto lims = Get<size_t*>(meta::LIMS);
auto lims = Get<const size_t*>(meta::LIMS);
// the space of lims must be allocated through malloc
free(lims);
free((void*)lims);
}
}
}
Expand All @@ -58,11 +58,7 @@ class Dataset {
T
Get(const std::string_view& k) {
std::lock_guard<std::mutex> lk(mutex_);
try {
return std::any_cast<T>(*(data_.at(std::string(k))));
} catch (...) {
throw std::logic_error("Can't find this key");
}
return std::any_cast<T>(*(data_.at(std::string(k))));
}

const std::map<std::string, ValuePtr>&
Expand Down
7 changes: 5 additions & 2 deletions knowhere/index/vector_index/ConfAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ static const std::vector<MetricType> default_binary_metric_array{metric::HAMMING
template<typename T>
inline bool
CheckValueInRange(const Config& cfg, const std::string_view& key, T min, T max) {
T value = GetValueFromConfig<T>(cfg, std::string(key));
return (value >= min && value <= max);
if (cfg.contains(std::string(key))) {
T value = GetValueFromConfig<T>(cfg, key);
return (value >= min && value <= max);
}
return false;
}

inline bool
Expand Down
5 changes: 1 addition & 4 deletions knowhere/index/vector_index/IndexAnnoy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config, const fai
}
}

auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
return GenResultDataset(p_id, p_dist);
}

int64_t
Expand Down
12 changes: 2 additions & 10 deletions knowhere/index/vector_index/IndexBinaryIDMAP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,7 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const fa

QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config, bitset);

auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
return GenResultDataset(p_id, p_dist);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
Expand Down Expand Up @@ -106,12 +103,7 @@ BinaryIDMAP::QueryByRange(const DatasetPtr& dataset,

try {
QueryByRangeImpl(rows, reinterpret_cast<const uint8_t*>(p_data), radius, p_dist, p_id, p_lims, config, bitset);

auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
ret_ds->Set(meta::LIMS, p_lims);
return ret_ds;
return GenResultDataset(p_id, p_dist, p_lims);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
Expand Down
12 changes: 2 additions & 10 deletions knowhere/index/vector_index/IndexBinaryIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais

QueryImpl(rows, reinterpret_cast<const uint8_t*>(p_data), k, p_dist, p_id, config, bitset);

auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
return GenResultDataset(p_id, p_dist);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
Expand Down Expand Up @@ -119,12 +116,7 @@ BinaryIVF::QueryByRange(const DatasetPtr& dataset,

try {
QueryByRangeImpl(rows, reinterpret_cast<const uint8_t*>(p_data), radius, p_dist, p_id, p_lims, config, bitset);

auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
ret_ds->Set(meta::LIMS, p_lims);
return ret_ds;
return GenResultDataset(p_id, p_dist, p_lims);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
Expand Down
5 changes: 1 addition & 4 deletions knowhere/index/vector_index/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,7 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fais
// LOG_KNOWHERE_DEBUG_ << "IndexHNSW::Query finished, show statistics:";
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();

auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
return GenResultDataset(p_id, p_dist);
}

int64_t
Expand Down
12 changes: 2 additions & 10 deletions knowhere/index/vector_index/IndexIDMAP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,7 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::B

QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config, bitset);

auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
return GenResultDataset(p_id, p_dist);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
Expand Down Expand Up @@ -139,12 +136,7 @@ IDMAP::QueryByRange(const DatasetPtr& dataset,

try {
QueryByRangeImpl(rows, reinterpret_cast<const float*>(p_data), radius, p_dist, p_id, p_lims, config, bitset);

auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
ret_ds->Set(meta::LIMS, p_lims);
return ret_ds;
return GenResultDataset(p_id, p_dist, p_lims);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
Expand Down
12 changes: 2 additions & 10 deletions knowhere/index/vector_index/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,7 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::Bit

QueryImpl(rows, reinterpret_cast<const float*>(p_data), k, p_dist, p_id, config, bitset);

auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
return GenResultDataset(p_id, p_dist);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
Expand Down Expand Up @@ -162,12 +159,7 @@ IVF::QueryByRange(const DatasetPtr& dataset,

try {
QueryByRangeImpl(rows, reinterpret_cast<const float*>(p_data), radius, p_dist, p_id, p_lims, config, bitset);

auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
ret_ds->Set(meta::LIMS, p_lims);
return ret_ds;
return GenResultDataset(p_id, p_dist, p_lims);
} catch (faiss::FaissException& e) {
release_when_exception();
KNOWHERE_THROW_MSG(e.what());
Expand Down
5 changes: 1 addition & 4 deletions knowhere/index/vector_index/IndexNGT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,7 @@ IndexNGT::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss
index_->deleteObject(object);
}

auto res_ds = std::make_shared<Dataset>();
res_ds->Set(meta::IDS, p_id);
res_ds->Set(meta::DISTANCE, p_dist);
return res_ds;
return GenResultDataset(p_id, p_dist);
}

int64_t
Expand Down
5 changes: 1 addition & 4 deletions knowhere/index/vector_index/IndexRHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ IndexRHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, const fai
// LOG_KNOWHERE_DEBUG_ << "IndexRHNSW::Load finished, show statistics:";
// LOG_KNOWHERE_DEBUG_ << GetStatistics()->ToString();

auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
return GenResultDataset(p_id, p_dist);
}

int64_t
Expand Down
2 changes: 1 addition & 1 deletion knowhere/index/vector_index/IndexSPTAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ DatasetPtr
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config, const faiss::BitsetView bitset) {
SetParameters(config);

float* p_data = (float*)dataset_ptr->Get<const void*>(meta::TENSOR);
float* p_data = (float*)GetDatasetTensor(dataset_ptr);
for (auto i = 0; i < 10; ++i) {
for (auto j = 0; j < 10; ++j) {
std::cout << p_data[i * 10 + j] << " ";
Expand Down
7 changes: 2 additions & 5 deletions knowhere/index/vector_index/adapter/SptagAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace knowhere {

std::shared_ptr<SPTAG::MetadataSet>
ConvertToMetadataSet(const DatasetPtr& dataset_ptr) {
auto elems = dataset_ptr->Get<int64_t>(meta::ROWS);
auto elems = GetDatasetRows(dataset_ptr);

auto p_id = new int64_t[elems];
for (int64_t i = 0; i < elems; ++i) p_id[i] = i;
Expand Down Expand Up @@ -81,10 +81,7 @@ ConvertToDataset(std::vector<SPTAG::QueryResult> query_results, std::shared_ptr<
}
}

auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::IDS, p_id);
ret_ds->Set(meta::DISTANCE, p_dist);
return ret_ds;
return GenResultDataset(p_id, p_dist);
}

} // namespace knowhere
23 changes: 20 additions & 3 deletions knowhere/index/vector_index/adapter/VectorAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,26 @@ namespace knowhere {
DatasetPtr
GenDataset(const int64_t nb, const int64_t dim, const void* xb) {
auto ret_ds = std::make_shared<Dataset>();
ret_ds->Set(meta::ROWS, nb);
ret_ds->Set(meta::DIM, dim);
ret_ds->Set(meta::TENSOR, xb);
SetDatasetRows(ret_ds, nb);
SetDatasetDim(ret_ds, dim);
SetDatasetTensor(ret_ds, xb);
return ret_ds;
}

DatasetPtr
GenResultDataset(const int64_t* ids, const float* distance) {
auto ret_ds = std::make_shared<Dataset>();
SetDatasetIDs(ret_ds, ids);
SetDatasetDistance(ret_ds, distance);
return ret_ds;
}

DatasetPtr
GenResultDataset(const int64_t* ids, const float* distance, const size_t* lims) {
auto ret_ds = std::make_shared<Dataset>();
SetDatasetIDs(ret_ds, ids);
SetDatasetDistance(ret_ds, distance);
SetDatasetLims(ret_ds, lims);
return ret_ds;
}

Expand Down
50 changes: 44 additions & 6 deletions knowhere/index/vector_index/adapter/VectorAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,53 @@

namespace knowhere {

#define GET_TENSOR_DATA(dataset_ptr) \
int64_t rows = dataset_ptr->Get<int64_t>(meta::ROWS); \
const void* p_data = dataset_ptr->Get<const void*>(meta::TENSOR);
#define DEFINE_DATASET_GETTER(func_name, key, T) \
inline T func_name(const DatasetPtr& ds_ptr) { \
return ds_ptr->Get<T>(key); \
}

#define GET_TENSOR_DATA_DIM(dataset_ptr) \
GET_TENSOR_DATA(dataset_ptr) \
int64_t dim = dataset_ptr->Get<int64_t>(meta::DIM);
#define DEFINE_DATASET_SETTER(func_name, key, T) \
inline void func_name(DatasetPtr& ds_ptr, T value) { \
ds_ptr->Set(key, value); \
}

///////////////////////////////////////////////////////////////////////////////

DEFINE_DATASET_GETTER(GetDatasetDim, meta::DIM, int64_t);
DEFINE_DATASET_SETTER(SetDatasetDim, meta::DIM, int64_t);

DEFINE_DATASET_GETTER(GetDatasetTensor, meta::TENSOR, const void*);
DEFINE_DATASET_SETTER(SetDatasetTensor, meta::TENSOR, const void*);

DEFINE_DATASET_GETTER(GetDatasetRows, meta::ROWS, int64_t);
DEFINE_DATASET_SETTER(SetDatasetRows, meta::ROWS, int64_t);

DEFINE_DATASET_GETTER(GetDatasetIDs, meta::IDS, const int64_t*);
DEFINE_DATASET_SETTER(SetDatasetIDs, meta::IDS, const int64_t*);

DEFINE_DATASET_GETTER(GetDatasetDistance, meta::DISTANCE, const float*);
DEFINE_DATASET_SETTER(SetDatasetDistance, meta::DISTANCE, const float*);

DEFINE_DATASET_GETTER(GetDatasetLims, meta::LIMS, const size_t*);
DEFINE_DATASET_SETTER(SetDatasetLims, meta::LIMS, const size_t*);

///////////////////////////////////////////////////////////////////////////////

#define GET_TENSOR_DATA(ds_ptr) \
auto rows = GetDatasetRows(ds_ptr); \
auto p_data = GetDatasetTensor(ds_ptr);

#define GET_TENSOR_DATA_DIM(ds_ptr) \
GET_TENSOR_DATA(ds_ptr) \
auto dim = GetDatasetDim(ds_ptr);

extern DatasetPtr
GenDataset(const int64_t nb, const int64_t dim, const void* xb);

extern DatasetPtr
GenResultDataset(const int64_t* ids, const float* distance);

extern DatasetPtr
GenResultDataset(const int64_t* ids, const float* distance, const size_t* lims);

} // namespace knowhere
Loading

0 comments on commit 9f86cd6

Please sign in to comment.