Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Update omp dynamic (#662)
Browse files Browse the repository at this point in the history
Signed-off-by: zh Wang <[email protected]>
  • Loading branch information
hhy3 authored Feb 7, 2023
1 parent a5e4834 commit 4124cab
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 66 deletions.
144 changes: 78 additions & 66 deletions knowhere/archive/BruteForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
// limitations under the License.

#include <vector>
#include <omp.h>

#include "faiss/utils/BinaryDistance.h"
#include "faiss/utils/distances.h"
#include "knowhere/archive/BruteForce.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/ThreadPool.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/RangeUtil.h"
Expand All @@ -29,9 +31,7 @@ namespace knowhere {
/** knowhere wrapper API to call faiss brute force search for all metric types
*/
DatasetPtr
BruteForce::Search(const DatasetPtr base_dataset,
const DatasetPtr query_dataset,
const Config& config,
BruteForce::Search(const DatasetPtr base_dataset, const DatasetPtr query_dataset, const Config& config,
const faiss::BitsetView bitset) {
auto xb = GetDatasetTensor(base_dataset);
auto nb = GetDatasetRows(base_dataset);
Expand All @@ -48,50 +48,54 @@ BruteForce::Search(const DatasetPtr base_dataset,
auto labels = new int64_t[nq * topk];
auto distances = new float[nq * topk];

switch (faiss_metric_type) {
case faiss::METRIC_L2: {
faiss::float_maxheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances};
faiss::knn_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, &buf, nullptr, bitset);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
faiss::float_minheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances};
faiss::knn_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard:
case faiss::METRIC_Tanimoto: {
faiss::float_maxheap_array_t res = {size_t(nq), size_t(topk), labels, distances};
binary_distance_knn_hc(faiss::METRIC_Jaccard, &res, (const uint8_t*)xq, (const uint8_t*)xb, nb, dim / 8,
bitset);

if (faiss_metric_type == faiss::METRIC_Tanimoto) {
for (int i = 0; i < topk * nq; i++) {
distances[i] = faiss::Jaccard_2_Tanimoto(distances[i]);
auto pool = ThreadPool::GetGlobalThreadPool();
auto future = pool->push([&] {
switch (faiss_metric_type) {
case faiss::METRIC_L2: {
faiss::float_maxheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances};
faiss::knn_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, &buf, nullptr, bitset);
break;
}
case faiss::METRIC_INNER_PRODUCT: {
faiss::float_minheap_array_t buf{(size_t)nq, (size_t)topk, labels, distances};
faiss::knn_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, &buf, bitset);
break;
}
case faiss::METRIC_Jaccard:
case faiss::METRIC_Tanimoto: {
faiss::float_maxheap_array_t res = {size_t(nq), size_t(topk), labels, distances};
binary_distance_knn_hc(faiss::METRIC_Jaccard, &res, (const uint8_t*)xq, (const uint8_t*)xb, nb, dim / 8,
bitset);

if (faiss_metric_type == faiss::METRIC_Tanimoto) {
for (int i = 0; i < topk * nq; i++) {
distances[i] = faiss::Jaccard_2_Tanimoto(distances[i]);
}
}
break;
}
break;
}
case faiss::METRIC_Hamming: {
std::vector<int32_t> int_distances(nq * topk);
faiss::int_maxheap_array_t res = {size_t(nq), size_t(topk), labels, int_distances.data()};
binary_distance_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)xq, (const uint8_t*)xb, nb, dim / 8,
bitset);
for (int i = 0; i < nq * topk; ++i) {
distances[i] = int_distances[i];
case faiss::METRIC_Hamming: {
std::vector<int32_t> int_distances(nq * topk);
faiss::int_maxheap_array_t res = {size_t(nq), size_t(topk), labels, int_distances.data()};
binary_distance_knn_hc(faiss::METRIC_Hamming, &res, (const uint8_t*)xq, (const uint8_t*)xb, nb, dim / 8,
bitset);
for (int i = 0; i < nq * topk; ++i) {
distances[i] = int_distances[i];
}
break;
}
break;
}
case faiss::METRIC_Substructure:
case faiss::METRIC_Superstructure: {
// only matched ids will be chosen, not to use heap
binary_distance_knn_mc(faiss_metric_type, (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, topk, dim / 8,
distances, labels, bitset);
break;
case faiss::METRIC_Substructure:
case faiss::METRIC_Superstructure: {
// only matched ids will be chosen, not to use heap
binary_distance_knn_mc(faiss_metric_type, (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, topk, dim / 8,
distances, labels, bitset);
break;
}
default:
KNOWHERE_THROW_MSG("BruteForce search not support metric type: " + metric_type);
}
default:
KNOWHERE_THROW_MSG("BruteForce search not support metric type: " + metric_type);
}
});
future.get();
return GenResultDataset(labels, distances);
}

Expand All @@ -117,29 +121,37 @@ BruteForce::RangeSearch(const DatasetPtr base_dataset,

faiss::RangeSearchResult res(nq);

switch (faiss_metric_type) {
case faiss::METRIC_L2:
faiss::range_search_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, bitset);
break;
case faiss::METRIC_INNER_PRODUCT:
is_ip = true;
faiss::range_search_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, bitset);
break;
case faiss::METRIC_Jaccard:
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(faiss::METRIC_Jaccard,
(const uint8_t*)xq, (const uint8_t*)xb, nq, nb, radius, dim / 8, &res, bitset);
break;
case faiss::METRIC_Tanimoto:
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(faiss::METRIC_Tanimoto,
(const uint8_t*)xq, (const uint8_t*)xb, nq, nb, radius, dim / 8, &res, bitset);
break;
case faiss::METRIC_Hamming:
faiss::binary_range_search<faiss::CMin<int, int64_t>, int>(faiss::METRIC_Hamming,
(const uint8_t*)xq, (const uint8_t*)xb, nq, nb, (int)radius, dim / 8, &res, bitset);
break;
default:
KNOWHERE_THROW_MSG("BruteForce range search not support metric type: " + metric_type);
}
auto pool = ThreadPool::GetGlobalThreadPool();
auto future = pool->push([&] {
switch (faiss_metric_type) {
case faiss::METRIC_L2:
faiss::range_search_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, bitset);
break;
case faiss::METRIC_INNER_PRODUCT:
is_ip = true;
faiss::range_search_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res,
bitset);
break;
case faiss::METRIC_Jaccard:
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(faiss::METRIC_Jaccard,
(const uint8_t*)xq, (const uint8_t*)xb,
nq, nb, radius, dim / 8, &res, bitset);
break;
case faiss::METRIC_Tanimoto:
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(faiss::METRIC_Tanimoto,
(const uint8_t*)xq, (const uint8_t*)xb,
nq, nb, radius, dim / 8, &res, bitset);
break;
case faiss::METRIC_Hamming:
faiss::binary_range_search<faiss::CMin<int, int64_t>, int>(faiss::METRIC_Hamming, (const uint8_t*)xq,
(const uint8_t*)xb, nq, nb, (int)radius,
dim / 8, &res, bitset);
break;
default:
KNOWHERE_THROW_MSG("BruteForce range search not support metric type: " + metric_type);
}
});
future.get();

float* distances = nullptr;
int64_t* labels = nullptr;
Expand Down
2 changes: 2 additions & 0 deletions thirdparty/ctpl/ctpl-std.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#ifndef __ctpl_stl_thread_pool_H__
#define __ctpl_stl_thread_pool_H__

#include <omp.h>
#include <functional>
#include <thread>
#include <atomic>
Expand Down Expand Up @@ -209,6 +210,7 @@ namespace ctpl {
void set_thread(int i) {
std::shared_ptr<std::atomic<bool>> flag(this->flags[i]); // a copy of the shared ptr to the flag
auto f = [this, i, flag/* a copy of the shared ptr to the flag */]() {
omp_set_num_threads(1);
std::atomic<bool> & _flag = *flag;
std::function<void(int id)> * _f;
bool isPop = this->q.pop(_f);
Expand Down

0 comments on commit 4124cab

Please sign in to comment.