Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: BF functions support real fp16/bf16 calculate #980

Merged
merged 1 commit into from
Dec 16, 2024

Conversation

cqy123456
Copy link
Collaborator

@cqy123456 cqy123456 commented Dec 10, 2024

issue: #909
This purpose of this pr is to remove memory copy of input data and prepare for NM index search.

simple test with cohere-768d-cosine:
8cpu, avx512
fp16/bf16: rowcount = 174762, total size = 128MB
QPS of float16: 170.49 --> before: 75.5
QPS of bfloat16: 158.87 --> before: 73.95

fp32: rowcount = 174762, total size = 256MB
QPS of float32: 132.029

Use VecTool to test with gist1M-960d-L2:
bf vs. flat(convert to fp32)
bf16:
I1216 02:53:15.839870 69100 metric.cpp:141] case = gist_query_FLAT_k_100_0.00 | repeat = 1 | nq = 1000 | k = 100 | time = 66.1252s | qps = 15.1228 | avg_recall = 1 | min_recall = 1 | avg_ncdg = -nan | min_ncdg = 3.40282e+38 | avg_disterr = 0 | max_disterr = 1.17549e-38
fp16:
I1216 02:59:14.644295 70049 metric.cpp:141] case = gist_query_FLAT_k_100_0.00 | repeat = 1 | nq = 1000 | k = 100 | time = 67.0043s | qps = 14.9244 | avg_recall = 1 | min_recall = 1 | avg_ncdg = -nan | min_ncdg = 3.40282e+38 | avg_disterr = 0 | max_disterr = 1.17549e-38

@sre-ci-robot
Copy link
Collaborator

[APPROVALNOTIFIER] This PR is APPROVED

This pull-request has been approved by: cqy123456

The full list of commands accepted by this bot can be found here.

The pull request process is described here

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

Copy link

mergify bot commented Dec 10, 2024

@cqy123456 🔍 Important: PR Classification Needed!

For efficient project management and a seamless review process, it's essential to classify your PR correctly. Here's how:

  1. If you're fixing a bug, label it as kind/bug.
  2. For small tweaks (less than 20 lines without altering any functionality), please use kind/improvement.
  3. Significant changes that don't modify existing functionalities should be tagged as kind/enhancement.
  4. Adjusting APIs or changing functionality? Go with kind/feature.

For any PR outside the kind/improvement category, ensure you link to the associated issue using the format: “issue: #”.

Thanks for your efforts and contribution to the community!.

Copy link

codecov bot commented Dec 10, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 73.87%. Comparing base (3c46f4c) to head (e089ab8).
Report is 268 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff            @@
##           main     #980       +/-   ##
=========================================
+ Coverage      0   73.87%   +73.87%     
=========================================
  Files         0       82       +82     
  Lines         0     6916     +6916     
=========================================
+ Hits          0     5109     +5109     
- Misses        0     1807     +1807     

see 82 files with indirect coverage changes

