Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

[BesTLA] First-token inference optimization #271

Merged
merged 37 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a8ffea6
add per-channel kblock template
luoyu-intel May 23, 2024
e83a8d5
add gemv support for pckblock. revise all benchmark cases and UT cases.
luoyu-intel May 27, 2024
6d40043
fix bandwith calc of CompFp32 and CompBf16
luoyu-intel May 27, 2024
af84b4f
use correct core number
luoyu-intel May 27, 2024
463aeaf
update thread pool
yuchengliu1 May 27, 2024
9b9777c
fix bug
luoyu-intel May 27, 2024
26140bc
fix bug
luoyu-intel May 27, 2024
95c8791
update kernels with gemm and qkv fusion
luoyu-intel May 27, 2024
c094c59
refactor epilogue (removing ISA from the class' template)
luoyu-intel May 27, 2024
20262f9
update fnn and ip_add
luoyu-intel May 27, 2024
81c9944
fix compile on gcc
luoyu-intel May 27, 2024
d000ba0
fix gcc template
luoyu-intel May 27, 2024
0308aea
fix compile
luoyu-intel May 27, 2024
05aecfc
update amx template
luoyu-intel May 27, 2024
1d55c37
fix UT compile
luoyu-intel May 27, 2024
2035d24
fix benchmark compile
luoyu-intel May 27, 2024
6db48e7
revert NTILE of amx_int8
luoyu-intel May 27, 2024
d9636a9
reduce templates
luoyu-intel May 28, 2024
5169105
fix deprecated UTs. optimize cache block strategy
luoyu-intel May 28, 2024
a78f35d
Enlarge stack size on windows
luoyu-intel May 28, 2024
499d459
revert NTILE of amx_int8
luoyu-intel May 28, 2024
76c1331
update cache config
luoyu-intel May 28, 2024
67eacff
add mul support
luoyu-intel May 29, 2024
f0a2ad7
add mul implementation
luoyu-intel May 29, 2024
b4bd875
support tensor mul tensor
luoyu-intel May 29, 2024
04d7f5d
fix compile on gcc
luoyu-intel May 29, 2024
3c591ab
clang-format
luoyu-intel May 29, 2024
12ef415
fix doc
luoyu-intel May 29, 2024
e8829a2
code scan fix
luoyu-intel May 29, 2024
41f146b
fix compile
luoyu-intel May 29, 2024
4e99482
fix batch bug
luoyu-intel May 30, 2024
18250c1
comment add
luoyu-intel May 30, 2024
dc8e2c1
comment mul
luoyu-intel May 30, 2024
3a79072
enable mul&add
luoyu-intel May 30, 2024
71b017c
clang-format
luoyu-intel May 30, 2024
86e0a94
fix the code bug of mul and add. use new kernels in custom::epilogue
luoyu-intel May 31, 2024
36b2063
clang-format
luoyu-intel May 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ endif()

if (MSVC)
add_compile_definitions(_CRT_SECURE_NO_WARNINGS NOMINMAX)
add_compile_options(/bigobj)
if (BUILD_SHARED_LIBS)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
Expand Down
4 changes: 4 additions & 0 deletions bestla/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ if(UT_BUILD)
target_link_options(${PROJECT_NAME}_ut PRIVATE -fsanitize=address)
endif()
target_link_options(${PROJECT_NAME}_ut PRIVATE -lpthread)
else()
target_link_options(${PROJECT_NAME}_ut PUBLIC /STACK:5242880)
endif()

add_ut_flag(BTLA_UT_DEBUG)
Expand Down Expand Up @@ -137,6 +139,8 @@ if(BTLA_UT_BENCHMARK)
endif()
if(NOT WIN32)
target_link_options(${PROJECT_NAME}_benchmark PRIVATE -lpthread)
else()
target_link_options(${PROJECT_NAME}_benchmark PUBLIC /STACK:5242880)
endif()
target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE ${PROJECT_NAME} ${sycl_libs})
endif(BTLA_UT_BENCHMARK)
3 changes: 2 additions & 1 deletion bestla/bestla/bestla_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
#include "bestla_utils.h"
#ifdef _WIN32
#include <windows.h>
#define FIXED_CACHE 1
#else
#include <sched.h>
#define FIXED_CACHE 0
#endif

#define FIXED_CACHE_SIZE ((1 << 20) - (128 << 10))
#define FIXED_CACHE 1

