-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enhance: BF functions support real fp16/bf16 calculate #980
Conversation
[APPROVALNOTIFIER] This PR is APPROVED This pull-request has been approved by: cqy123456 The full list of commands accepted by this bot can be found here. The pull request process is described here
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
@cqy123456 🔍 Important: PR Classification Needed! For efficient project management and a seamless review process, it's essential to classify your PR correctly. Here's how:
For any PR outside the kind/improvement category, ensure you link to the associated issue using the format: “issue: #”. Thanks for your efforts and contribution to the community!. |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #980 +/- ##
=========================================
+ Coverage 0 73.87% +73.87%
=========================================
Files 0 82 +82
Lines 0 6916 +6916
=========================================
+ Hits 0 5109 +5109
- Misses 0 1807 +1807 |
@@ -159,19 +159,17 @@ FAISS_PRAGMA_IMPRECISE_FUNCTION_END | |||
float | |||
fp16_vec_inner_product_avx512(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { | |||
__m512 m512_res = _mm512_setzero_ps(); | |||
__m512 m512_res_0 = _mm512_setzero_ps(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
res_0
is used for increasing the instruction level parallelism in the following loop.
Please confirm with godbolt.org, or let me know if I need to check this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a slight difference between fp16_vec_inner_product_avx512 and fp16_vec_inner_product_avx512_batch_4:
fp16_vec_inner_product_avx512 : sum of (round(ab + m512_res)) + sum of (round(ab + m512_res_0))
fp16_vec_inner_product_avx512_batch_4: sum of (round(a*b + m512_res))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
presicion loss may caused by fmadd: round(ab + c)
mul and add: round(ab) +c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. It sounds reasonable, but please add comments in the corresponding distance_XYZ.cc
files after headers but before namespace faiss {
row the explanation of why it is done. Otherwise, someone may wish to 'optimize' the code back.
Alternatively, it is possible (and it is a faster solution for the single-op version) to extend batch-4 version to match a single-op version instead. So, let batch-4 loop perform 8 FMA operations instead of 4 and let a single-op version perform 2 FMA operations, as it is now in the baseline.
I leave it up to you to decide whether you'd like to change it, bcz the hot spot is batch-4 version anyways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested batch-4 loop perform 8 FMA operations, a little performance degradation existed in BF float16 search.
It is possible that the number of registers in avx512 is not enough for parallelism.
src/simd/distances_neon.cc
Outdated
res.val[1] = vmlaq_f32(res.val[1], a.val[1], a.val[1]); | ||
res.val[2] = vmlaq_f32(res.val[2], a.val[2], a.val[2]); | ||
res.val[3] = vmlaq_f32(res.val[3], a.val[3], a.val[3]); | ||
res.val[0] = vaddq_f32(vmulq_f32(a.val[0], a.val[0]), res.val[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why replacing FMA with ADD+MUL ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update.
src/common/comp/brute_force.cc
Outdated
// use build thread pool to compute norms | ||
auto pool = ThreadPool::GetGlobalSearchThreadPool(); | ||
std::vector<folly::Future<folly::Unit>> futs; | ||
constexpr int64_t chunk_size = 128; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is way too small, tbh
I'd make it something like 8192
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update.
src/common/comp/brute_force.cc
Outdated
auto pool = ThreadPool::GetGlobalSearchThreadPool(); | ||
std::vector<folly::Future<folly::Unit>> futs; | ||
constexpr int64_t chunk_size = 128; | ||
auto chunk_num = std::ceil(float(nb) / float(chunk_size)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto chunk_num = (nb + chunk_size - 1) / chunk_size;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update.
faiss::float_maxheap_array_t buf{(size_t)1, (size_t)topk, cur_labels, cur_distances}; | ||
faiss::knn_L2sqr(cur_query, (const float*)xb, dim, 1, nb, &buf, nullptr, id_selector); | ||
} else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck<DataType>::value) { | ||
faiss::half_precision_floating_point_knn_L2sqr(cur_query, (const DataType*)xb, dim, 1, nb, topk, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it not possible to make faiss:knn_L2sqr
templated instead? Basically, why half-precision data type requires cloning the knn code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will templating lots of the distances.h functions increase the complexity of faiss upgrade? If it's okay, I can move the implementation of fp16/bf16 to distances.h.
@@ -397,7 +465,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da | |||
|
|||
faiss::MetricType faiss_metric_type; | |||
sparse::DocValueComputer<float> sparse_computer; | |||
if (!is_sparse) { | |||
if constexpr (!std::is_same_v<DataType, knowhere::sparse::SparseRow<float>>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this row changes the previous logic: from if sparse
to if float sparse
.
Just want to double check that this is intended.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool is_sparse = std::is_same<DataType, knowhere::sparse::SparseRow>::value;
} else if constexpr (KnowhereHalfPrecisionFloatPointTypeCheck<DataType>::value) { | ||
// normalize query vector may cause presision loss, so div query norms in apply function | ||
|
||
faiss::half_precision_floating_point_range_search_cosine( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here: why not templatize faiss::range_search_cosine
?
faiss::range_search_inner_product(cur_query, (const float*)xb, dim, 1, nb, radius, &res, | ||
id_selector); | ||
} else { | ||
// else not sparse: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else not float sparse
.
The logic for this branch got changed.
@@ -159,19 +159,17 @@ FAISS_PRAGMA_IMPRECISE_FUNCTION_END | |||
float | |||
fp16_vec_inner_product_avx512(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) { | |||
__m512 m512_res = _mm512_setzero_ps(); | |||
__m512 m512_res_0 = _mm512_setzero_ps(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. It sounds reasonable, but please add comments in the corresponding distance_XYZ.cc
files after headers but before namespace faiss {
row the explanation of why it is done. Otherwise, someone may wish to 'optimize' the code back.
Alternatively, it is possible (and it is a faster solution for the single-op version) to extend batch-4 version to match a single-op version instead. So, let batch-4 loop perform 8 FMA operations instead of 4 and let a single-op version perform 2 FMA operations, as it is now in the baseline.
I leave it up to you to decide whether you'd like to change it, bcz the hot spot is batch-4 version anyways.
Signed-off-by: cqy123456 <[email protected]>
@@ -239,6 +239,14 @@ GetRangeSearchRecall(const knowhere::DataSet& gt, const knowhere::DataSet& resul | |||
return (1 + precision) * recall / 2; | |||
} | |||
|
|||
inline float | |||
GetRelativeLoss(float gt_res, float res) { | |||
if (gt_res == 0.0 || std::abs(gt_res) < 0.000001) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recommend use float epsilon
/kind improvement |
/lgtm |
issue: #909
This purpose of this pr is to remove memory copy of input data and prepare for NM index search.
simple test with cohere-768d-cosine:
8cpu, avx512
fp16/bf16: rowcount = 174762, total size = 128MB
QPS of float16: 170.49 --> before: 75.5
QPS of bfloat16: 158.87 --> before: 73.95
fp32: rowcount = 174762, total size = 256MB
QPS of float32: 132.029
Use VecTool to test with gist1M-960d-L2:
bf vs. flat(convert to fp32)
bf16:
I1216 02:53:15.839870 69100 metric.cpp:141] case = gist_query_FLAT_k_100_0.00 | repeat = 1 | nq = 1000 | k = 100 | time = 66.1252s | qps = 15.1228 | avg_recall = 1 | min_recall = 1 | avg_ncdg = -nan | min_ncdg = 3.40282e+38 | avg_disterr = 0 | max_disterr = 1.17549e-38
fp16:
I1216 02:59:14.644295 70049 metric.cpp:141] case = gist_query_FLAT_k_100_0.00 | repeat = 1 | nq = 1000 | k = 100 | time = 67.0043s | qps = 14.9244 | avg_recall = 1 | min_recall = 1 | avg_ncdg = -nan | min_ncdg = 3.40282e+38 | avg_disterr = 0 | max_disterr = 1.17549e-38