diff --git a/bestla/bestla/bestla.h b/bestla/bestla/bestla.h index ee2ddd3c8..f053e4049 100644 --- a/bestla/bestla/bestla.h +++ b/bestla/bestla/bestla.h @@ -32,6 +32,7 @@ enum class BTLA_ISA : uint8_t { AMX_INT8, AVX512_FP16, AVX512_BF16, + AMX_FP16, ISA_COUNT, }; enum class BTLA_DTYPE : uint32_t { diff --git a/bestla/bestla/bestla_device.h b/bestla/bestla/bestla_device.h index 596065a66..9cb641ed5 100644 --- a/bestla/bestla/bestla_device.h +++ b/bestla/bestla/bestla_device.h @@ -232,6 +232,7 @@ class CpuDevice { inline bool AMX_BF16() { return mHasAMX_BF16; } inline bool AVX512_BF16() { return mHasAVX512_BF16; } inline bool AVX512_FP16() { return mHasAVX512_FP16; } + inline bool AMX_FP16() { return mHasAMX_FP16; } inline float* const getPE() { return PE; } inline int getPcoreNum() { return static_cast(P_core.size()); } inline int getEcoreNum() { return static_cast(E_core.size()); } @@ -252,8 +253,9 @@ class CpuDevice { ADD_FLAG(AMX_INT8); ADD_FLAG(AVX512_BF16); ADD_FLAG(AVX512_FP16); + ADD_FLAG(AMX_FP16); numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); - if (mHasAMX_BF16 || mHasAMX_INT8) { + if (mHasAMX_BF16 || mHasAMX_INT8 || mHasAMX_FP16) { utils::request_perm_xtile_data(); } static bool p = false; @@ -470,7 +472,8 @@ class CpuDevice { uint32_t L2Cache = 0, L1Cache = 0, L3Cache = 0; bool mHybrid = false, mClient = false; bool mHasAVX2 = false, mHasAVX_VNNI = false, mHasAVX = false, mHasAVX512_VNNI = false, mHasAMX_INT8 = false, - mHasAMX_BF16 = false, mHasAVX512F = false, mHasAVX512BW, mHasAVX512_BF16 = false, mHasAVX512_FP16 = false; + mHasAMX_BF16 = false, mHasAVX512F = false, mHasAVX512BW, mHasAVX512_BF16 = false, mHasAVX512_FP16 = false, + mHasAMX_FP16 = false; int numcores = 0; int numthreads = 0; std::vector P_core, E_core, SMT_core; diff --git a/bestla/bestla/bestla_gemm.h b/bestla/bestla/bestla_gemm.h index 88356ab25..67d2c4d5d 100644 --- a/bestla/bestla/bestla_gemm.h +++ b/bestla/bestla/bestla_gemm.h @@ -37,6 +37,7 @@ enum class CompType : uint16_t { COMP_FP32 = (tFP32 << SHIFT_A) | (tFP32 << SHIFT_B) | (tFP32 << SHIFT_C), COMP_BF16_FP32 = (tBF16 << SHIFT_A) | (tBF16 << SHIFT_B) | (tFP32 << SHIFT_C), COMP_FP16_FP16 = (tFP16 << SHIFT_A) | (tFP16 << SHIFT_B) | (tFP16 << SHIFT_C), + COMP_FP16_FP32 = (tFP16 << SHIFT_A) | (tFP16 << SHIFT_B) | (tFP32 << SHIFT_C), COMP_INT8_US_INT32 = (tU8 << SHIFT_A) | (tS8 << SHIFT_B) | (tS32 << SHIFT_C), COMP_INT8_UU_INT32 = (tU8 << SHIFT_A) | (tU8 << SHIFT_B) | (tS32 << SHIFT_C), COMP_INT8_SS_INT32 = (tS8 << SHIFT_A) | (tS8 << SHIFT_B) | (tS32 << SHIFT_C), @@ -2207,6 +2208,269 @@ class Amxbf16N16P2 : protected bestla::xbyak::JitAmxbf16 { } }; +template +class Amxfp16N16P2 : protected bestla::xbyak::JitAmxbf16 { + public: + static int constexpr RegLen = 16, PackRow = 2; + static_assert(_NTILE % RegLen == 0); + static_assert(_MTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; + static_assert(NRegs * MRegs + 2 <= TileCount); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AMX_FP16; + static auto constexpr COMPUTE = CompType::COMP_FP16_FP32; + typedef utils::fp16 AType; + typedef utils::fp16 BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + void* workspace; + }; + typedef long long (*func_t)(params*); + + int TmpRegCount = RegCount; + int TmpReg = 0; + int CTileCount = 0, ATileCount = 0, BTileCount = 0; + int CTile = 0, ATile = 0, BTile = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CTileCount = NRegs * MRegs; + auto tile_re = TileCount - CTileCount; + if (tile_re - 1 >= NRegs) { + BTileCount = NRegs; + ATileCount = tile_re - BTileCount; + } else if (tile_re - 1 >= MRegs) { + ATileCount = MRegs; + BTileCount = tile_re - ATileCount; + } else { + ATileCount = 1; + BTileCount = tile_re - ATileCount; + } + CTile = 0; + ATile = CTile + CTileCount; + BTile = ATile + ATileCount; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int kunrll) { + auto& reg_Bstride = reg_tmp1; + mov(reg_Bstride, NTILE * 4); + int mtiles = _mtile / RegLen; + for (int kk = 0; kk < kunrll; kk++) { + auto reg_Atmp = reg_tmp2; + if (mtiles == 1) { + reg_Atmp = reg_matAptr; + } else { + mov(reg_Atmp, reg_matAptr); + } + if (BTileCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + } + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tdpfp16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } else { + if (ATileCount == mtiles) { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < mtiles; mm++) { + tdpfp16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); + } + } + } else { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + tdpfp16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < CTileCount; i++) { + tilezero(Xbyak::Tmm(CTile + i)); + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + int mtnum = _mtile / 16; + for (int mm = 0; mm < mtnum; mm++) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); + } + if (mm != mtnum - 1) { + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + } + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_tmp, dword[parambase + OFFSET(workspace)]); + mov(reg_tmp1, NTILE * 4); + for (int mm = 0; mm < MRegs; mm++) { + for (int i = 0; i < NRegs; i++) { + tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); + } + } + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + int zunroll = TmpRegCount / NRegs; + for (int i = 0; i < _mtile; i += zunroll) { + int m_re = utils::remainsize(i, _mtile, zunroll); + for (int im = 0; im < m_re; im++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + } + outLocalLabel(); + } +}; + template class Amxint8N16P4 : protected bestla::xbyak::JitAmxint8 { public: @@ -4960,6 +5224,34 @@ class HCoreRowNAmxbf16 : public CoreCodeBaseAMX +class HCoreRowNAmxfp16 : public CoreCodeBaseAMX { + public: + using Base = CoreCodeBaseAMX; + using Code = typename Base::Code; + using AType = typename Code::AType; + using BType = typename Code::BType; + using CType = typename Code::CType; + + static void configure(int _M, int _N, int _K) { + code::AmxConfigure::configure(_M < 16 ? _M : 16, 16, Code::KTILE, sizeof(BType), + Base::getInstance()->mCodes[0].ATileCount, Base::getInstance()->mCodes[0].BTileCount, + Base::getInstance()->mCodes[0].CTileCount); + } + + static void forward(AType* matA, BType* matB, CType* matC, int _m, int _n, int _k, int _astride, int _bstride, + int _cstride, int kpos, void* tmpcache, size_t cachesize) { + auto param = + typename Code::params{matA, _astride, matB, _bstride, matC, _cstride, _k, _n, kpos == 0 ? 1 : 0, tmpcache}; + if (_m <= Code::MTILE) { + int idx = utils::updiv(_m, 16) - 1; + Base::getInstance()->mCodes[idx].mKernel(¶m); + } else { + assert(0); + } + } +}; + template class ICoreRowNAvx512vnni : public CoreCodeBase { public: diff --git a/bestla/bestla/kernel_avx2.h b/bestla/bestla/kernel_avx2.h index 0c3b23887..a340e5276 100644 --- a/bestla/bestla/kernel_avx2.h +++ b/bestla/bestla/kernel_avx2.h @@ -3013,24 +3013,29 @@ static inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* s int constexpr VLen = 8; auto vff = _mm256_set1_epi32(255); auto v0 = _mm256_set1_epi32(0); + int constexpr Unroll = 2; + int vblocksize_un = utils::padto_le(blocksize, VLen * Unroll); int vblocksize = utils::padto_le(blocksize, VLen); int colblk = utils::padto_le(col, blocksize); - for (int i = 0; i < row; i++) { + for (size_t i = 0; i < row; i++) { size_t j = 0; for (; j < colblk; j += blocksize) { __m256 vmaxval = _mm256_set1_ps(0.f); __m256 vminval = _mm256_set1_ps(0.f); size_t ij = 0; - for (; ij < vblocksize; ij += VLen) { - __m256 vsrc; - if constexpr (std::is_same_v) vsrc = _mm256_loadu_ps(&srcptr[(j + ij) + i * ld_src]); - if constexpr (std::is_same_v) { - auto vtmp = - _mm_loadu_si128(reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); - vsrc = ymm_cvt_bf16_fp32(vtmp); + for (; ij < vblocksize_un; ij += VLen * Unroll) { + for (size_t iu = 0; iu < Unroll; iu++) { + __m256 vsrc = load_T_fp32(&srcptr[(j + ij) + i * ld_src + iu * VLen]); + vmaxval = _mm256_max_ps(vmaxval, vsrc); + vminval = _mm256_min_ps(vminval, vsrc); + } + } + if (ij + VLen < vblocksize) { + for (; ij < vblocksize; ij += VLen) { + __m256 vsrc = load_T_fp32(&srcptr[(j + ij) + i * ld_src]); + vmaxval = _mm256_max_ps(vmaxval, vsrc); + vminval = _mm256_min_ps(vminval, vsrc); } - vmaxval = _mm256_max_ps(vmaxval, vsrc); - vminval = _mm256_min_ps(vminval, vsrc); } auto maxval = avx2_reduce_ps(vmaxval); auto minval = avx2_reduce_ps(vminval); @@ -3052,13 +3057,7 @@ static inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* s ij = 0; if (blkreduce) { for (; ij < vblocksize; ij += VLen) { - __m256 vsrc; - if constexpr (std::is_same_v) vsrc = _mm256_loadu_ps(&srcptr[(j + ij) + i * ld_src]); - if constexpr (std::is_same_v) { - auto vtmp = - _mm_loadu_si128(reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); - vsrc = ymm_cvt_bf16_fp32(vtmp); - } + __m256 vsrc = load_T_fp32(&srcptr[(j + ij) + i * ld_src]); vsrc = _mm256_mul_ps(vsrc, vrscale); auto vdsrc = _mm256_cvtps_epi32(vsrc); sum += avx2_reduce_epi32(vdsrc); @@ -3069,21 +3068,29 @@ static inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* s _mm_storel_epi64(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst]), vbsrc); } } else { - for (; ij < vblocksize; ij += VLen) { - __m256 vsrc; - if constexpr (std::is_same_v) vsrc = _mm256_loadu_ps(&srcptr[(j + ij) + i * ld_src]); - if constexpr (std::is_same_v) { - auto vtmp = - _mm_loadu_si128(reinterpret_cast<__m128i*>(const_cast(&srcptr[(j + ij) + i * ld_src]))); - vsrc = ymm_cvt_bf16_fp32(vtmp); + for (; ij < vblocksize_un; ij += VLen * Unroll) { + for (size_t iu = 0; iu < Unroll; iu++) { + __m256 vsrc = load_T_fp32(&srcptr[(j + ij) + i * ld_src + iu * VLen]); + vsrc = _mm256_mul_ps(vsrc, vrscale); + auto vdsrc = _mm256_cvtps_epi32(vsrc); + vdsrc = _mm256_add_epi32(vdsrc, vdzp); + vdsrc = _mm256_min_epi32(vdsrc, vff); + vdsrc = _mm256_max_epi32(vdsrc, v0); + auto vbsrc = avx2_cvtepi32_epu8(vdsrc); + _mm_storel_epi64(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst + iu * VLen]), vbsrc); + } + } + if (ij + VLen < vblocksize) { + for (; ij < vblocksize; ij += VLen) { + __m256 vsrc = load_T_fp32(&srcptr[(j + ij) + i * ld_src]); + vsrc = _mm256_mul_ps(vsrc, vrscale); + auto vdsrc = _mm256_cvtps_epi32(vsrc); + vdsrc = _mm256_add_epi32(vdsrc, vdzp); + vdsrc = _mm256_min_epi32(vdsrc, vff); + vdsrc = _mm256_max_epi32(vdsrc, v0); + auto vbsrc = avx2_cvtepi32_epu8(vdsrc); + _mm_storel_epi64(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst]), vbsrc); } - vsrc = _mm256_mul_ps(vsrc, vrscale); - auto vdsrc = _mm256_cvtps_epi32(vsrc); - vdsrc = _mm256_add_epi32(vdsrc, vdzp); - vdsrc = _mm256_min_epi32(vdsrc, vff); - vdsrc = _mm256_max_epi32(vdsrc, v0); - auto vbsrc = avx2_cvtepi32_epu8(vdsrc); - _mm_storel_epi64(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst]), vbsrc); } } for (; ij < blocksize; ij++) { diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index b976b52d5..6ffbf898e 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -296,6 +296,31 @@ static inline void vec_broadcast_epi32_2_4(__m512i* dst4regs, __m512i* src2regs) vec_broadcast_epi32_1_2(dst4regs + 2, src2regs + 1); } +template +static inline __m512 load_T_fp32(const T* srcptr) { + __m512 vtmp; + if constexpr (std::is_same_v) { + vtmp = _mm512_loadu_ps(srcptr); + } else if constexpr (std::is_same_v) { + vtmp = load_bf16_fp32(srcptr); + } else { + assert(0); + } + return vtmp; +} + +static inline __m512 load_s8_fp32(int8_t* srcptr) { + auto src_y = load_s8_s32(srcptr); + auto dst_y = _mm512_cvtepi32_ps(src_y); + return dst_y; +} + +static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) { + __m512i zero = _mm512_setzero_si512(); + __mmask64 blt0 = _mm512_movepi8_mask(b); + return _mm512_mask_sub_epi8(a, blt0, zero, a); +} + template static inline BTLA_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, _DT* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int8_t* zero_points, @@ -1212,27 +1237,33 @@ static inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* s int ld_dst, float* scales, int ld_scale, uint8_t* zps, int blocksize, float* blkreduce) { int constexpr VLen = 16; + int constexpr Unroll = 2; auto vff = _mm512_set1_epi32(255); auto v0 = _mm512_set1_epi32(0); int vblocksize = utils::padto_le(blocksize, VLen); + int vblocksize_un = utils::padto_le(blocksize, VLen * Unroll); int colblk = utils::padto_le(col, blocksize); - for (int i = 0; i < row; i += 1) { + for (size_t i = 0; i < row; i += 1) { size_t j = 0; for (; j < colblk; j += blocksize) { __m512 vmaxval = _mm512_set1_ps(0.f); __m512 vminval = _mm512_set1_ps(0.f); size_t ij = 0; - for (; ij < vblocksize; ij += VLen) { - __m512 vsrc; - if constexpr (std::is_same_v) vsrc = _mm512_loadu_ps(&srcptr[(j + ij) + i * ld_src]); - - if constexpr (std::is_same_v) { - auto tmp = _mm256_loadu_si256(reinterpret_cast(srcptr + j + ij + i * ld_src)); - vsrc = zmm_cvt_bf16_fp32(tmp); + for (; ij < vblocksize_un; ij += VLen * Unroll) { + for (size_t iu = 0; iu < Unroll; iu++) { + __m512 vsrc = load_T_fp32(srcptr + j + ij + i * ld_src + iu * VLen); + vmaxval = _mm512_max_ps(vmaxval, vsrc); + vminval = _mm512_min_ps(vminval, vsrc); + } + } + if (ij + VLen < vblocksize) { + for (; ij < vblocksize; ij += VLen) { + __m512 vsrc = load_T_fp32(srcptr + j + ij + i * ld_src); + vmaxval = _mm512_max_ps(vmaxval, vsrc); + vminval = _mm512_min_ps(vminval, vsrc); } - vmaxval = _mm512_max_ps(vmaxval, vsrc); - vminval = _mm512_min_ps(vminval, vsrc); } + auto maxval = _mm512_reduce_max_ps(vmaxval); auto minval = _mm512_reduce_min_ps(vminval); if (ij < blocksize) { @@ -1251,23 +1282,35 @@ static inline BTLA_CODE quantize_fp_u8_colblock(int row, int col, const SRC_T* s auto vdzp = _mm512_set1_epi32(zp); int sum = 0; ij = 0; - for (; ij < vblocksize; ij += VLen) { - __m512 vsrc; - if constexpr (std::is_same_v) vsrc = _mm512_loadu_ps(&srcptr[(j + ij) + i * ld_src]); - if constexpr (std::is_same_v) { - auto tmp = _mm256_loadu_si256(reinterpret_cast(srcptr + j + ij + i * ld_src)); - vsrc = zmm_cvt_bf16_fp32(tmp); + for (; ij < vblocksize_un; ij += VLen * Unroll) { + for (size_t iu = 0; iu < Unroll; iu++) { + __m512 vsrc = load_T_fp32(srcptr + j + ij + i * ld_src + iu * VLen); + vsrc = _mm512_mul_ps(vsrc, vrscale); + auto vdsrc = _mm512_cvtps_epi32(vsrc); + if (blkreduce) { + sum += _mm512_reduce_add_epi32(vdsrc); + } + vdsrc = _mm512_add_epi32(vdsrc, vdzp); + vdsrc = _mm512_min_epi32(vdsrc, vff); + vdsrc = _mm512_max_epi32(vdsrc, v0); + auto vbsrc = _mm512_cvtepi32_epi8(vdsrc); + _mm_storeu_si128(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst + iu * VLen]), vbsrc); + } + } + if (ij + VLen < vblocksize) { + for (; ij < vblocksize; ij += VLen) { + __m512 vsrc = load_T_fp32(srcptr + j + ij + i * ld_src); + vsrc = _mm512_mul_ps(vsrc, vrscale); + auto vdsrc = _mm512_cvtps_epi32(vsrc); + if (blkreduce) { + sum += _mm512_reduce_add_epi32(vdsrc); + } + vdsrc = _mm512_add_epi32(vdsrc, vdzp); + vdsrc = _mm512_min_epi32(vdsrc, vff); + vdsrc = _mm512_max_epi32(vdsrc, v0); + auto vbsrc = _mm512_cvtepi32_epi8(vdsrc); + _mm_storeu_si128(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst]), vbsrc); } - vsrc = _mm512_mul_ps(vsrc, vrscale); - auto vdsrc = _mm512_cvtps_epi32(vsrc); - if (blkreduce) { - sum += _mm512_reduce_add_epi32(vdsrc); - } - vdsrc = _mm512_add_epi32(vdsrc, vdzp); - vdsrc = _mm512_min_epi32(vdsrc, vff); - vdsrc = _mm512_max_epi32(vdsrc, v0); - auto vbsrc = _mm512_cvtepi32_epi8(vdsrc); - _mm_storeu_si128(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst]), vbsrc); } for (; ij < blocksize; ij++) { auto srcval = static_cast(srcptr[(j + ij) + i * ld_src]); @@ -1321,6 +1364,7 @@ static inline BTLA_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* s int constexpr VLen = 16; auto vpos = _mm512_set1_epi32(127); auto vneg = _mm512_set1_epi32(-128); + int VBlockSizeU3 = utils::padto_le(blocksize, VLen * 3); int VBlockSize = utils::padto_le(blocksize, VLen); int colblk = utils::padto_le(col, blocksize); for (int i = 0; i < row; i += 1) { @@ -1328,16 +1372,21 @@ static inline BTLA_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* s for (; j < colblk; j += blocksize) { __m512 vmaxval = _mm512_set1_ps(std::numeric_limits::min()); size_t ij = 0; - for (; ij < VBlockSize; ij += VLen) { - __m512 vsrc; - if constexpr (std::is_same_v) vsrc = _mm512_loadu_ps(&srcptr[(j + ij) + i * ld_src]); - if constexpr (std::is_same_v) { - auto tmp = _mm256_loadu_si256(reinterpret_cast(srcptr + j + ij + i * ld_src)); - vsrc = zmm_cvt_bf16_fp32(tmp); + for (; ij < VBlockSizeU3; ij += VLen * 3) { + for (int iu = 0; iu < 3; iu++) { + __m512 vsrc = load_T_fp32(srcptr + j + ij + i * ld_src + iu * VLen); + vsrc = _mm512_abs_ps(vsrc); + vmaxval = _mm512_max_ps(vmaxval, vsrc); } - vsrc = _mm512_abs_ps(vsrc); - vmaxval = _mm512_max_ps(vmaxval, vsrc); } + if (ij + VLen < VBlockSize) { + for (; ij < VBlockSize; ij += VLen) { + __m512 vsrc = load_T_fp32(srcptr + j + ij + i * ld_src); + vsrc = _mm512_abs_ps(vsrc); + vmaxval = _mm512_max_ps(vmaxval, vsrc); + } + } + auto maxval = _mm512_reduce_max_ps(vmaxval); if (ij < blocksize) { for (; ij < blocksize; ij++) { @@ -1351,21 +1400,29 @@ static inline BTLA_CODE quantize_fp_s8_colblock(int row, int col, const SRC_T* s auto vrscale = _mm512_set1_ps(rscale); ij = 0; int sum = 0; - - for (; ij < VBlockSize; ij += VLen) { - __m512 vsrc; - if constexpr (std::is_same_v) vsrc = _mm512_loadu_ps(&srcptr[(j + ij) + i * ld_src]); - if constexpr (std::is_same_v) { - auto tmp = _mm256_loadu_si256(reinterpret_cast(srcptr + j + ij + i * ld_src)); - vsrc = zmm_cvt_bf16_fp32(tmp); + for (; ij < VBlockSizeU3; ij += VLen * 3) { + for (int iu = 0; iu < 3; iu++) { + __m512 vsrc = load_T_fp32(srcptr + j + ij + i * ld_src + iu * VLen); + vsrc = _mm512_mul_ps(vsrc, vrscale); + auto vdsrc = _mm512_cvtps_epi32(vsrc); + sum += _mm512_reduce_add_epi32(vdsrc); + vdsrc = _mm512_min_epi32(vdsrc, vpos); + vdsrc = _mm512_max_epi32(vdsrc, vneg); + auto vbsrc = _mm512_cvtepi32_epi8(vdsrc); + _mm_storeu_si128(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst + iu * VLen]), vbsrc); + } + } + if (ij + VLen < VBlockSize) { + for (; ij < VBlockSize; ij += VLen) { + __m512 vsrc = load_T_fp32(srcptr + j + ij + i * ld_src); + vsrc = _mm512_mul_ps(vsrc, vrscale); + auto vdsrc = _mm512_cvtps_epi32(vsrc); + sum += _mm512_reduce_add_epi32(vdsrc); + vdsrc = _mm512_min_epi32(vdsrc, vpos); + vdsrc = _mm512_max_epi32(vdsrc, vneg); + auto vbsrc = _mm512_cvtepi32_epi8(vdsrc); + _mm_storeu_si128(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst]), vbsrc); } - vsrc = _mm512_mul_ps(vsrc, vrscale); - auto vdsrc = _mm512_cvtps_epi32(vsrc); - sum += _mm512_reduce_add_epi32(vdsrc); - vdsrc = _mm512_min_epi32(vdsrc, vpos); - vdsrc = _mm512_max_epi32(vdsrc, vneg); - auto vbsrc = _mm512_cvtepi32_epi8(vdsrc); - _mm_storeu_si128(reinterpret_cast<__m128i*>(&dstptr[(j + ij) + i * ld_dst]), vbsrc); } if (ij < blocksize) { for (; ij < blocksize; ij++) { @@ -4674,31 +4731,6 @@ inline BTLA_CODE decompress_kblock_s7_fp(utils::bit4x2* b4ptr, utils::bit2x4* b2 return ret; } -template -static inline __m512 load_T_fp32(const T* srcptr) { - __m512 vtmp; - if constexpr (std::is_same_v) { - vtmp = _mm512_loadu_ps(srcptr); - } else if constexpr (std::is_same_v) { - vtmp = load_bf16_fp32(srcptr); - } else { - assert(0); - } - return vtmp; -} - -static inline __m512 load_s8_fp32(int8_t* srcptr) { - auto src_y = load_s8_s32(srcptr); - auto dst_y = _mm512_cvtepi32_ps(src_y); - return dst_y; -} - -static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) { - __m512i zero = _mm512_setzero_si512(); - __mmask64 blt0 = _mm512_movepi8_mask(b); - return _mm512_mask_sub_epi8(a, blt0, zero, a); -} - template static inline void gemv_dequant_s32fp32(const float* asptr, int ldzp, const ScaleT* bsptr, __m512i* iacc, __m512* facc) { diff --git a/bestla/bestla/ut/bestla_benchmark.cpp b/bestla/bestla/ut/bestla_benchmark.cpp index c9d1f0195..f3c09ac03 100644 --- a/bestla/bestla/ut/bestla_benchmark.cpp +++ b/bestla/bestla/ut/bestla_benchmark.cpp @@ -352,6 +352,85 @@ class Benchmark_Bf16Bf16Fp32 { static Benchmark_Bf16Bf16Fp32 sBenchmark_Bf16Bf16Fp32; #endif +class Benchmark_Fp16Fp16Fp32 { + public: + Benchmark_Fp16Fp16Fp32() { + UT_START(); + benchmark_all(1, 4096, 4096); + benchmark_all(1024, 4096, 4096); + } + + using AType = utils::fp16; + using BType = utils::fp16; + using CType = float; + template + void benchmark(int m, int n, int k, int batch, AType* A, BType* B, CType* C, float timems, int threads) { + LOG_T log; + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = wrapper::gemm::LauncherBase; + + UT_Threading::set_threads(threads); + auto corestr = gemm::CoreAttr::to_str(Core_T::ID); + utils::timer tm; + auto tmpB = Launcher::PrologueB::createStorage(n, k); + std::vector packBs(batch, 0); + avector bufB(tmpB.mSize * batch); + for (size_t i = 0; i < batch; i++) { + packBs[i] = tmpB; + packBs[i].assign(bufB.data() + i * tmpB.mSize); + Launcher::PrologueB::packWeight(n, k, {B + i * n * k, n, &packBs[i]}, UT_Threading::get()); + } + auto psize = (size_t)m * n * k * 2; + tm.start(); + while (tm.stop() < timems) { + for (size_t i = 0; i < batch; i++) { + log.start(); + GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {A + i * m * k, k}, {0, 0, &packBs[i]}, {C + i * m * n, n}}; + parallel::GemmRun(args, UT_Threading::get()); + log.stop(); + if (tm.stop() >= timems) { + break; + } + } + } + log.record(); + double flops = double(psize) / log.min_val / 1e6; + printf("Threads %d %s %s Flops:%.3f PerCoreFlops:%.3f\n", threads, corestr, log.get_log_str(), flops, + flops / threads); + } + + void benchmark_all(int m, int n, int k) { + auto memsize = gemm_memsize(m, n, k, BTLA_DTYPE::F16, BTLA_DTYPE::F16, BTLA_DTYPE::F32); + auto batch = auto_batch(memsize); + printf("%d %d %d %d %s %s %s\n", m, n, k, batch, bestla_dtype_str(BTLA_DTYPE::F16), + bestla_dtype_str(BTLA_DTYPE::F16), bestla_dtype_str(BTLA_DTYPE::F32)); + avector A(size_t(m) * k * batch); + avector B(size_t(k) * n * batch); + avector C(size_t(m) * n * batch); + fill_buffer_randn(A.data(), k * m, AType(-0.5f), AType(0.5f)); + fill_buffer_randn(B.data(), k * n, AType(-0.5f), AType(0.5f)); + for (size_t i = 0; i < batch - 1; i++) { + memcpy(A.data() + i * m * k, A.data(), m * k * sizeof(AType)); + memcpy(B.data() + i * n * k, B.data(), n * k * sizeof(BType)); + } + using LOG = timer_statistics_logger; + float testtime = float(TestMs); + GetCPUDevice(); + auto threads_cfg = UT_Threading::get_threads_config(); + for (auto threads : threads_cfg) { + if (_cd->AMX_FP16()) { + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + benchmark, LOG>(m, n, k, batch, A.data(), B.data(), C.data(), testtime, threads); + } + } + } +}; +#ifdef BTLA_UT_WRAPPER +static Benchmark_Fp16Fp16Fp32 sBenchmark_Fp16Fp16Fp32; +#endif + class Benchmark_Fp16Fp16Fp16 { public: Benchmark_Fp16Fp16Fp16() { diff --git a/bestla/bestla/ut/bestla_gemm.cpp b/bestla/bestla/ut/bestla_gemm.cpp index b4a885cb4..313cbbcf9 100644 --- a/bestla/bestla/ut/bestla_gemm.cpp +++ b/bestla/bestla/ut/bestla_gemm.cpp @@ -65,6 +65,36 @@ void ref_fp16(utils::fp16* matA, utils::fp16* matB, utils::fp16* matC, int _m, i } } +template +void ref_fp16_fp32(utils::fp16* matA, utils::fp16* matB, float* matC, int _m, int _n, int _k, int _astride, + int _bstride, int _cstride, int kpos) { + int lda = _astride / sizeof(utils::fp16); + int ldb = _bstride / sizeof(utils::fp16); + int ldc = _cstride / sizeof(float); + int constexpr KPack = 4 / sizeof(utils::fp16); + for (int i = 0; i < _m; i++) { + for (int j = 0; j < _n; j += NTILE) { + for (int ij = 0; ij < NTILE; ij++) { + if (j + ij >= _n) { + continue; + } + float tmp = 0; + for (int k = 0; k < _k; k += KPack) { + for (int ik = 0; ik < KPack; ik++) { + if (k + ik >= _k) { + continue; + } + auto tmpA = utils::cast(utils::fp16{matA[i * lda + k + ik]}); + auto tmpB = utils::cast(utils::fp16{matB[k * NTILE + ij * KPack + ik + j * ldb]}); + tmp += tmpA * tmpB; + } + } + matC[i * ldc + j + ij] = tmp; + } + } + } +} + template void ref_fp32(float* matA, float* matB, float* matC, int _m, int _n, int _k, int _astride, int _bstride, int _cstride, int kpos) { @@ -778,7 +808,6 @@ class UT_GEMM_AMXINT8_KBLOCK { UT_GEMM_AMXINT8_KBLOCK() { UT_START(); CheckISA(AMX_INT8); - request_perm_xtile_data(); ut_splitblock<48, 16>(16, 144, 128, 64); ut_splitblock<48, 16>(16, 144, 128, 128); ut_splitblock<48, 16>(16, 144, 256, 128); @@ -1035,7 +1064,6 @@ class UT_GEMM_AMXBF16 { UT_GEMM_AMXBF16() { UT_START(); CheckISA(AMX_BF16); - request_perm_xtile_data(); ut<32, 32>(32, 32, 64); ut<32, 32>(4, 96, 96); ut<48, 0>(4, 96, 96); @@ -1070,12 +1098,50 @@ class UT_GEMM_AMXBF16 { static UT_GEMM_AMXBF16 sUT_GEMM_AMXBF16; #endif +class UT_GEMM_AMXFP16 { + public: + UT_GEMM_AMXFP16() { + UT_START(); + CheckISA(AMX_FP16); + ut<32, 32>(32, 32, 64); + ut<32, 32>(4, 96, 96); + ut<48, 0>(4, 96, 96); + ut<64, 16>(4, 128, 96); + } + + template + void ut(int m, int n, int k) { + printf("Test Case: %d %d %d\n", m, n, k); + using Core = gemm::HCoreRowNAmxfp16; + static Core gemm; + if (n % Core::Code::NTILE != 0) { + return; + } + if (k % Core::Code::KTILE != 0) { + return; + } + + avector matAfp16(m * k), matBfp16(k * n); + avector matC(Core::Code::MTILE * n), refC(Core::Code::MTILE * n); + fill_buffer_randn(matAfp16.data(), matAfp16.size(), utils::fp16(-0.5f), utils::fp16(0.5f)); + fill_buffer_randn(matBfp16.data(), matBfp16.size(), utils::fp16(-0.5f), utils::fp16(0.5f)); + ref_fp16_fp32(matAfp16.data(), matBfp16.data(), refC.data(), m, n, k, k * 2, k * 2, n * 4, 0); + gemm.configure(m, n, k); + + gemm.forward(matAfp16.data(), matBfp16.data(), matC.data(), m, n, k, k * sizeof(fp16), k * sizeof(fp16), + n * sizeof(float), 0, cache, CacheSize); + ut::buffer_error(refC.data(), matC.data(), m * n, 0.001f); + } +}; +#ifdef BTLA_UT_GEMM +static UT_GEMM_AMXFP16 sUT_GEMM_AMXFP16; +#endif + class UT_GEMM_AMXINT8 { public: UT_GEMM_AMXINT8() { UT_START(); CheckISA(AMX_INT8); - request_perm_xtile_data(); ut<32, 32>(32, 64, 64 * 3); ut<48, 16>(16, 96, 64 * 3); ut<32, 32>(4, 64, 64 * 3); diff --git a/bestla/bestla/ut/bestla_ut.h b/bestla/bestla/ut/bestla_ut.h index 99109a05f..b6118c3d4 100644 --- a/bestla/bestla/ut/bestla_ut.h +++ b/bestla/bestla/ut/bestla_ut.h @@ -20,6 +20,7 @@ namespace bestla { namespace ut { using sAVX512F = gemm::SCoreRowNAvx512f<48, 8>; using sAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>; +using sAMX_FP16 = gemm::HCoreRowNAmxfp16<64, 16>; using sAVX512_FP16 = gemm::HCoreRowNAvx512fp16<96, 8>; using sAVX_VNNI = gemm::ICoreRowNAvxvnni<24, 4>; using sAVX_VNNI_SS = gemm::ICoreRowNAvxvnniSS<24, 4>; @@ -584,6 +585,20 @@ static inline void gemmref_fp16fp16fp16(int m, int n, int k, utils::fp16* A, uti } } +static inline void gemmref_fp16fp16fp32(int m, int n, int k, utils::fp16* A, utils::fp16* B, float* C, int lda, int ldb, + int ldc) { +#pragma omp parallel for + for (int j = 0; j < m; j++) { + for (int i = 0; i < n; i += 1) { + float tmp = 0; + for (int ik = 0; ik < k; ik++) { + tmp += float(A[ik + j * lda]) * float(B[ik * ldb + i]); + } + C[i + j * ldc] = tmp; + } + } +} + struct UT_GEMMData_Row_fp16 { utils::aligned_vector matA, matB, matC, matD; int M, N, K, LDA, LDB, LDC, LDD; diff --git a/bestla/bestla/ut/bestla_wrapper.cpp b/bestla/bestla/ut/bestla_wrapper.cpp index 9f20b7ec8..eb96eb69a 100644 --- a/bestla/bestla/ut/bestla_wrapper.cpp +++ b/bestla/bestla/ut/bestla_wrapper.cpp @@ -284,5 +284,45 @@ class UT_Fp16Fp16Fp16 { #ifdef BTLA_UT_WRAPPER static UT_Fp16Fp16Fp16 sUT_Fp16Fp16Fp16; #endif + +class UT_Fp16Fp16Fp32 { + public: + UT_Fp16Fp16Fp32() { + UT_START(); + CheckISA(AMX_FP16); + ut(1, 1, 1); + ut(8, 48, 2); + ut(8, 4096, 4096); + ut(384, 768, 768); + ut(1024, 1024, 1024); + ut(1024, 1536, 1536); + } + + template + void ut(int m, int n, int k) { + printf("Test Case %s: %d %d %d core:%s\n", __FUNCTION__, m, n, k, gemm::CoreAttr::to_str(GemmCore_T::ID)); + using Launcher = + wrapper::gemm::LauncherBase; + + using Parallel = parallel::gemm::SchedulerBase; + auto packw = Launcher::PrologueB::createStorage(n, k); + avector buffer(packw.mSize); + packw.assign(buffer.data()); + avector matAbf16(m * k), matBbf16(k * n); + avector matC(m * n), refC(m * n); + fill_buffer_randn(matAbf16.data(), matAbf16.size(), utils::fp16(-0.5f), utils::fp16(0.5f)); + fill_buffer_randn(matBbf16.data(), matBbf16.size(), utils::fp16(-0.5f), utils::fp16(0.5f)); + Launcher::PrologueB::packWeight(n, k, {matBbf16.data(), n, &packw}, UT_Threading::get()); + gemmref_fp16fp16fp32(m, n, k, matAbf16.data(), matBbf16.data(), refC.data(), k, n, n); + GemmProblem gp(1, m, n, k); + typename Launcher::Param args{gp, {matAbf16.data(), k}, {matBbf16.data(), n, &packw}, {matC.data(), n}}; + parallel::GemmRun(args, UT_Threading::get()); + buffer_error(refC.data(), matC.data(), refC.size(), 0.0002f * k); + } +}; +#ifdef BTLA_UT_WRAPPER +static UT_Fp16Fp16Fp32 sUT_Fp16Fp16Fp32; +#endif } // namespace ut } // namespace bestla diff --git a/neural_speed/core/layers/bestla_defs.h b/neural_speed/core/layers/bestla_defs.h index 392999989..38ac2954c 100644 --- a/neural_speed/core/layers/bestla_defs.h +++ b/neural_speed/core/layers/bestla_defs.h @@ -39,18 +39,18 @@ using tAVX_VNNI = gemm::ICoreRowNAvxvnni<24, 4>; using tAVX512F = gemm::SCoreRowNAvx512f<48, 8>; using tAVX512BW = gemm::ICoreRowNAvx512bw<48, 8>; using tAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>; -using tAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>; -using tAVX512_BF16 = gemm::HCoreRowNAvx512bf16<64, 4>; +using tAMX_BF16 = gemm::HCoreRowNAmxbf16<48, 16>; +using tAVX512_BF16 = gemm::HCoreRowNAvx512bf16<48, 8>; using tAVX512_FP16 = gemm::HCoreRowNAvx512fp16<96, 8>; -using tAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>; -using tAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>; +using tAMX_INT8_US = gemm::ICoreRowNAmxint8<48, 16>; +using tAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<48, 16>; using tAVX2_VNNI_KBlock = gemm::ICoreRowNAvx2vnniKBlock<24, 2>; using tAVX_VNNI_KBlock = gemm::ICoreRowNAvxvnniKBlock<24, 2>; using tAVX512BW_KBlock = gemm::ICoreRowNAvx512bwKBlock<48, 8>; using tAVX512_VNNI_KBlock = gemm::ICoreRowNAvx512vnniKBlock<48, 4>; -using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<64, 16>; -using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<64, 16>; +using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<48, 16>; +using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<48, 16>; template using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger;