Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: BF functions support real fp16/bf16 calculate #980

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/knowhere/operands.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ template <typename InType>
using KnowhereDataTypeCheck = TypeMatch<InType, bin1, fp16, fp32, bf16>;
template <typename InType>
using KnowhereFloatTypeCheck = TypeMatch<InType, fp16, fp32, bf16>;
template <typename InType>
using KnowhereHalfPrecisionFloatPointTypeCheck = TypeMatch<InType, fp16, bf16>;

template <typename T>
struct MockData {
Expand Down
397 changes: 251 additions & 146 deletions src/common/comp/brute_force.cc

Large diffs are not rendered by default.

285 changes: 244 additions & 41 deletions src/simd/distances_avx.cc

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions src/simd/distances_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ fvec_inner_product_batch_4_avx_bf16_patch(const float* x, const float* y0, const
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

void
fp16_vec_inner_product_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
bf16_vec_inner_product_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);
Expand All @@ -73,6 +83,16 @@ void
fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fp16_vec_L2sqr_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);

void
bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);

int32_t
ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d);

Expand Down
199 changes: 196 additions & 3 deletions src/simd/distances_avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,17 @@ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
float
fp16_vec_inner_product_avx512(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) {
__m512 m512_res = _mm512_setzero_ps();
__m512 m512_res_0 = _mm512_setzero_ps();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

res_0 is used for increasing the instruction level parallelism in the following loop.
Please confirm with godbolt.org, or let me know if I need to check this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a slight difference between fp16_vec_inner_product_avx512 and fp16_vec_inner_product_avx512_batch_4:
fp16_vec_inner_product_avx512 : sum of (round(ab + m512_res)) + sum of (round(ab + m512_res_0))
fp16_vec_inner_product_avx512_batch_4: sum of (round(a*b + m512_res))

Copy link
Collaborator Author

@cqy123456 cqy123456 Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

presicion loss may caused by fmadd: round(ab + c)
mul and add: round(a
b) +c

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It sounds reasonable, but please add comments in the corresponding distance_XYZ.cc files after headers but before namespace faiss { row the explanation of why it is done. Otherwise, someone may wish to 'optimize' the code back.
Alternatively, it is possible (and it is a faster solution for the single-op version) to extend batch-4 version to match a single-op version instead. So, let batch-4 loop perform 8 FMA operations instead of 4 and let a single-op version perform 2 FMA operations, as it is now in the baseline.

I leave it up to you to decide whether you'd like to change it, bcz the hot spot is batch-4 version anyways.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested batch-4 loop perform 8 FMA operations, a little performance degradation existed in BF float16 search.
It is possible that the number of registers in avx512 is not enough for parallelism.

while (d >= 32) {
auto mx_0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)x));
auto my_0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y));
auto mx_1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + 16)));
auto my_1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + 16)));
m512_res = _mm512_fmadd_ps(mx_0, my_0, m512_res);
m512_res_0 = _mm512_fmadd_ps(mx_1, my_1, m512_res_0);
m512_res = _mm512_fmadd_ps(mx_1, my_1, m512_res);
x += 32;
y += 32;
d -= 32;
}
m512_res = m512_res + m512_res_0;
if (d >= 16) {
auto mx = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)x));
auto my = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y));
Expand Down Expand Up @@ -408,6 +406,201 @@ fvec_inner_product_batch_4_avx512_bf16_patch(const float* __restrict x, const fl
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

void
fp16_vec_inner_product_batch_4_avx512(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3) {
__m512 m512_res_0 = _mm512_setzero_ps();
__m512 m512_res_1 = _mm512_setzero_ps();
__m512 m512_res_2 = _mm512_setzero_ps();
__m512 m512_res_3 = _mm512_setzero_ps();
size_t cur_d = d;
while (cur_d >= 16) {
auto mx = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)x));
auto my0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y0));
auto my1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y1));
auto my2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y2));
auto my3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y3));
m512_res_0 = _mm512_fmadd_ps(mx, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(mx, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(mx, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(mx, my3, m512_res_3);
x += 16;
y0 += 16;
y1 += 16;
y2 += 16;
y3 += 16;
cur_d -= 16;
}
if (cur_d > 0) {
const __mmask16 mask = (1U << cur_d) - 1U;
auto mx = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, x));
auto my0 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y0));
auto my1 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y1));
auto my2 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y2));
auto my3 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y3));
m512_res_0 = _mm512_fmadd_ps(mx, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(mx, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(mx, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(mx, my3, m512_res_3);
}
dis0 = _mm512_reduce_add_ps(m512_res_0);
dis1 = _mm512_reduce_add_ps(m512_res_1);
dis2 = _mm512_reduce_add_ps(m512_res_2);
dis3 = _mm512_reduce_add_ps(m512_res_3);
return;
}

void
bf16_vec_inner_product_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3) {
__m512 m512_res_0 = _mm512_setzero_ps();
__m512 m512_res_1 = _mm512_setzero_ps();
__m512 m512_res_2 = _mm512_setzero_ps();
__m512 m512_res_3 = _mm512_setzero_ps();
size_t cur_d = d;
while (cur_d >= 16) {
auto mx = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)x));
auto my0 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y0));
auto my1 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y1));
auto my2 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y2));
auto my3 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y3));
m512_res_0 = _mm512_fmadd_ps(mx, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(mx, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(mx, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(mx, my3, m512_res_3);
x += 16;
y0 += 16;
y1 += 16;
y2 += 16;
y3 += 16;
cur_d -= 16;
}
if (cur_d > 0) {
const __mmask16 mask = (1U << cur_d) - 1U;
auto mx = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, x));
auto my0 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y0));
auto my1 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y1));
auto my2 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y2));
auto my3 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y3));
m512_res_0 = _mm512_fmadd_ps(mx, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(mx, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(mx, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(mx, my3, m512_res_3);
}
dis0 = _mm512_reduce_add_ps(m512_res_0);
dis1 = _mm512_reduce_add_ps(m512_res_1);
dis2 = _mm512_reduce_add_ps(m512_res_2);
dis3 = _mm512_reduce_add_ps(m512_res_3);
return;
}

void
fp16_vec_L2sqr_batch_4_avx512(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3) {
__m512 m512_res_0 = _mm512_setzero_ps();
__m512 m512_res_1 = _mm512_setzero_ps();
__m512 m512_res_2 = _mm512_setzero_ps();
__m512 m512_res_3 = _mm512_setzero_ps();
size_t cur_d = d;
while (cur_d >= 16) {
auto mx = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)x));
auto my0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y0));
auto my1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y1));
auto my2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y2));
auto my3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y3));
my0 = _mm512_sub_ps(mx, my0);
my1 = _mm512_sub_ps(mx, my1);
my2 = _mm512_sub_ps(mx, my2);
my3 = _mm512_sub_ps(mx, my3);
m512_res_0 = _mm512_fmadd_ps(my0, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(my1, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(my2, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(my3, my3, m512_res_3);
x += 16;
y0 += 16;
y1 += 16;
y2 += 16;
y3 += 16;
cur_d -= 16;
}
if (cur_d > 0) {
const __mmask16 mask = (1U << cur_d) - 1U;
auto mx = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, x));
auto my0 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y0));
auto my1 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y1));
auto my2 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y2));
auto my3 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y3));
my0 = _mm512_sub_ps(mx, my0);
my1 = _mm512_sub_ps(mx, my1);
my2 = _mm512_sub_ps(mx, my2);
my3 = _mm512_sub_ps(mx, my3);
m512_res_0 = _mm512_fmadd_ps(my0, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(my1, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(my2, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(my3, my3, m512_res_3);
}
dis0 = _mm512_reduce_add_ps(m512_res_0);
dis1 = _mm512_reduce_add_ps(m512_res_1);
dis2 = _mm512_reduce_add_ps(m512_res_2);
dis3 = _mm512_reduce_add_ps(m512_res_3);
return;
}

void
bf16_vec_L2sqr_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3) {
__m512 m512_res_0 = _mm512_setzero_ps();
__m512 m512_res_1 = _mm512_setzero_ps();
__m512 m512_res_2 = _mm512_setzero_ps();
__m512 m512_res_3 = _mm512_setzero_ps();
size_t cur_d = d;
while (cur_d >= 16) {
auto mx = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)x));
auto my0 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y0));
auto my1 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y1));
auto my2 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y2));
auto my3 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y3));
my0 = mx - my0;
my1 = mx - my1;
my2 = mx - my2;
my3 = mx - my3;
m512_res_0 = _mm512_fmadd_ps(my0, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(my1, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(my2, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(my3, my3, m512_res_3);
x += 16;
y0 += 16;
y1 += 16;
y2 += 16;
y3 += 16;
cur_d -= 16;
}
if (cur_d > 0) {
const __mmask16 mask = (1U << cur_d) - 1U;
auto mx = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, x));
auto my0 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y0));
auto my1 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y1));
auto my2 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y2));
auto my3 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y3));
my0 = _mm512_sub_ps(mx, my0);
my1 = _mm512_sub_ps(mx, my1);
my2 = _mm512_sub_ps(mx, my2);
my3 = _mm512_sub_ps(mx, my3);
m512_res_0 = _mm512_fmadd_ps(my0, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(my1, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(my2, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(my3, my3, m512_res_3);
}
dis0 = _mm512_reduce_add_ps(m512_res_0);
dis1 = _mm512_reduce_add_ps(m512_res_1);
dis2 = _mm512_reduce_add_ps(m512_res_2);
dis3 = _mm512_reduce_add_ps(m512_res_3);
return;
}
// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
Expand Down
20 changes: 20 additions & 0 deletions src/simd/distances_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ fvec_inner_product_batch_4_avx512_bf16_patch(const float* x, const float* y0, co
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

void
fp16_vec_inner_product_batch_4_avx512(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
bf16_vec_inner_product_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);
Expand All @@ -72,6 +82,16 @@ void
fvec_L2sqr_batch_4_avx512_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fp16_vec_L2sqr_batch_4_avx512(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
bf16_vec_L2sqr_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

int32_t
ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d);

Expand Down
Loading
Loading