From fad80b14fa804ac0f2bae613a2dd857b5b7592f4 Mon Sep 17 00:00:00 2001 From: "Wang, Zhe" Date: Thu, 4 Jan 2024 10:24:26 +0800 Subject: [PATCH] fix avx512-s8-dequant and asym related bug (#19) * fix avx512-s8-dequant and asym related bug --- .gitignore | 4 ++++ bestla/bestla/bestla_prologue_a.h | 4 ++++ bestla/bestla/kernel_avx2.h | 13 ++++++------- bestla/bestla/kernel_jit.h | 30 +++++++++++++++++------------- bestla/bestla/kernel_ref.h | 2 +- bestla/bestla/kernel_wrapper.h | 4 ++-- 6 files changed, 34 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index 7d2913bd3..5ce59c23a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,7 @@ /.vs /out +.vscode/ +.vscode/* +bestla/build/ +bestla/build/* diff --git a/bestla/bestla/bestla_prologue_a.h b/bestla/bestla/bestla_prologue_a.h index 677e42203..d6e782c76 100644 --- a/bestla/bestla/bestla_prologue_a.h +++ b/bestla/bestla/bestla_prologue_a.h @@ -367,6 +367,8 @@ class ShuffleActivationKBlockBase : public ActivationKBlockBase<_GemmCore_T, ISA template using ShuffleActivationKBlockBaseF32 = ShuffleActivationKBlockBase<_GemmCore_T, ISA_T, float>; +template +using ShuffleActivationKBlockBaseBf16 = ShuffleActivationKBlockBase<_GemmCore_T, ISA_T, utils::bf16>; template struct ParamShuffleActivationKBlockQuantize : ParamActivationKBlockQuantize { @@ -422,6 +424,8 @@ class ShuffleActivationKBlockQuantize : public ActivationKBlockQuantize<_GemmCor template using ShuffleActivationKBlockQuantizeF32 = ShuffleActivationKBlockQuantize<_GemmCore_T, ISA_T, float>; +template +using ShuffleActivationKBlockQuantizeBf16 = ShuffleActivationKBlockQuantize<_GemmCore_T, ISA_T, utils::bf16>; } // namespace gemm } // namespace prologue_a } // namespace bestla diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index c5b66b4bf..125dae41a 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -161,15 +161,14 @@ BTLA_CODE dequant_kblock_s8_fp_fwd(int8_t* srcptr, _DST_T* dstptr, int row, int auto s8_ymm_v = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr + i * ld_src + j)); auto s32_ymm_v = _mm256_cvtepi8_epi32(s8_ymm_v); if constexpr (WITH_ZP) { - s32_ymm_v = _mm256_sub_epi32( - s32_ymm_v, - _mm256_cvtepi8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(zero_points + kpos * NPad + j)))); + auto zp_ymm = + _mm256_cvtepi8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(zero_points + kpos * NPad + j / PACK_ROW))); + if constexpr (PACK_ROW == 4) zp_ymm = _mm256_permutevar8x32_epi32(zp_ymm, packrow4_permute_idx); + s32_ymm_v = _mm256_sub_epi32(s32_ymm_v, zp_ymm); } auto f32_ymm_v = _mm256_cvtepi32_ps(s32_ymm_v); auto scale_ymm = _mm256_loadu_ps(sptr + j / PACK_ROW); - if constexpr (PACK_ROW == 4) { - scale_ymm = _mm256_permutevar8x32_ps(scale_ymm, packrow4_permute_idx); - } + if constexpr (PACK_ROW == 4) scale_ymm = _mm256_permutevar8x32_ps(scale_ymm, packrow4_permute_idx); f32_ymm_v = _mm256_mul_ps(f32_ymm_v, scale_ymm); if constexpr (std::is_same_v<_DST_T, float>) { _mm256_storeu_ps(dstptr + i * ld_dst + j, f32_ymm_v); @@ -181,7 +180,7 @@ BTLA_CODE dequant_kblock_s8_fp_fwd(int8_t* srcptr, _DST_T* dstptr, int row, int } for (; j < col; j++) { float tmp = (float)(srcptr[i * ld_src + j]); - if constexpr (WITH_ZP) tmp -= (float)(zero_points[kpos * NPad + j]); + if constexpr (WITH_ZP) tmp -= (float)(zero_points[kpos * NPad + j / PACK_ROW]); dstptr[i * ld_dst + j] = tmp * sptr[j / PACK_ROW]; } } diff --git a/bestla/bestla/kernel_jit.h b/bestla/bestla/kernel_jit.h index 718c47e25..1e9d4146c 100644 --- a/bestla/bestla/kernel_jit.h +++ b/bestla/bestla/kernel_jit.h @@ -56,7 +56,7 @@ class DequanS8FP { void generate(BTLA_DTYPE dst_dt, int pack_row) { assert(pack_row == 1 || pack_row == 2 || pack_row == 4); - int scale_step = 64 / pack_row; + int zmm_scale_step = 64 / pack_row; Xbyak::Label data_label; inLocalLabel(); // use local label for multiple instance { @@ -104,19 +104,23 @@ class DequanS8FP { return 4; // f32 case. }; - auto generateNTile = [&](int N, BTLA_DTYPE dst_dt, int scale_step, std::string row_label) { + auto generateNTile = [&](int N, BTLA_DTYPE dst_dt, int zmm_scale_step, std::string row_label) { if (pack_row == 2) { vmovups(Xbyak::Zmm(RegTmp), ptr[rip + data_label + 8]); } else if (pack_row == 4) { vmovups(Xbyak::Zmm(RegTmp), ptr[rip + data_label + 72]); } for (int i = 0; i < N; i++) { - vmovups(Xbyak::Zmm(RegScale + i), ptr[reg_scaleptr + i * scale_step]); + vmovups(Xbyak::Zmm(RegScale + i), ptr[reg_scaleptr + i * zmm_scale_step]); if (pack_row == 2 || pack_row == 4) { vpermd(Xbyak::Zmm(RegScale + i), Xbyak::Zmm(RegTmp), Xbyak::Zmm(RegScale + i)); } if (!is_sym) { - vpmovsxbd(Xbyak::Zmm(RegZP + i), ptr[reg_zpptr + i * 16]); + vpmovsxbd(Xbyak::Zmm(RegZP + i), + ptr[reg_zpptr + i * zmm_scale_step / sizeof(float)]); // revert to zp_step. + if (pack_row == 2 || pack_row == 4) { + vpermd(Xbyak::Zmm(RegZP + i), Xbyak::Zmm(RegTmp), Xbyak::Zmm(RegZP + i)); + } } } xor_(reg_iterrow, reg_iterrow); @@ -162,32 +166,32 @@ class DequanS8FP { sub(reg_tmp, reg_itercol); cmp(reg_tmp, 64); jl(".proc48", T_NEAR); - generateNTile(4, dst_dt, scale_step, ".rowloop1"); + generateNTile(4, dst_dt, zmm_scale_step, ".rowloop1"); add(reg_itercol, 64); add(reg_srcptr, 1 * 64); add(reg_dstptr, get_dst_step() * 64); - add(reg_scaleptr, 4 * scale_step); - if (!is_sym) add(reg_zpptr, 1 * 64); + add(reg_scaleptr, 4 * 64 / pack_row); + if (!is_sym) add(reg_zpptr, 1 * 64 / pack_row); jmp(".colend", T_NEAR); L(".proc48"); cmp(reg_tmp, 48); jl(".proc32", T_NEAR); - generateNTile(3, dst_dt, scale_step, ".rowloop2"); + generateNTile(3, dst_dt, zmm_scale_step, ".rowloop2"); add(reg_itercol, 48); add(reg_srcptr, 1 * 48); add(reg_dstptr, get_dst_step() * 48); - add(reg_scaleptr, 4 * scale_step); - if (!is_sym) add(reg_zpptr, 1 * 48); + add(reg_scaleptr, 4 * 48 / pack_row); + if (!is_sym) add(reg_zpptr, 1 * 48 / pack_row); jmp(".colend", T_NEAR); L(".proc32"); - generateNTile(2, dst_dt, scale_step, ".rowloop3"); + generateNTile(2, dst_dt, zmm_scale_step, ".rowloop3"); add(reg_itercol, 32); add(reg_srcptr, 1 * 32); add(reg_dstptr, get_dst_step() * 32); - add(reg_scaleptr, 4 * scale_step); - if (!is_sym) add(reg_zpptr, 1 * 32); + add(reg_scaleptr, 4 * 32 / pack_row); + if (!is_sym) add(reg_zpptr, 1 * 32 / pack_row); L(".colend"); cmp(reg_itercol, reg_colsize); diff --git a/bestla/bestla/kernel_ref.h b/bestla/bestla/kernel_ref.h index f55f4fa45..d72bf22c4 100644 --- a/bestla/bestla/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -335,7 +335,7 @@ inline BTLA_CODE decompress_kblock_s8_fp(int8_t* srcptr, _DST_T* dstptr, int row auto sptr = scales + kpos * NPad; for (int j = 0; j < col; j += 1) { float tmp = static_cast(srcptr[i * ld_src + j]); - if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j]); + if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j / _PACK_ROW]); dstptr[i * ld_dst + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); } } diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index d0832bd82..dd7d92ae5 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -715,10 +715,10 @@ class ColBlockReduceSum { template static inline BTLA_CODE forward(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, float* reduce, int ldr) { - if constexpr (utils::isa_base::avx512f) { + if constexpr (utils::isa_base::avx512f && std::is_same_v) { return avx512f::col_block_reduce_sum(srcptr, ldsrc, row, col, blocksize, reduce, ldr); } - if constexpr (utils::isa_base::avx2) { + if constexpr (utils::isa_base::avx2 && std::is_same_v) { return avx2::col_block_reduce_sum(srcptr, ldsrc, row, col, blocksize, reduce, ldr); } return ref::col_block_reduce_sum(srcptr, ldsrc, row, col, blocksize, reduce, ldr);