Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Check param data type (#173)
Browse files Browse the repository at this point in the history
Signed-off-by: yudong.cai <[email protected]>
  • Loading branch information
cydrain authored May 10, 2022
1 parent 9f86cd6 commit fb1027e
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 56 deletions.
107 changes: 60 additions & 47 deletions knowhere/index/vector_index/ConfAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,22 @@ static const std::vector<MetricType> default_binary_metric_array{metric::HAMMING
metric::TANIMOTO, metric::SUBSTRUCTURE,
metric::SUPERSTRUCTURE};

template<typename T>
inline bool
CheckValueInRange(const Config& cfg, const std::string_view& key, T min, T max) {
if (cfg.contains(std::string(key))) {
T value = GetValueFromConfig<T>(cfg, key);
return (value >= min && value <= max);
CheckIntegerRange(const Config& cfg, const std::string_view& key, int64_t min, int64_t max) {
if (!cfg.contains(std::string(key)) || !cfg[std::string(key)].is_number_integer()) {
return false;
}
int64_t value = GetValueFromConfig<int64_t>(cfg, key);
return (value >= min && value <= max);
}

inline bool
CheckFloatRange(const Config& cfg, const std::string_view& key, float min, float max) {
if (!cfg.contains(std::string(key)) || !cfg[std::string(key)].is_number_float()) {
return false;
}
return false;
float value = GetValueFromConfig<float>(cfg, key);
return (value >= min && value <= max);
}

inline bool
Expand All @@ -68,7 +76,7 @@ ConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
if (!CheckMetricType(cfg, default_metric_array)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM)) {
if (!CheckIntegerRange(cfg, meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM)) {
return false;
}
return true;
Expand All @@ -78,7 +86,7 @@ bool
ConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) {
const int64_t DEFAULT_MIN_K = 1;
const int64_t DEFAULT_MAX_K = 16384;
return CheckValueInRange<int64_t>(cfg, meta::TOPK, DEFAULT_MIN_K - 1, DEFAULT_MAX_K);
return CheckIntegerRange(cfg, meta::TOPK, DEFAULT_MIN_K - 1, DEFAULT_MAX_K);
}

int64_t
Expand Down Expand Up @@ -113,7 +121,7 @@ MatchNbits(int64_t size, int64_t nbits) {

bool
IVFConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::NLIST, MIN_NLIST, MAX_NLIST)) {
if (!CheckIntegerRange(cfg, indexparam::NLIST, MIN_NLIST, MAX_NLIST)) {
return false;
}

Expand All @@ -133,7 +141,7 @@ IVFConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode m
max_nprobe = faiss::gpu::getMaxKSelection();
}
#endif
if (!CheckValueInRange<int64_t>(cfg, indexparam::NPROBE, MIN_NPROBE, max_nprobe)) {
if (!CheckIntegerRange(cfg, indexparam::NPROBE, MIN_NPROBE, max_nprobe)) {
return false;
}

Expand All @@ -151,7 +159,7 @@ IVFPQConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
if (!IVFConfAdapter::CheckTrain(cfg, mode)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::NBITS, MIN_NBITS, MAX_NBITS)) {
if (!CheckIntegerRange(cfg, indexparam::NBITS, MIN_NBITS, MAX_NBITS)) {
return false;
}

Expand Down Expand Up @@ -209,16 +217,16 @@ IVFPQConfAdapter::CheckCPUPQParams(int64_t dimension, int64_t m) {
bool
IVFHNSWConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
// HNSW param check
if (!CheckValueInRange<int64_t>(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION,
if (!CheckIntegerRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION,
HNSW_MAX_EFCONSTRUCTION)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) {
if (!CheckIntegerRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) {
return false;
}

// IVF param check
if (!CheckValueInRange<int64_t>(cfg, indexparam::NLIST, MIN_NLIST, MAX_NLIST)) {
if (!CheckIntegerRange(cfg, indexparam::NLIST, MIN_NLIST, MAX_NLIST)) {
return false;
}

Expand All @@ -233,12 +241,12 @@ IVFHNSWConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
bool
IVFHNSWConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) {
// HNSW param check
if (!CheckValueInRange<int64_t>(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) {
if (!CheckIntegerRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) {
return false;
}

// IVF param check
if (!CheckValueInRange<int64_t>(cfg, indexparam::NPROBE, MIN_NPROBE, MAX_NPROBE)) {
if (!CheckIntegerRange(cfg, indexparam::NPROBE, MIN_NPROBE, MAX_NPROBE)) {
return false;
}

Expand All @@ -247,51 +255,51 @@ IVFHNSWConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMo

bool
HNSWConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION,
if (!CheckIntegerRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION,
HNSW_MAX_EFCONSTRUCTION)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) {
if (!CheckIntegerRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) {
return false;
}
return ConfAdapter::CheckTrain(cfg, mode);
}

bool
HNSWConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) {
if (!CheckIntegerRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) {
return false;
}
return ConfAdapter::CheckSearch(cfg, type, mode);
}

bool
RHNSWFlatConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION,
if (!CheckIntegerRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION,
HNSW_MAX_EFCONSTRUCTION)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) {
if (!CheckIntegerRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) {
return false;
}
return ConfAdapter::CheckTrain(cfg, mode);
}

bool
RHNSWFlatConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) {
if (!CheckIntegerRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) {
return false;
}
return ConfAdapter::CheckSearch(cfg, type, mode);
}

bool
RHNSWPQConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION,
if (!CheckIntegerRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION,
HNSW_MAX_EFCONSTRUCTION)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) {
if (!CheckIntegerRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) {
return false;
}

Expand All @@ -304,27 +312,27 @@ RHNSWPQConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {

bool
RHNSWPQConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) {
if (!CheckIntegerRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) {
return false;
}
return ConfAdapter::CheckSearch(cfg, type, mode);
}

bool
RHNSWSQConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION,
if (!CheckIntegerRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION,
HNSW_MAX_EFCONSTRUCTION)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) {
if (!CheckIntegerRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) {
return false;
}
return ConfAdapter::CheckTrain(cfg, mode);
}

bool
RHNSWSQConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) {
if (!CheckIntegerRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) {
return false;
}
return ConfAdapter::CheckSearch(cfg, type, mode);
Expand All @@ -335,7 +343,7 @@ BinIDMAPConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
if (!CheckMetricType(cfg, default_binary_metric_array)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM)) {
if (!CheckIntegerRange(cfg, meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM)) {
return false;
}
return true;
Expand All @@ -348,10 +356,10 @@ BinIVFConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
if (!CheckMetricType(cfg, metric_array)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM)) {
if (!CheckIntegerRange(cfg, meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::NLIST, MIN_NLIST, MAX_NLIST)) {
if (!CheckIntegerRange(cfg, indexparam::NLIST, MIN_NLIST, MAX_NLIST)) {
return false;
}

Expand All @@ -369,28 +377,33 @@ ANNOYConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
// too large of n_trees takes much time, if there is real requirement, change this threshold.
static int64_t MAX_N_TREES = 1024;

if (!CheckValueInRange<int64_t>(cfg, indexparam::N_TREES, MIN_N_TREES, MAX_N_TREES)) {
if (!CheckIntegerRange(cfg, indexparam::N_TREES, MIN_N_TREES, MAX_N_TREES)) {
return false;
}
return ConfAdapter::CheckTrain(cfg, mode);
}

bool
ANNOYConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) {
static int64_t MIN_SEARCH_K = std::numeric_limits<int64_t>::min();
static int64_t MAX_SEARCH_K = std::numeric_limits<int64_t>::max();
if (!CheckIntegerRange(cfg, indexparam::SEARCH_K, MIN_SEARCH_K, MAX_SEARCH_K)) {
return false;
}
return ConfAdapter::CheckSearch(cfg, type, mode);
}

#ifdef KNOWHERE_SUPPORT_NGT
bool
NGTPANNGConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) {
if (!CheckIntegerRange(cfg, indexparam::EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::FORCEDLY_PRUNED_EDGE_SIZE, NGT_MIN_EDGE_SIZE,
if (!CheckIntegerRange(cfg, indexparam::FORCEDLY_PRUNED_EDGE_SIZE, NGT_MIN_EDGE_SIZE,
NGT_MAX_EDGE_SIZE)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::SELECTIVELY_PRUNED_EDGE_SIZE, NGT_MIN_EDGE_SIZE,
if (!CheckIntegerRange(cfg, indexparam::SELECTIVELY_PRUNED_EDGE_SIZE, NGT_MIN_EDGE_SIZE,
NGT_MAX_EDGE_SIZE)) {
return false;
}
Expand All @@ -403,35 +416,35 @@ NGTPANNGConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {

bool
NGTPANNGConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::MAX_SEARCH_EDGES, -1, NGT_MAX_EDGE_SIZE)) {
if (!CheckIntegerRange(cfg, indexparam::MAX_SEARCH_EDGES, -1, NGT_MAX_EDGE_SIZE)) {
return false;
}
if (!CheckValueInRange<float>(cfg, indexparam::EPSILON, -1.0, 1.0)) {
if (!CheckFloatRange(cfg, indexparam::EPSILON, -1.0, 1.0)) {
return false;
}
return ConfAdapter::CheckSearch(cfg, type, mode);
}

bool
NGTONNGConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) {
if (!CheckIntegerRange(cfg, indexparam::EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::OUTGOING_EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) {
if (!CheckIntegerRange(cfg, indexparam::OUTGOING_EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::INCOMING_EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) {
if (!CheckIntegerRange(cfg, indexparam::INCOMING_EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) {
return false;
}
return ConfAdapter::CheckTrain(cfg, mode);
}

bool
NGTONNGConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) {
if (!CheckValueInRange<int64_t>(cfg, indexparam::MAX_SEARCH_EDGES, -1, NGT_MAX_EDGE_SIZE)) {
if (!CheckIntegerRange(cfg, indexparam::MAX_SEARCH_EDGES, -1, NGT_MAX_EDGE_SIZE)) {
return false;
}
if (!CheckValueInRange<float>(cfg, indexparam::EPSILON, -1.0, 1.0)) {
if (!CheckFloatRange(cfg, indexparam::EPSILON, -1.0, 1.0)) {
return false;
}
return ConfAdapter::CheckSearch(cfg, type, mode);
Expand All @@ -453,16 +466,16 @@ NSGConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) {
if (!CheckMetricType(cfg, default_metric_array)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::KNNG, MIN_KNNG, MAX_KNNG)) {
if (!CheckIntegerRange(cfg, indexparam::KNNG, MIN_KNNG, MAX_KNNG)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::SEARCH_LENGTH, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH)) {
if (!CheckIntegerRange(cfg, indexparam::SEARCH_LENGTH, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::OUT_DEGREE, MIN_OUT_DEGREE, MAX_OUT_DEGREE)) {
if (!CheckIntegerRange(cfg, indexparam::OUT_DEGREE, MIN_OUT_DEGREE, MAX_OUT_DEGREE)) {
return false;
}
if (!CheckValueInRange<int64_t>(cfg, indexparam::CANDIDATE, MIN_CANDIDATE_POOL_SIZE, MAX_CANDIDATE_POOL_SIZE)) {
if (!CheckIntegerRange(cfg, indexparam::CANDIDATE, MIN_CANDIDATE_POOL_SIZE, MAX_CANDIDATE_POOL_SIZE)) {
return false;
}

Expand All @@ -480,7 +493,7 @@ NSGConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode m
static int64_t MIN_SEARCH_LENGTH = 1;
static int64_t MAX_SEARCH_LENGTH = 300;

if (!CheckValueInRange<int64_t>(cfg, indexparam::SEARCH_LENGTH, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH)) {
if (!CheckIntegerRange(cfg, indexparam::SEARCH_LENGTH, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH)) {
return false;
}
return ConfAdapter::CheckSearch(cfg, type, mode);
Expand Down
13 changes: 7 additions & 6 deletions unittest/test_annoy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@
// or implied. See the License for the specific language governing permissions and limitations under the License.

#include <gtest/gtest.h>
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include <iostream>
#include <sstream>

#include "knowhere/common/Exception.h"
#include "knowhere/index/vector_index/ConfAdapterMgr.h"
#include "knowhere/index/vector_index/IndexAnnoy.h"

#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "unittest/utils.h"

using ::testing::Combine;
Expand All @@ -27,7 +25,6 @@ class AnnoyTest : public DataGen, public TestWithParam<std::string> {
protected:
void
SetUp() override {
IndexType = GetParam();
Generate(128, 10000, 10);
index_ = std::make_shared<knowhere::IndexAnnoy>();
conf = knowhere::Config{
Expand All @@ -42,8 +39,9 @@ class AnnoyTest : public DataGen, public TestWithParam<std::string> {

protected:
knowhere::Config conf;
knowhere::IndexMode index_mode_ = knowhere::IndexMode::MODE_CPU;
knowhere::IndexType index_type_ = knowhere::IndexEnum::INDEX_ANNOY;
std::shared_ptr<knowhere::IndexAnnoy> index_ = nullptr;
std::string IndexType;
};

INSTANTIATE_TEST_CASE_P(AnnoyParameters, AnnoyTest, Values("Annoy"));
Expand All @@ -65,6 +63,9 @@ TEST_P(AnnoyTest, annoy_basic) {
ASSERT_EQ(index_->Count(), nb);
ASSERT_EQ(index_->Dim(), dim);

auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_);
ASSERT_TRUE(adapter->CheckSearch(conf, index_type_, index_mode_));

auto result = index_->Query(query_dataset, conf, nullptr);
AssertAnns(result, nq, k);
}
Expand Down
Loading

0 comments on commit fb1027e

Please sign in to comment.