Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

[BesTLA] Add new ISA support: AMX_FP16 #282

Merged
merged 7 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bestla/bestla/bestla.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum class BTLA_ISA : uint8_t {
AMX_INT8,
AVX512_FP16,
AVX512_BF16,
AMX_FP16,
ISA_COUNT,
};
enum class BTLA_DTYPE : uint32_t {
Expand Down
7 changes: 5 additions & 2 deletions bestla/bestla/bestla_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ class CpuDevice {
inline bool AMX_BF16() { return mHasAMX_BF16; }
inline bool AVX512_BF16() { return mHasAVX512_BF16; }
inline bool AVX512_FP16() { return mHasAVX512_FP16; }
inline bool AMX_FP16() { return mHasAMX_FP16; }
inline float* const getPE() { return PE; }
inline int getPcoreNum() { return static_cast<int>(P_core.size()); }
inline int getEcoreNum() { return static_cast<int>(E_core.size()); }
Expand All @@ -252,8 +253,9 @@ class CpuDevice {
ADD_FLAG(AMX_INT8);
ADD_FLAG(AVX512_BF16);
ADD_FLAG(AVX512_FP16);
ADD_FLAG(AMX_FP16);
numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
if (mHasAMX_BF16 || mHasAMX_INT8) {
if (mHasAMX_BF16 || mHasAMX_INT8 || mHasAMX_FP16) {
utils::request_perm_xtile_data();
}
static bool p = false;
Expand Down Expand Up @@ -470,7 +472,8 @@ class CpuDevice {
uint32_t L2Cache = 0, L1Cache = 0, L3Cache = 0;
bool mHybrid = false, mClient = false;
bool mHasAVX2 = false, mHasAVX_VNNI = false, mHasAVX = false, mHasAVX512_VNNI = false, mHasAMX_INT8 = false,
mHasAMX_BF16 = false, mHasAVX512F = false, mHasAVX512BW, mHasAVX512_BF16 = false, mHasAVX512_FP16 = false;
mHasAMX_BF16 = false, mHasAVX512F = false, mHasAVX512BW, mHasAVX512_BF16 = false, mHasAVX512_FP16 = false,
mHasAMX_FP16 = false;
int numcores = 0;
int numthreads = 0;
std::vector<int> P_core, E_core, SMT_core;
Expand Down
292 changes: 292 additions & 0 deletions bestla/bestla/bestla_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ enum class CompType : uint16_t {
COMP_FP32 = (tFP32 << SHIFT_A) | (tFP32 << SHIFT_B) | (tFP32 << SHIFT_C),
COMP_BF16_FP32 = (tBF16 << SHIFT_A) | (tBF16 << SHIFT_B) | (tFP32 << SHIFT_C),
COMP_FP16_FP16 = (tFP16 << SHIFT_A) | (tFP16 << SHIFT_B) | (tFP16 << SHIFT_C),
COMP_FP16_FP32 = (tFP16 << SHIFT_A) | (tFP16 << SHIFT_B) | (tFP32 << SHIFT_C),
COMP_INT8_US_INT32 = (tU8 << SHIFT_A) | (tS8 << SHIFT_B) | (tS32 << SHIFT_C),
COMP_INT8_UU_INT32 = (tU8 << SHIFT_A) | (tU8 << SHIFT_B) | (tS32 << SHIFT_C),
COMP_INT8_SS_INT32 = (tS8 << SHIFT_A) | (tS8 << SHIFT_B) | (tS32 << SHIFT_C),
Expand Down Expand Up @@ -2207,6 +2208,269 @@ class Amxbf16N16P2 : protected bestla::xbyak::JitAmxbf16 {
}
};

template <int _NTILE, int _MTILE = 0>
class Amxfp16N16P2 : protected bestla::xbyak::JitAmxbf16 {
public:
static int constexpr RegLen = 16, PackRow = 2;
static_assert(_NTILE % RegLen == 0);
static_assert(_MTILE % RegLen == 0);
static int constexpr NRegs = _NTILE / RegLen;
static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen;
static_assert(NRegs * MRegs + 2 <= TileCount);
static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32;
static int constexpr KUNROLL = 2;
static auto constexpr ISA = BTLA_ISA::AMX_FP16;
static auto constexpr COMPUTE = CompType::COMP_FP16_FP32;
typedef utils::fp16 AType;
typedef utils::fp16 BType;
typedef float CType;

struct params {
AType* matA;
int astride;
BType* matB;
int bstride;
CType* matC;
int cstride;
int k;
int n;
int init;
void* workspace;
};
typedef long long (*func_t)(params*);

int TmpRegCount = RegCount;
int TmpReg = 0;
int CTileCount = 0, ATileCount = 0, BTileCount = 0;
int CTile = 0, ATile = 0, BTile = 0;
static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType);
static int constexpr AKStepSize = KTILE * sizeof(AType);

void generate_code(int _mtile) {
assign_regs();
reset();
generate_mtile(_mtile);
ready();
mKernel = getCode<func_t>();
}
func_t mKernel = nullptr;

protected:
Xbyak::Reg64 parambase;
Xbyak::Reg64 reg_matAptr;
Xbyak::Reg64 reg_matBptr;
Xbyak::Reg64 reg_matCptr;
Xbyak::Reg64 reg_ksize;
Xbyak::Reg64 reg_nsize;
Xbyak::Reg64 reg_cstride;
Xbyak::Reg64 reg_astride;
Xbyak::Reg64 reg_iterk;
Xbyak::Reg64 reg_itern;
Xbyak::Reg64 reg_tmp;
Xbyak::Reg64 reg_tmp1;
Xbyak::Reg64 reg_tmp2;
Xbyak::Reg64 reg_tmp3;
Xbyak::Reg64 reg_ret = rax;

void assign_regs() {
CTileCount = NRegs * MRegs;
auto tile_re = TileCount - CTileCount;
if (tile_re - 1 >= NRegs) {
BTileCount = NRegs;
ATileCount = tile_re - BTileCount;
} else if (tile_re - 1 >= MRegs) {
ATileCount = MRegs;
BTileCount = tile_re - ATileCount;
} else {
ATileCount = 1;
BTileCount = tile_re - ATileCount;
}
CTile = 0;
ATile = CTile + CTileCount;
BTile = ATile + ATileCount;
}

void generate_mtile(int _mtile) {
inLocalLabel(); // use local label for multiple instance
Xbyak::util::StackFrame st(this, 1, 11, 16 * 10);
parambase = st.p[0];
reg_matAptr = st.t[0];
reg_matBptr = st.t[1];
reg_matCptr = st.t[0];
reg_ksize = st.t[2];
reg_astride = st.t[3];
reg_cstride = st.t[3];
reg_iterk = st.t[4];
reg_tmp = st.t[5];
reg_tmp1 = st.t[6];
reg_tmp2 = st.t[7];
reg_tmp3 = st.t[10];
reg_nsize = st.t[8];
reg_itern = st.t[9];
reg_ret = rax;

vreg_push(rsp);

load32(reg_ksize, ptr[parambase + OFFSET(k)]);
load32(reg_nsize, ptr[parambase + OFFSET(n)]);
xor_(reg_itern, reg_itern);
L(".nloop");
init_regs(_mtile);
mov(reg_matAptr, ptr[parambase + OFFSET(matA)]);
load32(reg_astride, ptr[parambase + OFFSET(astride)]);
mov(reg_matBptr, ptr[parambase + OFFSET(matB)]);
load32(reg_tmp, ptr[parambase + OFFSET(bstride)]);
imul(reg_tmp, reg_itern);
lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]);
xor_(reg_iterk, reg_iterk);
generate_kloop(_mtile);
write_back(_mtile);
add(reg_itern, NTILE);
cmp(reg_itern, reg_nsize);
jb(".nloop");
mov(reg_ret, 0);
vreg_pop(rsp);

outLocalLabel(); // end of local label
}

void generate_kloop(int _mtile) {
inLocalLabel();
mov(reg_tmp, reg_ksize);
padto_le(reg_tmp, KUNROLL * KTILE);
cmp(reg_tmp, 0);
jz(".kloop", T_NEAR);
L(".unkloop");
generate_fma(_mtile, KUNROLL);
add(reg_matAptr, KUNROLL * AKStepSize);
add(reg_matBptr, KUNROLL * BKStepSize);
add(reg_iterk, KUNROLL * KTILE);
cmp(reg_iterk, reg_tmp); // k iteration variable
jb(".unkloop");
cmp(reg_tmp, reg_ksize);
jge(".kend", T_NEAR);
L(".kloop");
generate_fma(_mtile, 1);
add(reg_matAptr, 1 * AKStepSize);
add(reg_matBptr, 1 * BKStepSize);
add(reg_iterk, 1 * KTILE);
cmp(reg_iterk, reg_ksize); // k iteration variable
jb(".kloop");
L(".kend");
outLocalLabel();
}

void generate_fma(int _mtile, int kunrll) {
auto& reg_Bstride = reg_tmp1;
mov(reg_Bstride, NTILE * 4);
int mtiles = _mtile / RegLen;
for (int kk = 0; kk < kunrll; kk++) {
auto reg_Atmp = reg_tmp2;
if (mtiles == 1) {
reg_Atmp = reg_matAptr;
} else {
mov(reg_Atmp, reg_matAptr);
}
if (BTileCount == NRegs) {
for (int i = 0; i < NRegs; i++) {
tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]);
}
for (int mm = 0; mm < mtiles; mm++) {
tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]);
for (int i = 0; i < NRegs; i++) {
tdpfp16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i));
}
if (mm != mtiles - 1) {
lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
}
}
} else {
if (ATileCount == mtiles) {
for (int mm = 0; mm < mtiles; mm++) {
tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]);
if (mm != mtiles - 1) {
lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
}
}
for (int i = 0; i < NRegs; i++) {
tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]);
for (int mm = 0; mm < mtiles; mm++) {
tdpfp16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile));
}
}
} else {
for (int mm = 0; mm < mtiles; mm++) {
tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]);
for (int i = 0; i < NRegs; i++) {
tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]);
tdpfp16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile));
}
if (mm != mtiles - 1) {
lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
}
}
}
}
}
}