namespace bestla {

Expand Down
209 changes: 101 additions & 108 deletions bestla/bestla/bestla_epilogue.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,102 @@ namespace bestla {
namespace epilogue {
namespace gemm {

struct ParamPcKBlockCompInt8Epilogue {
void* scalesB;
BTLA_DTYPE scaleBdtype;
float* scalesA;
// optional if A asym
uint8_t* zpA = nullptr;
void* reduceB = nullptr;
BTLA_DTYPE reduceBdtype = BTLA_DTYPE::F32;
// optional if B asym
int8_t* zpB = nullptr;
float* reduceA = nullptr;
int K = 1;
};
template <class Fp32Epilogue>
class PcKBlockCompInt8Epilogue {
public:
using Fp32Param = typename Fp32Epilogue::Param;
struct Param {
ParamPcKBlockCompInt8Epilogue param1;
Fp32Param param2;
};
using Fp32Epi = Fp32Epilogue;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const int32_t* srcptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
BTLA_CODE ret = BTLA_CODE::NotSupport;
float* scab = nullptr;
size_t ScaleBTmpSize = N * sizeof(float);
size_t ReduceBTmpSize = N * sizeof(float);
assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize));
auto& param1 = _param.param1;
if (param1.scaleBdtype == BTLA_DTYPE::BF16) {
auto scache = reinterpret_cast<float*>(tmpcache);
ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward<ISA_T>(
reinterpret_cast<utils::bf16*>(param1.scalesB) + N_offset, scache, 1, N, N, N, false);
assert(ret == BTLA_CODE::Success);
scab = scache;
} else if (param1.scaleBdtype == BTLA_DTYPE::F32) {
scab = reinterpret_cast<float*>(param1.scalesB) + N_offset;
}
float* redb = nullptr;
if (param1.reduceB) {
if (param1.reduceBdtype == BTLA_DTYPE::BF16) {
auto rcache = reinterpret_cast<float*>(reinterpret_cast<char*>(tmpcache) + ScaleBTmpSize);
ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward<ISA_T>(
reinterpret_cast<utils::bf16*>(param1.reduceB) + N_offset, rcache, 1, N, N, N, false);
assert(ret == BTLA_CODE::Success);
redb = rcache;
} else if (param1.reduceBdtype == BTLA_DTYPE::F32) {
redb = reinterpret_cast<float*>(param1.reduceB) + N_offset;
}
}
auto tmpfp32ptr = reinterpret_cast<float*>(const_cast<int32_t*>(srcptr));
ret = kernel::wrapper::DequanS32Fp32::template forward<ISA_T>(srcptr, cachestep, tmpfp32ptr, cachestep, M, N,
param1.scalesA + M_offset, 1, scab);
assert(ret == BTLA_CODE::Success);

if (param1.zpA == nullptr) {
if (param1.zpB == nullptr) {
} else {
ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei<ISA_T>(
tmpfp32ptr, cachestep, M, N, param1.zpB + N_offset, scab, 1, param1.reduceA + M_offset);
}
} else {
if (param1.zpB == nullptr) {
ret = kernel::wrapper::RemoveZeroPointBias::template forward_act<ISA_T>(
tmpfp32ptr, cachestep, M, N, param1.zpA + M_offset, param1.scalesA + M_offset, 1, redb);
} else {
ret = kernel::wrapper::RemoveZeroPointBias::template forward_both<ISA_T>(
tmpfp32ptr, cachestep, M, N, param1.zpA + M_offset, param1.zpB + N_offset, param1.scalesA + M_offset, scab,
1, param1.K, param1.reduceA + M_offset, redb);
}
}
Fp32Epilogue::template forward<ISA_T>(tmpfp32ptr, cachestep, M_offset, N_offset, M, N, _param.param2, tmpcache,
cachesize);

return ret;
}
};

template <typename DT>
struct ParamAccumulatorWriteBack {
DT* C;
int ldc;
void* elt_const_v;
};

template <BTLA_ISA ISA_T, typename _SRC_T, typename _DST_T>
template <typename _SRC_T, typename _DST_T>
class AccumulatorWriteBack {
public:
using SType = _SRC_T;
using DType = _DST_T;
using Param = ParamAccumulatorWriteBack<DType>;
using PcCompInt8Epi = bestla::epilogue::gemm::PcKBlockCompInt8Epilogue<AccumulatorWriteBack<_SRC_T, _DST_T>>;

template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
Expand All @@ -52,10 +134,13 @@ class AccumulatorWriteBack {
}
};

template <BTLA_ISA ISA_T, typename _SRC_T, typename _DST_T, BTLA_ELTWISEOP _OP>
template <typename _SRC_T, typename _DST_T, BTLA_ELTWISEOP _OP>
class CustomAccumulatorWriteBackWithEltop {
public:
using PcCompInt8Epi =
bestla::epilogue::gemm::PcKBlockCompInt8Epilogue<CustomAccumulatorWriteBackWithEltop<_SRC_T, _DST_T, _OP>>;
using Param = ParamAccumulatorWriteBack<_DST_T>;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
Expand All @@ -68,39 +153,29 @@ class CustomAccumulatorWriteBackWithEltop {
}
}
};
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackFp32 = AccumulatorWriteBack<ISA_T, float, float>;
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackInt32 = AccumulatorWriteBack<ISA_T, int, int>;
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackBf16 = AccumulatorWriteBack<ISA_T, utils::bf16, utils::bf16>;
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackFp16 = AccumulatorWriteBack<ISA_T, utils::fp16, utils::fp16>;
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackBf16Fp32 = AccumulatorWriteBack<ISA_T, utils::bf16, float>;
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack<ISA_T, utils::fp16, float>;
template <BTLA_ISA ISA_T>
using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack<ISA_T, float, utils::bf16>;
using AccumulatorWriteBackFp32 = AccumulatorWriteBack<float, float>;
using AccumulatorWriteBackInt32 = AccumulatorWriteBack<int, int>;
using AccumulatorWriteBackBf16 = AccumulatorWriteBack<utils::bf16, utils::bf16>;
using AccumulatorWriteBackFp16 = AccumulatorWriteBack<utils::fp16, utils::fp16>;
using AccumulatorWriteBackBf16Fp32 = AccumulatorWriteBack<utils::bf16, float>;
using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack<utils::fp16, float>;
using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack<float, utils::bf16>;

