Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
fix avx512-s8-dequant and asym related bug (#19)
Browse files Browse the repository at this point in the history
* fix avx512-s8-dequant and asym related bug
  • Loading branch information
zhewang1-intc authored Jan 4, 2024
1 parent 1b9330b commit fad80b1
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 23 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@

/.vs
/out
.vscode/
.vscode/*
bestla/build/
bestla/build/*
4 changes: 4 additions & 0 deletions bestla/bestla/bestla_prologue_a.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ class ShuffleActivationKBlockBase : public ActivationKBlockBase<_GemmCore_T, ISA

template <class _GemmCore_T, BTLA_ISA ISA_T>
using ShuffleActivationKBlockBaseF32 = ShuffleActivationKBlockBase<_GemmCore_T, ISA_T, float>;
template <class _GemmCore_T, BTLA_ISA ISA_T>
using ShuffleActivationKBlockBaseBf16 = ShuffleActivationKBlockBase<_GemmCore_T, ISA_T, utils::bf16>;

template <typename AType>
struct ParamShuffleActivationKBlockQuantize : ParamActivationKBlockQuantize<AType> {
Expand Down Expand Up @@ -422,6 +424,8 @@ class ShuffleActivationKBlockQuantize : public ActivationKBlockQuantize<_GemmCor

template <class _GemmCore_T, BTLA_ISA ISA_T>
using ShuffleActivationKBlockQuantizeF32 = ShuffleActivationKBlockQuantize<_GemmCore_T, ISA_T, float>;
template <class _GemmCore_T, BTLA_ISA ISA_T>
using ShuffleActivationKBlockQuantizeBf16 = ShuffleActivationKBlockQuantize<_GemmCore_T, ISA_T, utils::bf16>;
} // namespace gemm
} // namespace prologue_a
} // namespace bestla
13 changes: 6 additions & 7 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,14 @@ BTLA_CODE dequant_kblock_s8_fp_fwd(int8_t* srcptr, _DST_T* dstptr, int row, int
auto s8_ymm_v = _mm_loadl_epi64(reinterpret_cast<__m128i*>(srcptr + i * ld_src + j));
auto s32_ymm_v = _mm256_cvtepi8_epi32(s8_ymm_v);
if constexpr (WITH_ZP) {
s32_ymm_v = _mm256_sub_epi32(
s32_ymm_v,
_mm256_cvtepi8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(zero_points + kpos * NPad + j))));
auto zp_ymm =
_mm256_cvtepi8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(zero_points + kpos * NPad + j / PACK_ROW)));
if constexpr (PACK_ROW == 4) zp_ymm = _mm256_permutevar8x32_epi32(zp_ymm, packrow4_permute_idx);
s32_ymm_v = _mm256_sub_epi32(s32_ymm_v, zp_ymm);
}
auto f32_ymm_v = _mm256_cvtepi32_ps(s32_ymm_v);
auto scale_ymm = _mm256_loadu_ps(sptr + j / PACK_ROW);
if constexpr (PACK_ROW == 4) {
scale_ymm = _mm256_permutevar8x32_ps(scale_ymm, packrow4_permute_idx);
}
if constexpr (PACK_ROW == 4) scale_ymm = _mm256_permutevar8x32_ps(scale_ymm, packrow4_permute_idx);
f32_ymm_v = _mm256_mul_ps(f32_ymm_v, scale_ymm);
if constexpr (std::is_same_v<_DST_T, float>) {
_mm256_storeu_ps(dstptr + i * ld_dst + j, f32_ymm_v);
Expand All @@ -181,7 +180,7 @@ BTLA_CODE dequant_kblock_s8_fp_fwd(int8_t* srcptr, _DST_T* dstptr, int row, int
}
for (; j < col; j++) {
float tmp = (float)(srcptr[i * ld_src + j]);
if constexpr (WITH_ZP) tmp -= (float)(zero_points[kpos * NPad + j]);
if constexpr (WITH_ZP) tmp -= (float)(zero_points[kpos * NPad + j / PACK_ROW]);
dstptr[i * ld_dst + j] = tmp * sptr[j / PACK_ROW];
}
}
Expand Down
30 changes: 17 additions & 13 deletions bestla/bestla/kernel_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class DequanS8FP {

void generate(BTLA_DTYPE dst_dt, int pack_row) {
assert(pack_row == 1 || pack_row == 2 || pack_row == 4);
int scale_step = 64 / pack_row;
int zmm_scale_step = 64 / pack_row;
Xbyak::Label data_label;
inLocalLabel(); // use local label for multiple instance
{
Expand Down Expand Up @@ -104,19 +104,23 @@ class DequanS8FP {
return 4; // f32 case.
};

auto generateNTile = [&](int N, BTLA_DTYPE dst_dt, int scale_step, std::string row_label) {
auto generateNTile = [&](int N, BTLA_DTYPE dst_dt, int zmm_scale_step, std::string row_label) {
if (pack_row == 2) {
vmovups(Xbyak::Zmm(RegTmp), ptr[rip + data_label + 8]);
} else if (pack_row == 4) {
vmovups(Xbyak::Zmm(RegTmp), ptr[rip + data_label + 72]);
}
for (int i = 0; i < N; i++) {
vmovups(Xbyak::Zmm(RegScale + i), ptr[reg_scaleptr + i * scale_step]);
vmovups(Xbyak::Zmm(RegScale + i), ptr[reg_scaleptr + i * zmm_scale_step]);
if (pack_row == 2 || pack_row == 4) {
vpermd(Xbyak::Zmm(RegScale + i), Xbyak::Zmm(RegTmp), Xbyak::Zmm(RegScale + i));
}
if (!is_sym) {
vpmovsxbd(Xbyak::Zmm(RegZP + i), ptr[reg_zpptr + i * 16]);
vpmovsxbd(Xbyak::Zmm(RegZP + i),
ptr[reg_zpptr + i * zmm_scale_step / sizeof(float)]); // revert to zp_step.
if (pack_row == 2 || pack_row == 4) {
vpermd(Xbyak::Zmm(RegZP + i), Xbyak::Zmm(RegTmp), Xbyak::Zmm(RegZP + i));
}
}
}
xor_(reg_iterrow, reg_iterrow);
Expand Down Expand Up @@ -162,32 +166,32 @@ class DequanS8FP {
sub(reg_tmp, reg_itercol);
cmp(reg_tmp, 64);
jl(".proc48", T_NEAR);
generateNTile(4, dst_dt, scale_step, ".rowloop1");
generateNTile(4, dst_dt, zmm_scale_step, ".rowloop1");
add(reg_itercol, 64);
add(reg_srcptr, 1 * 64);
add(reg_dstptr, get_dst_step() * 64);
add(reg_scaleptr, 4 * scale_step);
if (!is_sym) add(reg_zpptr, 1 * 64);
add(reg_scaleptr, 4 * 64 / pack_row);
if (!is_sym) add(reg_zpptr, 1 * 64 / pack_row);
jmp(".colend", T_NEAR);

L(".proc48");
cmp(reg_tmp, 48);
jl(".proc32", T_NEAR);
generateNTile(3, dst_dt, scale_step, ".rowloop2");
generateNTile(3, dst_dt, zmm_scale_step, ".rowloop2");
add(reg_itercol, 48);
add(reg_srcptr, 1 * 48);
add(reg_dstptr, get_dst_step() * 48);
add(reg_scaleptr, 4 * scale_step);
if (!is_sym) add(reg_zpptr, 1 * 48);
add(reg_scaleptr, 4 * 48 / pack_row);
if (!is_sym) add(reg_zpptr, 1 * 48 / pack_row);
jmp(".colend", T_NEAR);

L(".proc32");
generateNTile(2, dst_dt, scale_step, ".rowloop3");
generateNTile(2, dst_dt, zmm_scale_step, ".rowloop3");
add(reg_itercol, 32);
add(reg_srcptr, 1 * 32);
add(reg_dstptr, get_dst_step() * 32);
add(reg_scaleptr, 4 * scale_step);
if (!is_sym) add(reg_zpptr, 1 * 32);
add(reg_scaleptr, 4 * 32 / pack_row);
if (!is_sym) add(reg_zpptr, 1 * 32 / pack_row);

L(".colend");
cmp(reg_itercol, reg_colsize);
Expand Down
2 changes: 1 addition & 1 deletion bestla/bestla/kernel_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ inline BTLA_CODE decompress_kblock_s8_fp(int8_t* srcptr, _DST_T* dstptr, int row
auto sptr = scales + kpos * NPad;
for (int j = 0; j < col; j += 1) {
float tmp = static_cast<float>(srcptr[i * ld_src + j]);
if (zero_points != nullptr) tmp -= static_cast<float>(zero_points[kpos * NPad + j]);
if (zero_points != nullptr) tmp -= static_cast<float>(zero_points[kpos * NPad + j / _PACK_ROW]);
dstptr[i * ld_dst + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]);
}
}
Expand Down
4 changes: 2 additions & 2 deletions bestla/bestla/kernel_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -715,10 +715,10 @@ class ColBlockReduceSum {
template <BTLA_ISA ISA_T, typename SRC_T>
static inline BTLA_CODE forward(const SRC_T* srcptr, int ldsrc, int row, int col, int blocksize, float* reduce,
int ldr) {
if constexpr (utils::isa_base<ISA_T>::avx512f) {
if constexpr (utils::isa_base<ISA_T>::avx512f && std::is_same_v<SRC_T, float>) {
return avx512f::col_block_reduce_sum<SRC_T>(srcptr, ldsrc, row, col, blocksize, reduce, ldr);
}
if constexpr (utils::isa_base<ISA_T>::avx2) {
if constexpr (utils::isa_base<ISA_T>::avx2 && std::is_same_v<SRC_T, float>) {
return avx2::col_block_reduce_sum<SRC_T>(srcptr, ldsrc, row, col, blocksize, reduce, ldr);
}
return ref::col_block_reduce_sum<SRC_T>(srcptr, ldsrc, row, col, blocksize, reduce, ldr);
Expand Down

0 comments on commit fad80b1

Please sign in to comment.