void init_regs(int _mtile) {
inLocalLabel();
load32(reg_tmp, ptr[parambase + OFFSET(init)]);
cmp(reg_tmp, 0);
je(".read", T_NEAR);
for (int i = 0; i < CTileCount; i++) {
tilezero(Xbyak::Tmm(CTile + i));
}
jmp(".end", T_NEAR);
L(".read");
mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
int mtnum = _mtile / 16;
for (int mm = 0; mm < mtnum; mm++) {
for (int i = 0; i < NRegs; i++) {
tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]);
}
if (mm != mtnum - 1) {
lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]);
lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]);
}
}
L(".end");
outLocalLabel();
}

void write_back(int _mtile) {
inLocalLabel();
mov(reg_tmp, dword[parambase + OFFSET(workspace)]);
mov(reg_tmp1, NTILE * 4);
for (int mm = 0; mm < MRegs; mm++) {
for (int i = 0; i < NRegs; i++) {
tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i));
}
}
mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
int zunroll = TmpRegCount / NRegs;
for (int i = 0; i < _mtile; i += zunroll) {
int m_re = utils::remainsize(i, _mtile, zunroll);
for (int im = 0; im < m_re; im++) {
for (int j = 0; j < NRegs; j++) {
vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]);
vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j));
}
add(reg_matCptr, reg_cstride);
}
}
outLocalLabel();
}
};

template <typename AT, typename BT, int _NTILE, int _MTILE = 0>
class Amxint8N16P4 : protected bestla::xbyak::JitAmxint8 {
public:
Expand Down Expand Up @@ -4960,6 +5224,34 @@ class HCoreRowNAmxbf16 : public CoreCodeBaseAMX<code::Amxbf16N16P2, _NTILE, _MTI
}
};

template <int _NTILE, int _MTILE = 0>
class HCoreRowNAmxfp16 : public CoreCodeBaseAMX<code::Amxfp16N16P2, _NTILE, _MTILE> {
public:
using Base = CoreCodeBaseAMX<code::Amxfp16N16P2, _NTILE, _MTILE>;
using Code = typename Base::Code;
using AType = typename Code::AType;
using BType = typename Code::BType;
using CType = typename Code::CType;

static void configure(int _M, int _N, int _K) {
code::AmxConfigure::configure(_M < 16 ? _M : 16, 16, Code::KTILE, sizeof(BType),
Base::getInstance()->mCodes[0].ATileCount, Base::getInstance()->mCodes[0].BTileCount,
Base::getInstance()->mCodes[0].CTileCount);
}

static void forward(AType* matA, BType* matB, CType* matC, int _m, int _n, int _k, int _astride, int _bstride,
int _cstride, int kpos, void* tmpcache, size_t cachesize) {
auto param =
typename Code::params{matA, _astride, matB, _bstride, matC, _cstride, _k, _n, kpos == 0 ? 1 : 0, tmpcache};
if (_m <= Code::MTILE) {
int idx = utils::updiv(_m, 16) - 1;
Base::getInstance()->mCodes[idx].mKernel(&param);
} else {
assert(0);
}
}
};

template <int _NTILE, int _MTILE = 0>
class ICoreRowNAvx512vnni : public CoreCodeBase<code::Avx512vnniN16P4, _NTILE, _MTILE> {
public:
Expand Down
Loading
Loading