diff --git a/knowhere/index/vector_index/ConfAdapter.cpp b/knowhere/index/vector_index/ConfAdapter.cpp index 121436754..6768d5a2d 100644 --- a/knowhere/index/vector_index/ConfAdapter.cpp +++ b/knowhere/index/vector_index/ConfAdapter.cpp @@ -47,14 +47,22 @@ static const std::vector default_binary_metric_array{metric::HAMMING metric::TANIMOTO, metric::SUBSTRUCTURE, metric::SUPERSTRUCTURE}; -template inline bool -CheckValueInRange(const Config& cfg, const std::string_view& key, T min, T max) { - if (cfg.contains(std::string(key))) { - T value = GetValueFromConfig(cfg, key); - return (value >= min && value <= max); +CheckIntegerRange(const Config& cfg, const std::string_view& key, int64_t min, int64_t max) { + if (!cfg.contains(std::string(key)) || !cfg[std::string(key)].is_number_integer()) { + return false; + } + int64_t value = GetValueFromConfig(cfg, key); + return (value >= min && value <= max); +} + +inline bool +CheckFloatRange(const Config& cfg, const std::string_view& key, float min, float max) { + if (!cfg.contains(std::string(key)) || !cfg[std::string(key)].is_number_float()) { + return false; } - return false; + float value = GetValueFromConfig(cfg, key); + return (value >= min && value <= max); } inline bool @@ -68,7 +76,7 @@ ConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { if (!CheckMetricType(cfg, default_metric_array)) { return false; } - if (!CheckValueInRange(cfg, meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM)) { + if (!CheckIntegerRange(cfg, meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM)) { return false; } return true; @@ -78,7 +86,7 @@ bool ConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) { const int64_t DEFAULT_MIN_K = 1; const int64_t DEFAULT_MAX_K = 16384; - return CheckValueInRange(cfg, meta::TOPK, DEFAULT_MIN_K - 1, DEFAULT_MAX_K); + return CheckIntegerRange(cfg, meta::TOPK, DEFAULT_MIN_K - 1, DEFAULT_MAX_K); } int64_t @@ -113,7 +121,7 @@ MatchNbits(int64_t size, int64_t nbits) { bool IVFConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::NLIST, MIN_NLIST, MAX_NLIST)) { + if (!CheckIntegerRange(cfg, indexparam::NLIST, MIN_NLIST, MAX_NLIST)) { return false; } @@ -133,7 +141,7 @@ IVFConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode m max_nprobe = faiss::gpu::getMaxKSelection(); } #endif - if (!CheckValueInRange(cfg, indexparam::NPROBE, MIN_NPROBE, max_nprobe)) { + if (!CheckIntegerRange(cfg, indexparam::NPROBE, MIN_NPROBE, max_nprobe)) { return false; } @@ -151,7 +159,7 @@ IVFPQConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { if (!IVFConfAdapter::CheckTrain(cfg, mode)) { return false; } - if (!CheckValueInRange(cfg, indexparam::NBITS, MIN_NBITS, MAX_NBITS)) { + if (!CheckIntegerRange(cfg, indexparam::NBITS, MIN_NBITS, MAX_NBITS)) { return false; } @@ -209,16 +217,16 @@ IVFPQConfAdapter::CheckCPUPQParams(int64_t dimension, int64_t m) { bool IVFHNSWConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { // HNSW param check - if (!CheckValueInRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION, + if (!CheckIntegerRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION)) { return false; } - if (!CheckValueInRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) { + if (!CheckIntegerRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) { return false; } // IVF param check - if (!CheckValueInRange(cfg, indexparam::NLIST, MIN_NLIST, MAX_NLIST)) { + if (!CheckIntegerRange(cfg, indexparam::NLIST, MIN_NLIST, MAX_NLIST)) { return false; } @@ -233,12 +241,12 @@ IVFHNSWConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { bool IVFHNSWConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) { // HNSW param check - if (!CheckValueInRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) { + if (!CheckIntegerRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) { return false; } // IVF param check - if (!CheckValueInRange(cfg, indexparam::NPROBE, MIN_NPROBE, MAX_NPROBE)) { + if (!CheckIntegerRange(cfg, indexparam::NPROBE, MIN_NPROBE, MAX_NPROBE)) { return false; } @@ -247,11 +255,11 @@ IVFHNSWConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMo bool HNSWConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION, + if (!CheckIntegerRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION)) { return false; } - if (!CheckValueInRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) { + if (!CheckIntegerRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) { return false; } return ConfAdapter::CheckTrain(cfg, mode); @@ -259,7 +267,7 @@ HNSWConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { bool HNSWConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) { + if (!CheckIntegerRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) { return false; } return ConfAdapter::CheckSearch(cfg, type, mode); @@ -267,11 +275,11 @@ HNSWConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode bool RHNSWFlatConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION, + if (!CheckIntegerRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION)) { return false; } - if (!CheckValueInRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) { + if (!CheckIntegerRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) { return false; } return ConfAdapter::CheckTrain(cfg, mode); @@ -279,7 +287,7 @@ RHNSWFlatConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { bool RHNSWFlatConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) { + if (!CheckIntegerRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) { return false; } return ConfAdapter::CheckSearch(cfg, type, mode); @@ -287,11 +295,11 @@ RHNSWFlatConfAdapter::CheckSearch(Config& cfg, const IndexType type, const Index bool RHNSWPQConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION, + if (!CheckIntegerRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION)) { return false; } - if (!CheckValueInRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) { + if (!CheckIntegerRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) { return false; } @@ -304,7 +312,7 @@ RHNSWPQConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { bool RHNSWPQConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) { + if (!CheckIntegerRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) { return false; } return ConfAdapter::CheckSearch(cfg, type, mode); @@ -312,11 +320,11 @@ RHNSWPQConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMo bool RHNSWSQConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION, + if (!CheckIntegerRange(cfg, indexparam::EFCONSTRUCTION, HNSW_MIN_EFCONSTRUCTION, HNSW_MAX_EFCONSTRUCTION)) { return false; } - if (!CheckValueInRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) { + if (!CheckIntegerRange(cfg, indexparam::HNSW_M, HNSW_MIN_M, HNSW_MAX_M)) { return false; } return ConfAdapter::CheckTrain(cfg, mode); @@ -324,7 +332,7 @@ RHNSWSQConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { bool RHNSWSQConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) { + if (!CheckIntegerRange(cfg, indexparam::EF, GetMetaTopk(cfg), HNSW_MAX_EF)) { return false; } return ConfAdapter::CheckSearch(cfg, type, mode); @@ -335,7 +343,7 @@ BinIDMAPConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { if (!CheckMetricType(cfg, default_binary_metric_array)) { return false; } - if (!CheckValueInRange(cfg, meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM)) { + if (!CheckIntegerRange(cfg, meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM)) { return false; } return true; @@ -348,10 +356,10 @@ BinIVFConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { if (!CheckMetricType(cfg, metric_array)) { return false; } - if (!CheckValueInRange(cfg, meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM)) { + if (!CheckIntegerRange(cfg, meta::DIM, DEFAULT_MIN_DIM, DEFAULT_MAX_DIM)) { return false; } - if (!CheckValueInRange(cfg, indexparam::NLIST, MIN_NLIST, MAX_NLIST)) { + if (!CheckIntegerRange(cfg, indexparam::NLIST, MIN_NLIST, MAX_NLIST)) { return false; } @@ -369,7 +377,7 @@ ANNOYConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { // too large of n_trees takes much time, if there is real requirement, change this threshold. static int64_t MAX_N_TREES = 1024; - if (!CheckValueInRange(cfg, indexparam::N_TREES, MIN_N_TREES, MAX_N_TREES)) { + if (!CheckIntegerRange(cfg, indexparam::N_TREES, MIN_N_TREES, MAX_N_TREES)) { return false; } return ConfAdapter::CheckTrain(cfg, mode); @@ -377,20 +385,25 @@ ANNOYConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { bool ANNOYConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) { + static int64_t MIN_SEARCH_K = std::numeric_limits::min(); + static int64_t MAX_SEARCH_K = std::numeric_limits::max(); + if (!CheckIntegerRange(cfg, indexparam::SEARCH_K, MIN_SEARCH_K, MAX_SEARCH_K)) { + return false; + } return ConfAdapter::CheckSearch(cfg, type, mode); } #ifdef KNOWHERE_SUPPORT_NGT bool NGTPANNGConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) { + if (!CheckIntegerRange(cfg, indexparam::EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) { return false; } - if (!CheckValueInRange(cfg, indexparam::FORCEDLY_PRUNED_EDGE_SIZE, NGT_MIN_EDGE_SIZE, + if (!CheckIntegerRange(cfg, indexparam::FORCEDLY_PRUNED_EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) { return false; } - if (!CheckValueInRange(cfg, indexparam::SELECTIVELY_PRUNED_EDGE_SIZE, NGT_MIN_EDGE_SIZE, + if (!CheckIntegerRange(cfg, indexparam::SELECTIVELY_PRUNED_EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) { return false; } @@ -403,10 +416,10 @@ NGTPANNGConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { bool NGTPANNGConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::MAX_SEARCH_EDGES, -1, NGT_MAX_EDGE_SIZE)) { + if (!CheckIntegerRange(cfg, indexparam::MAX_SEARCH_EDGES, -1, NGT_MAX_EDGE_SIZE)) { return false; } - if (!CheckValueInRange(cfg, indexparam::EPSILON, -1.0, 1.0)) { + if (!CheckFloatRange(cfg, indexparam::EPSILON, -1.0, 1.0)) { return false; } return ConfAdapter::CheckSearch(cfg, type, mode); @@ -414,13 +427,13 @@ NGTPANNGConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexM bool NGTONNGConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) { + if (!CheckIntegerRange(cfg, indexparam::EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) { return false; } - if (!CheckValueInRange(cfg, indexparam::OUTGOING_EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) { + if (!CheckIntegerRange(cfg, indexparam::OUTGOING_EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) { return false; } - if (!CheckValueInRange(cfg, indexparam::INCOMING_EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) { + if (!CheckIntegerRange(cfg, indexparam::INCOMING_EDGE_SIZE, NGT_MIN_EDGE_SIZE, NGT_MAX_EDGE_SIZE)) { return false; } return ConfAdapter::CheckTrain(cfg, mode); @@ -428,10 +441,10 @@ NGTONNGConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { bool NGTONNGConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode mode) { - if (!CheckValueInRange(cfg, indexparam::MAX_SEARCH_EDGES, -1, NGT_MAX_EDGE_SIZE)) { + if (!CheckIntegerRange(cfg, indexparam::MAX_SEARCH_EDGES, -1, NGT_MAX_EDGE_SIZE)) { return false; } - if (!CheckValueInRange(cfg, indexparam::EPSILON, -1.0, 1.0)) { + if (!CheckFloatRange(cfg, indexparam::EPSILON, -1.0, 1.0)) { return false; } return ConfAdapter::CheckSearch(cfg, type, mode); @@ -453,16 +466,16 @@ NSGConfAdapter::CheckTrain(Config& cfg, const IndexMode mode) { if (!CheckMetricType(cfg, default_metric_array)) { return false; } - if (!CheckValueInRange(cfg, indexparam::KNNG, MIN_KNNG, MAX_KNNG)) { + if (!CheckIntegerRange(cfg, indexparam::KNNG, MIN_KNNG, MAX_KNNG)) { return false; } - if (!CheckValueInRange(cfg, indexparam::SEARCH_LENGTH, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH)) { + if (!CheckIntegerRange(cfg, indexparam::SEARCH_LENGTH, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH)) { return false; } - if (!CheckValueInRange(cfg, indexparam::OUT_DEGREE, MIN_OUT_DEGREE, MAX_OUT_DEGREE)) { + if (!CheckIntegerRange(cfg, indexparam::OUT_DEGREE, MIN_OUT_DEGREE, MAX_OUT_DEGREE)) { return false; } - if (!CheckValueInRange(cfg, indexparam::CANDIDATE, MIN_CANDIDATE_POOL_SIZE, MAX_CANDIDATE_POOL_SIZE)) { + if (!CheckIntegerRange(cfg, indexparam::CANDIDATE, MIN_CANDIDATE_POOL_SIZE, MAX_CANDIDATE_POOL_SIZE)) { return false; } @@ -480,7 +493,7 @@ NSGConfAdapter::CheckSearch(Config& cfg, const IndexType type, const IndexMode m static int64_t MIN_SEARCH_LENGTH = 1; static int64_t MAX_SEARCH_LENGTH = 300; - if (!CheckValueInRange(cfg, indexparam::SEARCH_LENGTH, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH)) { + if (!CheckIntegerRange(cfg, indexparam::SEARCH_LENGTH, MIN_SEARCH_LENGTH, MAX_SEARCH_LENGTH)) { return false; } return ConfAdapter::CheckSearch(cfg, type, mode); diff --git a/unittest/test_annoy.cpp b/unittest/test_annoy.cpp index 50ec154ae..88608642b 100644 --- a/unittest/test_annoy.cpp +++ b/unittest/test_annoy.cpp @@ -10,13 +10,11 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include -#include "knowhere/index/vector_index/helpers/IndexParameter.h" -#include -#include #include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/ConfAdapterMgr.h" #include "knowhere/index/vector_index/IndexAnnoy.h" - +#include "knowhere/index/vector_index/helpers/IndexParameter.h" #include "unittest/utils.h" using ::testing::Combine; @@ -27,7 +25,6 @@ class AnnoyTest : public DataGen, public TestWithParam { protected: void SetUp() override { - IndexType = GetParam(); Generate(128, 10000, 10); index_ = std::make_shared(); conf = knowhere::Config{ @@ -42,8 +39,9 @@ class AnnoyTest : public DataGen, public TestWithParam { protected: knowhere::Config conf; + knowhere::IndexMode index_mode_ = knowhere::IndexMode::MODE_CPU; + knowhere::IndexType index_type_ = knowhere::IndexEnum::INDEX_ANNOY; std::shared_ptr index_ = nullptr; - std::string IndexType; }; INSTANTIATE_TEST_CASE_P(AnnoyParameters, AnnoyTest, Values("Annoy")); @@ -65,6 +63,9 @@ TEST_P(AnnoyTest, annoy_basic) { ASSERT_EQ(index_->Count(), nb); ASSERT_EQ(index_->Dim(), dim); + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_); + ASSERT_TRUE(adapter->CheckSearch(conf, index_type_, index_mode_)); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); } diff --git a/unittest/test_binaryidmap.cpp b/unittest/test_binaryidmap.cpp index ea283998e..90fe0d171 100644 --- a/unittest/test_binaryidmap.cpp +++ b/unittest/test_binaryidmap.cpp @@ -13,6 +13,7 @@ #include #include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/ConfAdapterMgr.h" #include "knowhere/index/vector_index/IndexBinaryIDMAP.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" #include "unittest/utils.h" @@ -38,6 +39,7 @@ class BinaryIDMAPTest : public DataGen, {knowhere::INDEX_FILE_SLICE_SIZE_IN_MEGABYTE, knowhere::index_file_slice_size}, }; index_mode_ = GetParam(); + index_type_ = knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP; index_ = std::make_shared(); } @@ -95,6 +97,7 @@ class BinaryIDMAPTest : public DataGen, knowhere::Config conf_; knowhere::BinaryIDMAPPtr index_ = nullptr; knowhere::IndexMode index_mode_; + knowhere::IndexType index_type_; }; INSTANTIATE_TEST_CASE_P( @@ -121,6 +124,10 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) { EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); ASSERT_TRUE(index_->GetRawVectors() != nullptr); + + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_); + ASSERT_TRUE(adapter->CheckSearch(conf_, index_type_, index_mode_)); + auto result = index_->Query(query_dataset, conf_, nullptr); AssertAnns(result, nq, k); // PrintResult(result, nq, k); diff --git a/unittest/test_binaryivf.cpp b/unittest/test_binaryivf.cpp index 3d11f4e05..47ecb66e2 100644 --- a/unittest/test_binaryivf.cpp +++ b/unittest/test_binaryivf.cpp @@ -15,6 +15,7 @@ #include "knowhere/common/Exception.h" #include "knowhere/common/Timer.h" +#include "knowhere/index/vector_index/ConfAdapterMgr.h" #include "knowhere/index/vector_index/IndexBinaryIVF.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" #include "unittest/Helper.h" @@ -41,6 +42,7 @@ class BinaryIVFTest : public DataGen, }; index_mode_ = GetParam(); + index_type_ = knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT; index_ = std::make_shared(); } @@ -64,6 +66,7 @@ class BinaryIVFTest : public DataGen, protected: knowhere::Config conf_; knowhere::IndexMode index_mode_; + knowhere::IndexType index_type_; knowhere::BinaryIVFIndexPtr index_ = nullptr; }; @@ -86,6 +89,9 @@ TEST_P(BinaryIVFTest, binaryivf_basic) { EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_); + ASSERT_TRUE(adapter->CheckSearch(conf_, index_type_, index_mode_)); + auto result = index_->Query(query_dataset, conf_, nullptr); AssertAnns(result, nq, knowhere::GetMetaTopk(conf_)); // PrintResult(result, nq, k); diff --git a/unittest/test_hnsw.cpp b/unittest/test_hnsw.cpp index 27130cc73..7915b2cd1 100644 --- a/unittest/test_hnsw.cpp +++ b/unittest/test_hnsw.cpp @@ -15,6 +15,7 @@ #include "knowhere/common/Config.h" #include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/ConfAdapterMgr.h" #include "knowhere/index/vector_index/IndexHNSW.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" #include "knowhere/index/vector_index/helpers/IndexParameter.h" @@ -28,8 +29,6 @@ class HNSWTest : public DataGen, public TestWithParam { protected: void SetUp() override { - IndexType = GetParam(); - std::cout << "IndexType from GetParam() is: " << IndexType << std::endl; Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10 index_ = std::make_shared(); conf = knowhere::Config{ @@ -41,8 +40,9 @@ class HNSWTest : public DataGen, public TestWithParam { protected: knowhere::Config conf; + knowhere::IndexMode index_mode_ = knowhere::IndexMode::MODE_CPU; + knowhere::IndexType index_type_ = knowhere::IndexEnum::INDEX_HNSW; std::shared_ptr index_ = nullptr; - std::string IndexType; }; INSTANTIATE_TEST_CASE_P(HNSWParameters, HNSWTest, Values("HNSW")); @@ -79,6 +79,9 @@ TEST_P(HNSWTest, HNSW_basic) { index_->Load(bs); + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_); + ASSERT_TRUE(adapter->CheckSearch(conf, index_type_, index_mode_)); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); diff --git a/unittest/test_idmap.cpp b/unittest/test_idmap.cpp index e44ba886d..aa5496975 100644 --- a/unittest/test_idmap.cpp +++ b/unittest/test_idmap.cpp @@ -16,6 +16,7 @@ #include "knowhere/common/Exception.h" #include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/ConfAdapterMgr.h" #include "knowhere/index/vector_index/IndexIDMAP.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" #ifdef KNOWHERE_GPU_VERSION @@ -42,6 +43,7 @@ class IDMAPTest : public DataGen, public TestWithParam { knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); #endif index_mode_ = GetParam(); + index_type_ = knowhere::IndexEnum::INDEX_FAISS_IDMAP; index_ = std::make_shared(); } @@ -143,6 +145,7 @@ class IDMAPTest : public DataGen, public TestWithParam { protected: knowhere::IDMAPPtr index_ = nullptr; knowhere::IndexMode index_mode_; + knowhere::IndexType index_type_; }; INSTANTIATE_TEST_CASE_P( @@ -176,6 +179,10 @@ TEST_P(IDMAPTest, idmap_basic) { EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); ASSERT_TRUE(index_->GetRawVectors() != nullptr); + + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_); + ASSERT_TRUE(adapter->CheckSearch(conf, index_type_, index_mode_)); + auto result = index_->Query(query_dataset, conf, nullptr); AssertAnns(result, nq, k); // PrintResult(result, nq, k); diff --git a/unittest/test_ivf.cpp b/unittest/test_ivf.cpp index d1162fbe3..91eca2fb2 100644 --- a/unittest/test_ivf.cpp +++ b/unittest/test_ivf.cpp @@ -20,6 +20,7 @@ #include "knowhere/common/Exception.h" #include "knowhere/index/IndexType.h" #include "knowhere/index/VecIndexFactory.h" +#include "knowhere/index/vector_index/ConfAdapterMgr.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" #ifdef KNOWHERE_GPU_VERSION @@ -102,6 +103,9 @@ TEST_P(IVFTest, ivf_basic) { EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_); + ASSERT_TRUE(adapter->CheckSearch(conf_, index_type_, index_mode_)); + auto result = index_->Query(query_dataset, conf_, nullptr); AssertAnns(result, nq, k); // PrintResult(result, nq, k); diff --git a/unittest/test_ivf_nm.cpp b/unittest/test_ivf_nm.cpp index d8ab6cdbe..60140aec0 100644 --- a/unittest/test_ivf_nm.cpp +++ b/unittest/test_ivf_nm.cpp @@ -21,6 +21,7 @@ #include "knowhere/common/Exception.h" #include "knowhere/index/IndexType.h" #include "knowhere/index/VecIndexFactory.h" +#include "knowhere/index/vector_index/ConfAdapterMgr.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" #ifdef KNOWHERE_GPU_VERSION @@ -115,6 +116,9 @@ TEST_P(IVFNMTest, ivfnm_basic) { LoadRawData(index_, base_dataset, conf_); + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_type_); + ASSERT_TRUE(adapter->CheckSearch(conf_, index_type_, index_mode_)); + auto result = index_->Query(query_dataset, conf_, nullptr); AssertAnns(result, nq, k);