From cae8d1ffc3be4c6a5afb3228cf5b8db02a597c8c Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Tue, 17 Dec 2024 14:36:38 +0800 Subject: [PATCH] Optimize feature registry (#990) Signed-off-by: Cai Yudong --- include/knowhere/index/index_factory.h | 44 ++++++++++---------------- src/index/index_factory.cc | 2 +- tests/ut/test_index_check.cc | 14 +++++++- 3 files changed, 31 insertions(+), 29 deletions(-) diff --git a/include/knowhere/index/index_factory.h b/include/knowhere/index/index_factory.h index 029ea08da..332a57093 100644 --- a/include/knowhere/index/index_factory.h +++ b/include/knowhere/index/index_factory.h @@ -103,11 +103,11 @@ class IndexFactory { // Please review carefully and select with caution // register vector index supporting ALL_TYPE(binary, bf16, fp16, fp32, sparse_float32) data types -#define KNOWHERE_SIMPLE_REGISTER_ALL_GLOBAL(name, index_node, features, ...) \ - KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::ALL_TYPE), ##__VA_ARGS__); \ - KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::ALL_TYPE), ##__VA_ARGS__); \ - KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::ALL_TYPE), ##__VA_ARGS__); \ - KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::ALL_TYPE), ##__VA_ARGS__); +#define KNOWHERE_SIMPLE_REGISTER_ALL_GLOBAL(name, index_node, features, ...) \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::FP16), ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__); // register vector index supporting sparse_float32 #define KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(name, index_node, features, ...) \ @@ -115,35 +115,25 @@ class IndexFactory { ##__VA_ARGS__); // register vector index supporting ALL_DENSE_TYPE(binary, bf16, fp16, fp32) data types -#define KNOWHERE_SIMPLE_REGISTER_DENSE_ALL_GLOBAL(name, index_node, features, ...) \ - KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::ALL_DENSE_TYPE), \ - ##__VA_ARGS__); \ - KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::ALL_DENSE_TYPE), \ - ##__VA_ARGS__); \ - KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::ALL_DENSE_TYPE), \ - ##__VA_ARGS__); \ - KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::ALL_DENSE_TYPE), \ - ##__VA_ARGS__); +#define KNOWHERE_SIMPLE_REGISTER_DENSE_ALL_GLOBAL(name, index_node, features, ...) \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::FP16), ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__); // register vector index supporting binary data type #define KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(name, index_node, features, ...) \ KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__); // register vector index supporting ALL_DENSE_FLOAT_TYPE(float32, bf16, fp16) data types #define KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \ - KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \ - ##__VA_ARGS__); \ - KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \ - ##__VA_ARGS__); \ - KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \ - ##__VA_ARGS__); + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::FP16), ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__); // register vector index supporting ALL_DENSE_FLOAT_TYPE(float32, bf16, fp16) data types, but mocked bf16 and fp16 -#define KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \ - KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \ - ##__VA_ARGS__); \ - KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \ - ##__VA_ARGS__); \ - KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::ALL_DENSE_FLOAT_TYPE), \ - ##__VA_ARGS__); +#define KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \ + KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, bf16, (features | knowhere::feature::BF16), ##__VA_ARGS__); \ + KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, fp16, (features | knowhere::feature::FP16), ##__VA_ARGS__); \ + KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, fp32, (features | knowhere::feature::FLOAT32), ##__VA_ARGS__); #define KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(name, index_node, data_type, features, thread_size) \ KNOWHERE_REGISTER_STATIC(name, index_node, data_type) \ diff --git a/src/index/index_factory.cc b/src/index/index_factory.cc index 7e9f8086b..621786116 100644 --- a/src/index/index_factory.cc +++ b/src/index/index_factory.cc @@ -84,7 +84,7 @@ IndexFactory::Register(const std::string& name, std::function(c feature_mapping_[name] = features; } else { // All data types should have the same features; please try to avoid breaking this rule. - feature_mapping_[name] = feature_mapping_[name] & features; + feature_mapping_[name] |= features; } return *this; } diff --git a/tests/ut/test_index_check.cc b/tests/ut/test_index_check.cc index 39dd73189..9c5a9ff39 100644 --- a/tests/ut/test_index_check.cc +++ b/tests/ut/test_index_check.cc @@ -332,11 +332,11 @@ TEST_CASE("Test index feature check", "[IndexFeatureCheck]") { REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_FAISS_BIN_IVFFLAT, knowhere::feature::SPARSE_FLOAT32)); + // HNSW Index REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::FLOAT32)); REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::FP16)); REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::BF16)); - // HNSW Index #ifdef KNOWHERE_WITH_CARDINAL REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::BINARY)); REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::SPARSE_FLOAT32)); @@ -344,6 +344,18 @@ TEST_CASE("Test index feature check", "[IndexFeatureCheck]") { REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::SPARSE_FLOAT32)); #endif + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::FLOAT32)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::FP16)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::BF16)); + + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::FLOAT32)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::FP16)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::BF16)); + + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::FLOAT32)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::FP16)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::BF16)); + // Sparse Indexes REQUIRE_FALSE( IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_SPARSE_INVERTED_INDEX, knowhere::feature::FLOAT32));