@@ -159,19 +159,17 @@ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
float
fp16_vec_inner_product_avx512(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) {
__m512 m512_res = _mm512_setzero_ps();
__m512 m512_res_0 = _mm512_setzero_ps();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

res_0 is used for increasing the instruction level parallelism in the following loop.
Please confirm with godbolt.org, or let me know if I need to check this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a slight difference between fp16_vec_inner_product_avx512 and fp16_vec_inner_product_avx512_batch_4:
fp16_vec_inner_product_avx512 : sum of (round(ab + m512_res)) + sum of (round(ab + m512_res_0))
fp16_vec_inner_product_avx512_batch_4: sum of (round(a*b + m512_res))

Copy link
Collaborator Author

@cqy123456 cqy123456 Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

presicion loss may caused by fmadd: round(ab + c)
mul and add: round(a
b) +c

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It sounds reasonable, but please add comments in the corresponding distance_XYZ.cc files after headers but before namespace faiss { row the explanation of why it is done. Otherwise, someone may wish to 'optimize' the code back.
Alternatively, it is possible (and it is a faster solution for the single-op version) to extend batch-4 version to match a single-op version instead. So, let batch-4 loop perform 8 FMA operations instead of 4 and let a single-op version perform 2 FMA operations, as it is now in the baseline.

I leave it up to you to decide whether you'd like to change it, bcz the hot spot is batch-4 version anyways.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested batch-4 loop perform 8 FMA operations, a little performance degradation existed in BF float16 search.
It is possible that the number of registers in avx512 is not enough for parallelism.

res.val[1] = vmlaq_f32(res.val[1], a.val[1], a.val[1]);
res.val[2] = vmlaq_f32(res.val[2], a.val[2], a.val[2]);
res.val[3] = vmlaq_f32(res.val[3], a.val[3], a.val[3]);
res.val[0] = vaddq_f32(vmulq_f32(a.val[0], a.val[0]), res.val[0]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why replacing FMA with ADD+MUL ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update.

// use build thread pool to compute norms
auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<folly::Unit>> futs;
constexpr int64_t chunk_size = 128;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is way too small, tbh
I'd make it something like 8192

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update.

auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<folly::Unit>> futs;
constexpr int64_t chunk_size = 128;
auto chunk_num = std::ceil(float(nb) / float(chunk_size));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto chunk_num = (nb + chunk_size - 1) / chunk_size;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update.

faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances};
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector);
} else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck<DataType>::value) {
faiss::half_precision_floating_point_knn_L2sqr(cur_query, (const DataType*)xb, dim, 1, nb, topk,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it not possible to make faiss:knn_L2sqr templated instead? Basically, why half-precision data type requires cloning the knn code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will templating lots of the distances.h functions increase the complexity of faiss upgrade? If it's okay, I can move the implementation of fp16/bf16 to distances.h.

@@ -397,7 +465,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da

faiss::MetricType faiss_metric_type;
sparse::DocValueComputer<float> sparse_computer;
if (!is_sparse) {
if constexpr (!std::is_same_v<DataType, knowhere::sparse::SparseRow<float>>) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this row changes the previous logic: from if sparse to if float sparse.
Just want to double check that this is intended.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool is_sparse = std::is_same<DataType, knowhere::sparse::SparseRow>::value;

} else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck<DataType>::value) {
// normalize query vector may cause presision loss, so div query norms in apply function

faiss::half_precision_floating_point_range_search_cosine(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here: why not templatize faiss::range_search_cosine?

faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res,
id_selector);
} else {
// else not sparse:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else not float sparse.
The logic for this branch got changed.

@@ -159,19 +159,17 @@ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
float
fp16_vec_inner_product_avx512(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) {
__m512 m512_res = _mm512_setzero_ps();
__m512 m512_res_0 = _mm512_setzero_ps();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It sounds reasonable, but please add comments in the corresponding distance_XYZ.cc files after headers but before namespace faiss { row the explanation of why it is done. Otherwise, someone may wish to 'optimize' the code back.
Alternatively, it is possible (and it is a faster solution for the single-op version) to extend batch-4 version to match a single-op version instead. So, let batch-4 loop perform 8 FMA operations instead of 4 and let a single-op version perform 2 FMA operations, as it is now in the baseline.

I leave it up to you to decide whether you'd like to change it, bcz the hot spot is batch-4 version anyways.

@@ -239,6 +239,14 @@ GetRangeSearchRecall(const knowhere::DataSet& gt, const knowhere::DataSet& resul
return (1 + precision) * recall / 2;
}

inline float
GetRelativeLoss(float gt_res, float res) {
if (gt_res == 0.0 || std::abs(gt_res) < 0.000001) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recommend use float epsilon

@foxspy
Copy link
Collaborator

foxspy commented Dec 16, 2024

/kind improvement

@foxspy
Copy link
Collaborator

foxspy commented Dec 16, 2024

/lgtm

@sre-ci-robot sre-ci-robot merged commit a9d6992 into zilliztech:main Dec 16, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants