Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
fix the code bug of mul and add. use new kernels in custom::epilogue
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed May 31, 2024
1 parent 71b017c commit 86e0a94
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 33 deletions.
22 changes: 11 additions & 11 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ static inline BTLA_CODE decompress_kblock_s4_s8(utils::int4x2* srcptr, int8_t* z
int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp,
size_t tmpsize) {
if (zpptr) {
typedef BTLA_CODE (*decompfunc)(utils::int4x2 * srcptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp,
typedef BTLA_CODE (*decompfunc)(utils::int4x2* srcptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp,
int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize);
decompfunc func = nullptr;
if (col == NTILE) {
Expand Down Expand Up @@ -764,7 +764,7 @@ static inline BTLA_CODE decompress_kblock_s2_s8(utils::bit2x4* bit2ptr, int8_t*
int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp,
size_t tmpsize) {
if (zpptr) {
typedef BTLA_CODE (*decompfunc)(utils::bit2x4 * srcptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp,
typedef BTLA_CODE (*decompfunc)(utils::bit2x4* srcptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp,
int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize);
decompfunc func = nullptr;
if (col == NTILE) {
Expand Down Expand Up @@ -1022,7 +1022,7 @@ static inline BTLA_CODE decompress_kblock_s3_s8(utils::bit2x4* bit2ptr, utils::b
int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset,
int row, int col, int8_t* tmp, size_t tmpsize) {
if (zpptr) {
typedef BTLA_CODE (*decompfunc)(utils::bit2x4 * bit2ptr, utils::bit1x8 * bit1ptr, int8_t * zpptr, int8_t * dstptr,
typedef BTLA_CODE (*decompfunc)(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int8_t* zpptr, int8_t* dstptr,
int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp,
size_t tmpsize);
decompfunc func = nullptr;
Expand Down Expand Up @@ -1247,7 +1247,7 @@ static inline BTLA_CODE decompress_kblock_s1_s8(utils::bit1x8* bit1ptr, int8_t*
int ldzp, int n_offset, int k_offset, int row, int col, int8_t* tmp,
size_t tmpsize) {
if (zpptr) {
typedef BTLA_CODE (*decompfunc)(utils::bit1x8 * bit1ptr, int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp,
typedef BTLA_CODE (*decompfunc)(utils::bit1x8* bit1ptr, int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp,
int n_offset, int k_offset, int row, int8_t* tmp, size_t tmpsize);
decompfunc func = nullptr;
if (col == NTILE) {
Expand Down Expand Up @@ -1500,7 +1500,7 @@ static inline BTLA_CODE decompress_kblock_s5_s8(utils::bit4x2* bit4ptr, utils::b
int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset,
int row, int col, int8_t* tmp, size_t tmpsize) {
if (zpptr) {
typedef BTLA_CODE (*decompfunc)(utils::bit4x2 * bit4ptr, utils::bit1x8 * bit1ptr, int8_t * zpptr, int8_t * dstptr,
typedef BTLA_CODE (*decompfunc)(utils::bit4x2* bit4ptr, utils::bit1x8* bit1ptr, int8_t* zpptr, int8_t* dstptr,
int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp,
size_t tmpsize);
decompfunc func = nullptr;
Expand Down Expand Up @@ -1814,9 +1814,9 @@ static inline BTLA_CODE decompress_kblock_s7_s8(utils::bit4x2* bit4ptr, utils::b
int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, int n_offset,
int k_offset, int row, int col, int8_t* tmp, size_t tmpsize) {
if (zpptr) {
typedef BTLA_CODE (*decompfunc)(utils::bit4x2 * bit4ptr, utils::bit2x4 * bit2ptr, utils::bit1x8 * bit1ptr,
int8_t * zpptr, int8_t * dstptr, int blocksize, int ldzp, int n_offset,
int k_offset, int row, int8_t* tmp, size_t tmpsize);
typedef BTLA_CODE (*decompfunc)(utils::bit4x2* bit4ptr, utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr,
int8_t* zpptr, int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset,
int row, int8_t* tmp, size_t tmpsize);
decompfunc func = nullptr;
if (col == NTILE) {
if constexpr (PackRow == 1) {
Expand Down Expand Up @@ -2077,7 +2077,7 @@ static inline BTLA_CODE decompress_kblock_s6_s8(utils::bit4x2* bit4ptr, utils::b
int8_t* dstptr, int blocksize, int ldzp, int n_offset, int k_offset,
int row, int col, int8_t* tmp, size_t tmpsize) {
if (zpptr) {
typedef BTLA_CODE (*decompfunc)(utils::bit4x2 * bit4ptr, utils::bit2x4 * bit2ptr, int8_t * zpptr, int8_t * dstptr,
typedef BTLA_CODE (*decompfunc)(utils::bit4x2* bit4ptr, utils::bit2x4* bit2ptr, int8_t* zpptr, int8_t* dstptr,
int blocksize, int ldzp, int n_offset, int k_offset, int row, int8_t* tmp,
size_t tmpsize);
decompfunc func = nullptr;
Expand Down Expand Up @@ -6651,7 +6651,7 @@ static inline BTLA_CODE mul(const T* src0ptr, const T* src1ptr, T* dstptr, size_
i = size - VLen;
vfunc();
} else {
ref::add(src0ptr + i, src1ptr + i, dstptr + i, size - i);
ref::mul(src0ptr + i, src1ptr + i, dstptr + i, size - i);
}
}
return BTLA_CODE::Success;
Expand All @@ -6674,7 +6674,7 @@ static inline BTLA_CODE add(const T* src0ptr, const T* src1ptr, T* dstptr, size_
i = size - VLen;
vfunc();
} else {
ref::mul(src0ptr + i, src1ptr + i, dstptr + i, size - i);
ref::add(src0ptr + i, src1ptr + i, dstptr + i, size - i);
}
}
return BTLA_CODE::Success;
Expand Down
52 changes: 30 additions & 22 deletions neural_speed/core/layers/bestla_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,21 +128,21 @@ class Add {
using Param = ParamAdd<_T>;

template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
static inline BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
auto DOffset = M_offset * _param.ldd + N_offset;
auto cptr = _param.C + COffset;
auto dptr = _param.D + DOffset;
// for (int i = 0; i < M; i++) {
// ne_vec_add_f32(N, cptr + i * _param.ldc,dptr + i * _param.ldd, cacheptr + i * cachestep);
// }
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
cptr[i * _param.ldc + j] = dptr[i * _param.ldd + j] + cacheptr[i * cachestep + j];
if constexpr (std::is_same_v<_T, float>) {
for (int i = 0; i < M; i++) {
bestla::kernel::wrapper::Add<_T>::template forward<ISA_T>(dptr + i * _param.ldd, cacheptr + i * cachestep,
cptr + i * _param.ldc, N);
}
return BTLA_CODE::Success;
} else {
return BTLA_CODE::NotSupport;
}
return BTLA_CODE::Success;
}
};
using AddFp32 = Add<float>;
Expand All @@ -157,18 +157,21 @@ class Mul {
public:
using Param = ParamMul<_T>;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
static inline BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
auto DOffset = M_offset * _param.ldd + N_offset;
auto cptr = _param.C + COffset;
auto dptr = _param.D + DOffset;
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
cptr[i * _param.ldc + j] = dptr[i * _param.ldd + j] * cacheptr[i * cachestep + j];
if constexpr (std::is_same_v<_T, float>) {
for (int i = 0; i < M; i++) {
bestla::kernel::wrapper::Mul<_T>::template forward<ISA_T>(dptr + i * _param.ldd, cacheptr + i * cachestep,
cptr + i * _param.ldc, N);
}
return BTLA_CODE::Success;
} else {
return BTLA_CODE::NotSupport;
}
return BTLA_CODE::Success;
}
};
using MulFp32 = Mul<float>;
Expand All @@ -183,20 +186,25 @@ class Add_Gelu {
public:
using Param = ParamAdd_Gelu<_T>;
template <BTLA_ISA ISA_T>
static BTLA_CODE forward( // NOLINT [build/include_what_you_use]
static inline BTLA_CODE forward( // NOLINT [build/include_what_you_use]
const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, const int N,
const Param& _param, void* tmpcache, size_t cachesize) {
auto COffset = M_offset * _param.ldc + N_offset;
auto DOffset = M_offset * _param.ldd + N_offset;
auto cptr = _param.C + COffset;
auto dptr = _param.D + DOffset;
for (int i = 0; i < M; i++) {
ne_vec_add_f32(N, cptr + i * _param.ldc, dptr + i * _param.ldd, cacheptr + i * cachestep);
if constexpr (std::is_same_v<_T, float>) {
for (int i = 0; i < M; i++) {
bestla::kernel::wrapper::Add<_T>::template forward<ISA_T>(dptr + i * _param.ldd, cacheptr + i * cachestep,
cptr + i * _param.ldc, N);
}
using GeluKernel = bestla::epilogue::gemm::AccumulatorWriteBackWithGeluFp32;
typename GeluKernel::Param param{_param.C, _param.ldc, nullptr};
auto ret = GeluKernel::forward<ISA_T>(cptr, _param.ldc, M_offset, N_offset, M, N, param, tmpcache, cachesize);
return ret;
} else {
return BTLA_CODE::NotSupport;
}
using GeluKernel = bestla::epilogue::gemm::AccumulatorWriteBackWithGeluFp32;
typename GeluKernel::Param param{_param.C, _param.ldc, nullptr};
auto ret = GeluKernel::forward<ISA_T>(cptr, _param.ldc, M_offset, N_offset, M, N, param, tmpcache, cachesize);
return ret;
}
};
using Add_GeluFp32 = Add_Gelu<float>;
Expand Down

0 comments on commit 86e0a94

Please sign in to comment.