Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
rename after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Dec 29, 2023
1 parent dce5e41 commit 940ab31
Show file tree
Hide file tree
Showing 11 changed files with 300 additions and 384 deletions.
117 changes: 58 additions & 59 deletions bestla/bestla/bestla_prologue_b.h

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion bestla/bestla/bestla_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ class PackedWeightParser {
case BTLA_PROLOGUEB_IDS::WeightPack:
ptr = new gemm::StoragePackedWeight(0);
break;
case JBLAS_PROLOGUEB_IDS::WeightKBlockNInteger:
case BTLA_PROLOGUEB_IDS::WeightKBlockNInteger:
ptr = new gemm::StorageWeightKBlockNInteger(0);
break;
case BTLA_PROLOGUEB_IDS::WeightKBlockNFloat:
Expand Down
8 changes: 4 additions & 4 deletions bestla/bestla/kernel_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,25 +251,25 @@ inline void convert_s4_s8_8_lowbits(int8_t* dstptr, int8_t* srcptr) {
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::S4_FULLRANGE>(int8_t* dstptr, int8_t* srcptr) {
inline void convert_s4_s8_8<BTLA_DTYPE::S4_FULLRANGE>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
for (size_t i = 0; i < 8; i++) {
dstptr[i] -= 8;
}
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_BNB>(int8_t* dstptr, int8_t* srcptr) {
inline void convert_s4_s8_8<BTLA_DTYPE::F4_BNB>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_NF4>(int8_t* dstptr, int8_t* srcptr) {
inline void convert_s4_s8_8<BTLA_DTYPE::F4_NF4>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
}

template <>
inline void convert_s4_s8_8<JBLAS_DTYPE::F4_E2M1>(int8_t* dstptr, int8_t* srcptr) {
inline void convert_s4_s8_8<BTLA_DTYPE::F4_E2M1>(int8_t* dstptr, int8_t* srcptr) {
convert_s4_s8_8_lowbits(dstptr, srcptr);
}

Expand Down
206 changes: 76 additions & 130 deletions bestla/bestla/ut/bestla_prologue_b.cpp

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions bestla/bestla/ut/bestla_ut.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,29 @@ static int8_t cache[CacheSize];
#define INT4_ERR 2.f
#define FP4_ERR 3.f

static inline float get_ut_err(JBLAS_DTYPE qtype){
auto dbits = utils::jblas_dtype_bits(qtype);
auto type = utils::jblas_dtype_type(qtype);
static inline float get_ut_err(BTLA_DTYPE qtype) {
auto dbits = utils::bestla_dtype_bits(qtype);
auto type = utils::bestla_dtype_type(qtype);
auto err = FP32_ERR;
auto constexpr dtype_int = utils::jblas_dtype_type(JBLAS_DTYPE::TypeInt);
auto constexpr dtype_int = utils::bestla_dtype_type(BTLA_DTYPE::TypeInt);
if (type == dtype_int) {
if (dbits == 8) {
err = INT8_ERR;
} else {
err = INT4_ERR;
}
} else {
if (dbits==4) {
if (dbits == 4) {
err = FP4_ERR;
} else if (dbits == 8) {
err = F8_ERR;
} else if (dbits == 16) {
if (qtype==JBLAS_DTYPE::F16) {
if (qtype == BTLA_DTYPE::F16) {
err = FP16_ERR;
} else {
err = BF16_ERR;
}
}
}
}
return err;
}
Expand Down
8 changes: 4 additions & 4 deletions neural_speed/core/layers/bestla_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ using tAVX512_VNNI_KBlock = gemm::ICoreRowNAvx512vnniKBlock<48, 4>;
using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<48, 16>;
using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<48, 16>;

template <class GC_T, JBLAS_ISA ISA_T>
using tWeiNInt = jblas::prologue_b::gemm::WeightKBlockNInteger<GC_T, ISA_T>;
template <class GC_T, JBLAS_ISA ISA_T>
using tWeiNFloat = jblas::prologue_b::gemm::WeightKBlockNFloat<GC_T, ISA_T>;
template <class GC_T, BTLA_ISA ISA_T>
using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger<GC_T, ISA_T>;
template <class GC_T, BTLA_ISA ISA_T>
using tWeiNFloat = prologue_b::gemm::WeightKBlockNFloat<GC_T, ISA_T>;

template <class GC_T, BTLA_ISA ISA_T>
using tActKBaseF32 = prologue_a::gemm::ShuffleActivationKBlockBaseF32<GC_T, ISA_T>;
Expand Down
203 changes: 87 additions & 116 deletions neural_speed/core/layers/bestla_gemm.cpp

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions neural_speed/core/layers/inner_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ bool bestla_fusion_add_f32f32_support(void* weiptr, int _m, int _n, int _k) {
constexpr size_t EleNum = sizeof(AllKBlockCores) / sizeof(AllKBlockCores[0]); // supported cores
support = contains(wtmp->mCoreId, AllKBlockCores, EleNum);
support &= hasISA(AllKBlockCores, EleNum);
} else if (wtmp->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockNFloat) {
} else if (wtmp->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNFloat) {
constexpr size_t EleNum = sizeof(FloatCores) / sizeof(FloatCores[0]);
support = contains(wtmp->mCoreId, FloatCores, EleNum);
support &= hasISA(FloatCores, EleNum);
Expand Down Expand Up @@ -188,26 +188,26 @@ void bestla_fusion_add_f32f32_forward(float* activation, void* weiptr, float* bi
}
}
}
if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockNFloat) {
auto bptr = reinterpret_cast<jblas::storage::gemm::IWeightKBlockBase*>(ptr);
if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNFloat) {
auto bptr = reinterpret_cast<storage::gemm::IWeightKBlockBase*>(ptr);
auto BlkSize = bptr->mBlockSize;
if (btype == gemm::CompType::tFP32 && PackRow == 1) {
if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
ip_add::JblasGemmCompF32<tAVX512F, tWeiNFloat, tActKBaseF32>(_m, _n, _k, activation, lda, ptr, output, ldo,
ip_add::BTLAGemmCompF32<tAVX512F, tWeiNFloat, tActKBaseF32>(_m, _n, _k, activation, lda, ptr, output, ldo,
bias, broadcast_bias, workspace, pth);
} else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
ip_add::JblasGemmCompF32<tAVX2, tWeiNFloat, tActKBaseF32>(_m, _n, _k, activation, lda, ptr, output, ldo, bias,
ip_add::BTLAGemmCompF32<tAVX2, tWeiNFloat, tActKBaseF32>(_m, _n, _k, activation, lda, ptr, output, ldo, bias,
broadcast_bias, workspace, pth);
}
}
if (btype == gemm::CompType::tBF16 && PackRow == 2 && BlkSize % tAMX_BF16::KTILE == 0) {
if (NTile == tAMX_BF16::NTILE && _cd->AMX_BF16()) {
if (_m <= tAVX512_BF16::MTILE) {
static_assert(tAVX512_BF16::NTILE == tAMX_BF16::NTILE);
ip_add::JblasGemmCompF32<tAVX512_BF16, tWeiNFloat, tActKBaseF32>(_m, _n, _k, activation, lda, ptr, output,
ip_add::BTLAGemmCompF32<tAVX512_BF16, tWeiNFloat, tActKBaseF32>(_m, _n, _k, activation, lda, ptr, output,
ldo, bias, broadcast_bias, workspace, pth);
} else {
ip_add::JblasGemmCompF32<tAMX_BF16, tWeiNFloat, tActKBaseF32>(_m, _n, _k, activation, lda, ptr, output, ldo,
ip_add::BTLAGemmCompF32<tAMX_BF16, tWeiNFloat, tActKBaseF32>(_m, _n, _k, activation, lda, ptr, output, ldo,
bias, broadcast_bias, workspace, pth);
}
}
Expand Down
Loading

0 comments on commit 940ab31

Please sign in to comment.