diff --git a/knowhere/archive/BruteForce.cpp b/knowhere/archive/BruteForce.cpp index 86d5a9673..19cd9603f 100644 --- a/knowhere/archive/BruteForce.cpp +++ b/knowhere/archive/BruteForce.cpp @@ -15,11 +15,13 @@ // limitations under the License. #include +#include #include "faiss/utils/BinaryDistance.h" #include "faiss/utils/distances.h" #include "knowhere/archive/BruteForce.h" #include "knowhere/common/Exception.h" +#include "knowhere/common/ThreadPool.h" #include "knowhere/common/Log.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" #include "knowhere/index/vector_index/helpers/RangeUtil.h" @@ -29,9 +31,7 @@ namespace knowhere { /** knowhere wrapper API to call faiss brute force search for all metric types */ DatasetPtr -BruteForce::Search(const DatasetPtr base_dataset, - const DatasetPtr query_dataset, - const Config& config, +BruteForce::Search(const DatasetPtr base_dataset, const DatasetPtr query_dataset, const Config& config, const faiss::BitsetView bitset) { auto xb = GetDatasetTensor(base_dataset); auto nb = GetDatasetRows(base_dataset); @@ -48,50 +48,54 @@ BruteForce::Search(const DatasetPtr base_dataset, auto labels = new int64_t[nq * topk]; auto distances = new float[nq * topk]; - switch (faiss_metric_type) { - case faiss::METRIC_L2: { - faiss::float_maxheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances}; - faiss::knn_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, &buf, nullptr, bitset); - break; - } - case faiss::METRIC_INNER_PRODUCT: { - faiss::float_minheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances}; - faiss::knn_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, &buf, bitset); - break; - } - case faiss::METRIC_Jaccard: - case faiss::METRIC_Tanimoto: { - faiss::float_maxheap_array_t res = {size_t(nq), size_t(topk), labels, distances}; - binary_distance_knn_hc(faiss::METRIC_Jaccard, &res, (const uint8_t*)xq, (const uint8_t*)xb, nb, dim / 8, - bitset); - - if (faiss_metric_type == faiss::METRIC_Tanimoto) { - for (int i = 0; i < topk * nq; i++) { - distances[i] = faiss::Jaccard_2_Tanimoto(distances[i]); + auto pool = ThreadPool::GetGlobalThreadPool(); + auto future = pool->push([&] { + switch (faiss_metric_type) { + case faiss::METRIC_L2: { + faiss::float_maxheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances}; + faiss::knn_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, &buf, nullptr, bitset); + break; + } + case faiss::METRIC_INNER_PRODUCT: { + faiss::float_minheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances}; + faiss::knn_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, &buf, bitset); + break; + } + case faiss::METRIC_Jaccard: + case faiss::METRIC_Tanimoto: { + faiss::float_maxheap_array_t res = {size_t(nq), size_t(topk), labels, distances}; + binary_distance_knn_hc(faiss::METRIC_Jaccard, &res, (const uint8_t*)xq, (const uint8_t*)xb, nb, dim / 8, + bitset); + + if (faiss_metric_type == faiss::METRIC_Tanimoto) { + for (int i = 0; i < topk * nq; i++) { + distances[i] = faiss::Jaccard_2_Tanimoto(distances[i]); + } } + break; } - break; - } - case faiss::METRIC_Hamming: { - std::vector int_distances(nq * topk); - faiss::int_maxheap_array_t res = {size_t(nq), size_t(topk), labels, int_distances.data()}; - binary_distance_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)xq, (const uint8_t*)xb, nb, dim / 8, - bitset); - for (int i = 0; i < nq * topk; ++i) { - distances[i] = int_distances[i]; + case faiss::METRIC_Hamming: { + std::vector int_distances(nq * topk); + faiss::int_maxheap_array_t res = {size_t(nq), size_t(topk), labels, int_distances.data()}; + binary_distance_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)xq, (const uint8_t*)xb, nb, dim / 8, + bitset); + for (int i = 0; i < nq * topk; ++i) { + distances[i] = int_distances[i]; + } + break; } - break; - } - case faiss::METRIC_Substructure: - case faiss::METRIC_Superstructure: { - // only matched ids will be chosen, not to use heap - binary_distance_knn_mc(faiss_metric_type, (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, topk, dim / 8, - distances, labels, bitset); - break; + case faiss::METRIC_Substructure: + case faiss::METRIC_Superstructure: { + // only matched ids will be chosen, not to use heap + binary_distance_knn_mc(faiss_metric_type, (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, topk, dim / 8, + distances, labels, bitset); + break; + } + default: + KNOWHERE_THROW_MSG("BruteForce search not support metric type: " + metric_type); } - default: - KNOWHERE_THROW_MSG("BruteForce search not support metric type: " + metric_type); - } + }); + future.get(); return GenResultDataset(labels, distances); } @@ -117,29 +121,37 @@ BruteForce::RangeSearch(const DatasetPtr base_dataset, faiss::RangeSearchResult res(nq); - switch (faiss_metric_type) { - case faiss::METRIC_L2: - faiss::range_search_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, bitset); - break; - case faiss::METRIC_INNER_PRODUCT: - is_ip = true; - faiss::range_search_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, bitset); - break; - case faiss::METRIC_Jaccard: - faiss::binary_range_search, float>(faiss::METRIC_Jaccard, - (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, radius, dim / 8, &res, bitset); - break; - case faiss::METRIC_Tanimoto: - faiss::binary_range_search, float>(faiss::METRIC_Tanimoto, - (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, radius, dim / 8, &res, bitset); - break; - case faiss::METRIC_Hamming: - faiss::binary_range_search, int>(faiss::METRIC_Hamming, - (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, (int)radius, dim / 8, &res, bitset); - break; - default: - KNOWHERE_THROW_MSG("BruteForce range search not support metric type: " + metric_type); - } + auto pool = ThreadPool::GetGlobalThreadPool(); + auto future = pool->push([&] { + switch (faiss_metric_type) { + case faiss::METRIC_L2: + faiss::range_search_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, bitset); + break; + case faiss::METRIC_INNER_PRODUCT: + is_ip = true; + faiss::range_search_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, + bitset); + break; + case faiss::METRIC_Jaccard: + faiss::binary_range_search, float>(faiss::METRIC_Jaccard, + (const uint8_t*)xq, (const uint8_t*)xb, + nq, nb, radius, dim / 8, &res, bitset); + break; + case faiss::METRIC_Tanimoto: + faiss::binary_range_search, float>(faiss::METRIC_Tanimoto, + (const uint8_t*)xq, (const uint8_t*)xb, + nq, nb, radius, dim / 8, &res, bitset); + break; + case faiss::METRIC_Hamming: + faiss::binary_range_search, int>(faiss::METRIC_Hamming, (const uint8_t*)xq, + (const uint8_t*)xb, nq, nb, (int)radius, + dim / 8, &res, bitset); + break; + default: + KNOWHERE_THROW_MSG("BruteForce range search not support metric type: " + metric_type); + } + }); + future.get(); float* distances = nullptr; int64_t* labels = nullptr; diff --git a/thirdparty/ctpl/ctpl-std.h b/thirdparty/ctpl/ctpl-std.h index 5956cf095..38d3edf5b 100644 --- a/thirdparty/ctpl/ctpl-std.h +++ b/thirdparty/ctpl/ctpl-std.h @@ -20,6 +20,7 @@ #ifndef __ctpl_stl_thread_pool_H__ #define __ctpl_stl_thread_pool_H__ +#include #include #include #include @@ -209,6 +210,7 @@ namespace ctpl { void set_thread(int i) { std::shared_ptr> flag(this->flags[i]); // a copy of the shared ptr to the flag auto f = [this, i, flag/* a copy of the shared ptr to the flag */]() { + omp_set_num_threads(1); std::atomic & _flag = *flag; std::function * _f; bool isPop = this->q.pop(_f);