template <BTLA_ISA ISA_T>
using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop<ISA_T, float, float, BTLA_ELTWISEOP::GELU>;
using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop<float, float, BTLA_ELTWISEOP::GELU>;

template <BTLA_ISA ISA_T>
using AccumulatorWriteBackWithSwishFp32 =
CustomAccumulatorWriteBackWithEltop<ISA_T, float, float, BTLA_ELTWISEOP::SWISH>;
using AccumulatorWriteBackWithSwishFp32 = CustomAccumulatorWriteBackWithEltop<float, float, BTLA_ELTWISEOP::SWISH>;

template <typename DT>
struct ParamAlphaBetaProcess {
DT *C, *D;
int ldc, ldd;
float alpha, beta;
};
template <BTLA_ISA ISA_T>
class AlphaBetaProcessFp32 {
public:
using Param = ParamAlphaBetaProcess<float>;

template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto DOffset = M_offset * _param.ldd + N_offset;
Expand All @@ -120,10 +195,10 @@ struct ParamCompFp32BlockEpilogue {
float* reduce = nullptr;
int ldra;
};
template <BTLA_ISA ISA_T>
class CompFp32BlockEpilogue {
public:
using Param = ParamCompFp32BlockEpilogue;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset,
const int N_offset, const int K_offset, const int M, const int N, const Param& _param,
void* tmpcache, size_t cachesize) {
Expand Down Expand Up @@ -171,10 +246,10 @@ struct ParamDequantInt32ToFp32 {
float* scalesA;
float* scalesB;
};
template <BTLA_ISA ISA_T>
class DequantInt32ToFp32 {
public:
using Param = ParamDequantInt32ToFp32;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
Expand All @@ -185,88 +260,6 @@ class DequantInt32ToFp32 {
}
};

struct ParamCompInt8BlockEpilogue {
void* scalesB;
BTLA_DTYPE scaleBdtype;
int ldsb;
float* scalesA;
int ldsa;
// optional if A asym
uint8_t* zpA = nullptr;
void* reduceB = nullptr;
BTLA_DTYPE reduceBdtype = BTLA_DTYPE::F32;
// optional if B asym
int8_t* zpB = nullptr;
float* reduceA = nullptr;
int K = 1;
};
template <BTLA_ISA ISA_T>
class CompInt8BlockEpilogue {
public:
using Param = ParamCompInt8BlockEpilogue;
static BTLA_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset,
const int N_offset, const int K_offset, const int M, const int N, const Param& _param,
void* tmpcache, size_t cachesize) {
BTLA_CODE ret = BTLA_CODE::NotSupport;
float* scab = nullptr;
size_t ScaleBTmpSize = N * sizeof(float);
size_t ReduceBTmpSize = N * sizeof(float);
assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize));
if (_param.scaleBdtype == BTLA_DTYPE::BF16) {
auto scache = reinterpret_cast<float*>(tmpcache);
ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward<ISA_T>(
reinterpret_cast<utils::bf16*>(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N,
false);
assert(ret == BTLA_CODE::Success);
scab = scache;
} else if (_param.scaleBdtype == BTLA_DTYPE::F32) {
scab = reinterpret_cast<float*>(_param.scalesB) + N_offset + K_offset * _param.ldsb;
}
float* redb = nullptr;
if (_param.reduceB) {
if (_param.reduceBdtype == BTLA_DTYPE::BF16) {
auto rcache = reinterpret_cast<float*>(reinterpret_cast<char*>(tmpcache) + ScaleBTmpSize);
ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward<ISA_T>(
reinterpret_cast<utils::bf16*>(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N,
false);
assert(ret == BTLA_CODE::Success);
redb = rcache;
} else if (_param.reduceBdtype == BTLA_DTYPE::F32) {
redb = reinterpret_cast<float*>(_param.reduceB) + N_offset + K_offset * _param.ldsb;
}
}
ret = kernel::wrapper::DequanS32Fp32::template forward<ISA_T>(
srcptr, cachestep, reinterpret_cast<float*>(const_cast<int32_t*>(srcptr)), cachestep, M, N,
_param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab);
assert(ret == BTLA_CODE::Success);
ret = kernel::wrapper::AccumulateFp32::template forward<ISA_T>(reinterpret_cast<const float*>(srcptr), cachestep,
dstptr, cachestep, M, N);
assert(ret == BTLA_CODE::Success);

if (_param.zpA == nullptr) {
if (_param.zpB == nullptr) {
return ret;
} else {
ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei<ISA_T>(
dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa,
_param.reduceA + M_offset * _param.ldsa + K_offset);
}
} else {
if (_param.zpB == nullptr) {
ret = kernel::wrapper::RemoveZeroPointBias::template forward_act<ISA_T>(
dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset,
_param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb);
} else {
ret = kernel::wrapper::RemoveZeroPointBias::template forward_both<ISA_T>(
dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset,
_param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab,
_param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb);
}
}
return ret;
}
};

struct ParamZpDequantInt32ToFp32 {
// necessary
float* C;
Expand All @@ -282,10 +275,10 @@ struct ParamZpDequantInt32ToFp32 {
float* reduceA = nullptr;
int K = 1;
};
template <BTLA_ISA ISA_T>
class ZpDequantInt32ToFp32 {
public:
using Param = ParamZpDequantInt32ToFp32;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
Expand Down Expand Up @@ -323,10 +316,10 @@ struct ParamAlphaBetaProcessS32U8 {
float scaleAcc, scaleC;
int zpC;
};
template <BTLA_ISA ISA_T>
class AlphaBetaProcessS32U8 {
public:
using Param = ParamAlphaBetaProcessS32U8;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
Expand Down
4 changes: 2 additions & 2 deletions bestla/bestla/bestla_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -4816,7 +4816,7 @@ class CoreCodeBase {
static auto constexpr KTILE = Code::KTILE;
static auto constexpr PACK_ROW = Code::PackRow;
static auto constexpr COMP = Code::COMPUTE;
static int constexpr PREFERRED_N = NTILE * 3;
static int constexpr PREFERRED_N = NTILE * 4;
static auto constexpr ISA = Code::ISA;
static auto constexpr ID = CoreAttr::make_core_id(NTILE, PACK_ROW, COMP, ISA);
void configure(int _M, int _N, int _K) { (void)(0); }
Expand All @@ -4842,7 +4842,7 @@ class CoreCodeBaseAMX {
static auto constexpr KTILE = Code::KTILE;
static auto constexpr PACK_ROW = Code::PackRow;
static auto constexpr COMP = Code::COMPUTE;
static int constexpr PREFERRED_N = NTILE * 3;
static int constexpr PREFERRED_N = NTILE * 4;
static auto constexpr ISA = Code::ISA;
static auto constexpr ID = CoreAttr::make_core_id(_NTILE, PACK_ROW, COMP, ISA);
Xbyak::CodeGenerator cfgcode;
Expand Down
Loading
Loading