Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into whisper_opt
Browse files Browse the repository at this point in the history
  • Loading branch information
intellinjun authored Jan 18, 2024
2 parents e974c26 + 9e4cd94 commit 25edf34
Show file tree
Hide file tree
Showing 22 changed files with 990 additions and 484 deletions.
6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,5 @@ build/
debug/
.eggs/
dist/
runtime_outs/
__pycache__
CMakeUserPresets.json
*.egg-info
.cache/
.clangd
488 changes: 76 additions & 412 deletions README.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions bestla/bestla/bestla.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ enum class BTLA_DTYPE : uint32_t {
SubType1 = 1 << SubTypeShift,
SubType2 = 2 << SubTypeShift,
SubType3 = 3 << SubTypeShift,
SubType4 = 4 << SubTypeShift,
F64 = EleBits64 | TypeFloat,
F32 = EleBits32 | TypeFloat,
F16 = EleBits16 | TypeFloat,
Expand All @@ -59,6 +60,7 @@ enum class BTLA_DTYPE : uint32_t {
F8_E5M2 = EleBits8 | TypeFloat | SubType1,
F8_E3M4 = EleBits8 | TypeFloat | SubType2,
F8_E8M0 = EleBits8 | TypeFloat | SubType3,
DQ8_BNB = EleBits8 | TypeFloat | SubType4,
S8 = EleBits8 | TypeInt,
U8 = EleBits8 | TypeInt | SubType1,
S4_CLIP = EleBits4 | TypeInt,
Expand Down
2 changes: 2 additions & 0 deletions bestla/bestla/bestla_epilogue.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class CompFp32BlockEpilogue {
if (_param.zps != nullptr) {
assert(0);
}
} else {
assert(0);
}
return BTLA_CODE::NotSupport;
}
Expand Down
94 changes: 94 additions & 0 deletions bestla/bestla/bestla_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,33 @@ class WeightKBlockNInteger {
return tmp;
}

void doubleQuantScale(float* scale, size_t scale_size, int dq_blocksize, BTLA_DTYPE qtype,
utils::aligned_vector<float>* dq_buf) {
if (qtype == BTLA_DTYPE::DQ8_BNB) {
dq_buf->resize(utils::updiv(scale_size, dq_blocksize) + 1); // add 1 for offset.
kernel::ref::dq8_bnb_double_quant<false>(scale, scale_size, dq_blocksize, dq_buf->data());
} else {
assert(0);
}
}

void setDoubleQuantCorrection(utils::avector<float>* dq_buf, StorageWeight* ptr) {
if (ptr->SDtype() == BTLA_DTYPE::DQ8_BNB) {
auto packw_dqbuf_ptr = ptr->DQPtr<float>();
memcpy(packw_dqbuf_ptr, dq_buf->data(), dq_buf->size() * sizeof(float));
} else {
assert(0);
}
}

static void enableShuffle(StorageWeight* stor) { stor->enable_shuffle(); }
void setDoubleQuantBlkSize(StorageWeight* stor, BTLA_DTYPE stype, int dq_blksize) {
stor->mDqBlockSize = dq_blksize;
auto nk_scale = utils::updiv(stor->mKPad, stor->mBlockSize);
if (stor->IsAsym() || dq_blksize % 8 != 0) assert(0);
stor->mCorrection.enable_double_quant(utils::updiv(nk_scale * stor->mN, dq_blksize), stype);
stor->update_size();
}

void packTransposeWeight(const int N, const int K, const float* B, const int ldb, StorageWeight* stor,
parallel::IThreading* threading) {
Expand Down Expand Up @@ -273,6 +299,25 @@ class WeightKBlockNInteger {
}
}
});
} else if (stor->SDtype() == BTLA_DTYPE::DQ8_BNB) {
threading->parallel_for([&](int tidx) {
parallel::ThreadProblem2D thdp{tidx};
_para.getIndex(thdp);
if (thdp.valid) {
for (int i = thdp.loc[1]; i < thdp.loc[1] + thdp.size[1]; i++) {
if (i < rawnk_scale) {
if (scales != nullptr) {
for (size_t j = 0; j < N; j++) {
stor->template SPtr<uint8_t>()[j + i * stor->mNPad] = static_cast<uint8_t>(scales[i * N + j]);
}
}
} else {
if (scales != nullptr)
std::memset(stor->template SPtr<uint8_t>() + i * stor->mNPad, 0, stor->mNPad * sizeof(uint8_t));
}
}
}
});
} else {
assert(0);
}
Expand Down Expand Up @@ -381,6 +426,14 @@ class WeightKBlockNInteger {

void packQWeight(const int N, const int K, const int8_t* B, const int ldb, const float* scales,
const int8_t* zero_points, StorageWeight* stor, parallel::IThreading* threading) {
if (stor->SDtype() == BTLA_DTYPE::DQ8_BNB) assert(stor->mDqBlockSize != 0);
if (stor->IsDoubleQuant()) {
int nk_scale = utils::updiv(K, stor->mBlockSize);
auto ssize = static_cast<size_t>(N) * nk_scale;
utils::avector<float> dq_buf;
doubleQuantScale(const_cast<float*>(scales), ssize, stor->mDqBlockSize, stor->SDtype(), &dq_buf);
setDoubleQuantCorrection(&dq_buf, stor);
}
setQuantCorrection(N, K, zero_points, scales, stor, threading);
if (stor->mDType == BTLA_DTYPE::S8 || stor->mDType == BTLA_DTYPE::F8_E4M3 || stor->mDType == BTLA_DTYPE::F8_E5M2) {
reorderWeight(N, K, B, ldb, stor->WPtr<int8_t>(), threading);
Expand Down Expand Up @@ -596,6 +649,15 @@ class WeightKBlockNInteger {
utils::updiv(k_size, wptr->mBlockSize), n_size, wptr->CStep() * 2, n_size * 4, false);
*dststep = n_size;
}
if (wptr->SDtype() == BTLA_DTYPE::DQ8_BNB) {
auto aptr = wptr->template SPtr<uint8_t>();
auto internal_k_offset = k_offset / wptr->mBlockSize;
auto dq_offset_idx = wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1;
kernel::wrapper::Dq8GetScale::template forward<ISA_T>(
aptr + internal_k_offset * wptr->CStep() + n_offset, *dstptr, utils::updiv(k_size, wptr->mBlockSize), n_size,
internal_k_offset * wptr->mN + n_offset, wptr->mDqBlockSize, dq_offset_idx, wptr->DQPtr<float>(),
wptr->CStep(), n_size, false);
}
return BTLA_CODE::Success;
}

Expand Down Expand Up @@ -731,6 +793,22 @@ class WeightKBlockNInteger {
} else {
assert(0);
}
} else if (wptr->SDtype() == BTLA_DTYPE::DQ8_BNB) {
auto internal_n_offset = n_offset + i;
if (wptr->mDType == BTLA_DTYPE::S4_CLIP) {
kernel::wrapper::DecompressDQKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T,
BTLA_DTYPE::S4_CLIP>(
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
i * KPad / 2,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize,
wptr->template SPtr<uint8_t>(), wptr->template DQPtr<float>(), k_offset / _GemmCore_T::PACK_ROW,
internal_n_offset, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, wptr->mN, wptr->mDqBlockSize,
wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1, tmpcache, cachesize);
} else {
assert(0);
}
} else {
assert(0);
}
}
*dststep = k_size;
Expand Down Expand Up @@ -899,6 +977,22 @@ class WeightKBlockNFloat : public WeightKBlockNInteger<_GemmCore_T, ISA_T> {
} else {
assert(0);
}
} else if (wptr->SDtype() == BTLA_DTYPE::DQ8_BNB) {
auto f4ptr = reinterpret_cast<utils::f4x2*>(bptr + i * KPad / 2);
auto fp_ptr = *dstptr + i * k_size;
auto internal_n_offset = n_offset + i;
auto internal_k_offset = k_offset / _GemmCore_T::PACK_ROW;
auto internal_kblock = wptr->mBlockSize / _GemmCore_T::PACK_ROW;
auto dq_offset_idx = wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1;
if (wptr->mDType == BTLA_DTYPE::F4_NF4) {
kernel::wrapper::DecompressDqKBlockF4Fp<_DST_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T,
BTLA_DTYPE::F4_NF4>(
f4ptr, fp_ptr, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, wptr->template SPtr<uint8_t>(),
wptr->template DQPtr<float>(), internal_k_offset, internal_n_offset, internal_kblock, wptr->mDqBlockSize,
dq_offset_idx, NPad, wptr->mN, tmpcache, cachesize);
} else {
assert(0);
}
} else {
assert(0);
}
Expand Down
33 changes: 33 additions & 0 deletions bestla/bestla/bestla_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class ObjectQuantCorrection : public ISerialObject {
BTLA_DTYPE mScaT = BTLA_DTYPE::F32, mZpT = BTLA_DTYPE::F32, mRedT = BTLA_DTYPE::F32;
ObjectAlignedBuffer<Alignment> mScaleBuf;
ObjectOptionalBuffer<Alignment> mZpBuf, mRedBuf;
ObjectOptionalBuffer<Alignment> mDQCorrectionBuf;

// non-ser
public:
Expand Down Expand Up @@ -184,6 +185,7 @@ class ObjectQuantCorrection : public ISerialObject {
totalsize += mScaleBuf.getSerializedSize();
totalsize += mZpBuf.getSerializedSize();
totalsize += mRedBuf.getSerializedSize();
totalsize += mDQCorrectionBuf.getSerializedSize();
return totalsize;
}
virtual void serializeToBuffer(int8_t*& wptr) override {
Expand All @@ -195,6 +197,7 @@ class ObjectQuantCorrection : public ISerialObject {
mScaleBuf.serializeToBuffer(wptr);
mZpBuf.serializeToBuffer(wptr);
mRedBuf.serializeToBuffer(wptr);
mDQCorrectionBuf.serializeToBuffer(wptr);
}
virtual void deserializeBuffer(int8_t*& rptr, bool locate_buf) override {
if (!locate_buf) {
Expand All @@ -214,7 +217,17 @@ class ObjectQuantCorrection : public ISerialObject {
mScaleBuf.deserializeBuffer(rptr, locate_buf);
mZpBuf.deserializeBuffer(rptr, locate_buf);
mRedBuf.deserializeBuffer(rptr, locate_buf);
mDQCorrectionBuf.deserializeBuffer(rptr, locate_buf);
}
void enable_double_quant(size_t scale_size, BTLA_DTYPE stype) {
if (stype == BTLA_DTYPE::DQ8_BNB) {
auto super_scale_size = scale_size * sizeof(float);
auto super_zp_size = sizeof(float);
mDQCorrectionBuf.resize(super_scale_size + super_zp_size);
} else {
assert(0);
}
};

protected:
inline void updateSize() {
Expand Down Expand Up @@ -306,6 +319,7 @@ class IWeightBase : public storage::ISerializable {
class IWeightKBlockBase : public IWeightBase {
public:
int mBlockSize = 1;
int mDqBlockSize = 0;
IWeightKBlockBase(uint64_t _id) : IWeightBase(_id) {}
void resize(int NPad, int KPad, int Block, int N, int K, BTLA_DTYPE dtype) {
IWeightBase::resize(NPad, KPad, N, K, dtype);
Expand All @@ -321,19 +335,23 @@ class IWeightKBlockBase : public IWeightBase {
virtual void serializeToBuffer(int8_t*& wptr) {
IWeightBase::serializeToBuffer(wptr);
utils::serialize(wptr, mBlockSize);
utils::serialize(wptr, mDqBlockSize);
}

virtual void deserializeBuffer(int8_t*& rptr, bool map_buf) {
IWeightBase::deserializeBuffer(rptr, map_buf);
if (!map_buf) {
mBlockSize = utils::deserialize<int>(rptr);
mDqBlockSize = utils::deserialize<int>(rptr);
} else {
utils::serialize(rptr, mBlockSize);
utils::serialize(rptr, mDqBlockSize);
}
}

inline constexpr size_t getMiscSize() {
size_t totalsize = sizeof(mBlockSize);
totalsize += sizeof(mDqBlockSize);
return totalsize;
}
};
Expand Down Expand Up @@ -694,10 +712,17 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase {
auto gemm_comp = bestla::gemm::CoreAttr::get_comp(mCoreId);
auto is_cint = bestla::gemm::CompTypeHelper::is_integer(gemm_comp);
mCorrection.resize(nk_scale, NPad, scalet, zpt, redt, IsAsym, is_cint);
if (scalet == BTLA_DTYPE::DQ8_BNB) initDoubleQuantBlkSize(Block, nk_scale, IsAsym, N, scalet);
update_size();
return mSize;
}

void initDoubleQuantBlkSize(int dq_blksize, int nk_scale, bool asym, int N, BTLA_DTYPE stype) {
mDqBlockSize = dq_blksize;
if (asym || dq_blksize % 8 != 0) assert(0);
mCorrection.enable_double_quant(utils::updiv(nk_scale * N, dq_blksize), stype);
}

void enable_shuffle() {
auto indicessize = mK * sizeof(int);
mShuffleIndices.resize(indicessize);
Expand All @@ -709,6 +734,7 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase {
inline constexpr BTLA_DTYPE SDtype() { return mCorrection.mScaT; }
inline constexpr bool IsAsym() { return mCorrection.mZpBuf.mNotEmpty; }
inline constexpr bool HasReduce() { return mCorrection.mRedBuf.mNotEmpty; }
inline constexpr bool IsDoubleQuant() { return mCorrection.mDQCorrectionBuf.mNotEmpty; }
inline constexpr size_t CSize() { return mCorrection.mCSize; }
inline constexpr int CStep() { return mCorrection.mCStep; }

Expand Down Expand Up @@ -737,6 +763,11 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase {
return mCorrection.mRedBuf.get<T>();
}

template <typename T>
inline constexpr T* DQPtr() {
return mCorrection.mDQCorrectionBuf.get<T>();
}

inline constexpr int* ShfIndice() { return mShuffleIndices.get<int>(); }

void update_size() {
Expand Down Expand Up @@ -782,6 +813,8 @@ class StorageWeightKBlockNFloat : public StorageWeightKBlockNInteger {
int nk_scale = utils::updiv(KPad, Block);
StorageWeightKBlockNInteger::mCorrection.resize(nk_scale, NPad, scalet, BTLA_DTYPE::EleBitsUndef,
BTLA_DTYPE::EleBitsUndef, false, false);
if (scalet == BTLA_DTYPE::DQ8_BNB) initDoubleQuantBlkSize(Block, nk_scale, false, N, scalet);
update_size();
mSize = StorageWeightKBlockNInteger::InfoType::getSerializedSize() +
StorageWeightKBlockNInteger::mQBuf.getSerializedSize() +
StorageWeightKBlockNInteger::mCorrection.getSerializedSize();
Expand Down
33 changes: 33 additions & 0 deletions bestla/bestla/bestla_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ inline const char* bestla_dtype_str(BTLA_DTYPE dtype) {
return "fp8_e5m2";
case BTLA_DTYPE::F8_E3M4:
return "fp8_e3m4";
case BTLA_DTYPE::F8_E8M0:
return "fp8_e8m0";
case BTLA_DTYPE::S8:
return "signed_int8";
case BTLA_DTYPE::U8:
Expand All @@ -316,6 +318,8 @@ inline const char* bestla_dtype_str(BTLA_DTYPE dtype) {
return "signed_int32";
case BTLA_DTYPE::U32:
return "unsigned_int32";
case BTLA_DTYPE::DQ8_BNB:
return "dq8_bnb";
default:
return "ErrType";
}
Expand Down Expand Up @@ -700,4 +704,33 @@ static float nf4_dequant_fp32_LUT[] = {0.f,
0.5626170039176941f,
0.7229568362236023f,
1.0f};

// 8bit dynamic-tree-quantization map from bitsandbytes double-quant implementation.
// For more details pls refer
// (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
static float dq8_bnb_LUT[] = {
-0.99297, -0.97891, -0.96484, -0.95078, -0.93672, -0.92266, -0.90859, -0.89453, -0.88047, -0.86641, -0.85234,
-0.83828, -0.82422, -0.81016, -0.79609, -0.78203, -0.76797, -0.75391, -0.73984, -0.72578, -0.71172, -0.69766,
-0.68359, -0.66953, -0.65547, -0.64141, -0.62734, -0.61328, -0.59922, -0.58516, -0.57109, -0.55703, -0.54297,
-0.52891, -0.51484, -0.50078, -0.48672, -0.47266, -0.45859, -0.44453, -0.43047, -0.41641, -0.40234, -0.38828,
-0.37422, -0.36016, -0.34609, -0.33203, -0.31797, -0.30391, -0.28984, -0.27578, -0.26172, -0.24766, -0.23359,
-0.21953, -0.20547, -0.19141, -0.17734, -0.16328, -0.14922, -0.13516, -0.12109, -0.10703, -0.09859, -0.09578,
-0.09297, -0.09016, -0.08734, -0.08453, -0.08172, -0.07891, -0.07609, -0.07328, -0.07047, -0.06766, -0.06484,
-0.06203, -0.05922, -0.05641, -0.05359, -0.05078, -0.04797, -0.04516, -0.04234, -0.03953, -0.03672, -0.03391,
-0.03109, -0.02828, -0.02547, -0.02266, -0.01984, -0.01703, -0.01422, -0.01141, -0.00972, -0.00916, -0.00859,
-0.00803, -0.00747, -0.00691, -0.00634, -0.00578, -0.00522, -0.00466, -0.00409, -0.00353, -0.00297, -0.00241,
-0.00184, -0.00128, -0.00094, -0.00083, -0.00072, -0.00061, -0.00049, -0.00038, -0.00027, -0.00016, -0.00009,
-0.00007, -0.00004, -0.00002, -0.00001, -0.00000, -0.00000, 0.00000, 0.00000, 0.00000, 0.00001, 0.00002,
0.00004, 0.00007, 0.00009, 0.00016, 0.00027, 0.00038, 0.00049, 0.00061, 0.00072, 0.00083, 0.00094,
0.00128, 0.00184, 0.00241, 0.00297, 0.00353, 0.00409, 0.00466, 0.00522, 0.00578, 0.00634, 0.00691,
0.00747, 0.00803, 0.00859, 0.00916, 0.00972, 0.01141, 0.01422, 0.01703, 0.01984, 0.02266, 0.02547,
0.02828, 0.03109, 0.03391, 0.03672, 0.03953, 0.04234, 0.04516, 0.04797, 0.05078, 0.05359, 0.05641,
0.05922, 0.06203, 0.06484, 0.06766, 0.07047, 0.07328, 0.07609, 0.07891, 0.08172, 0.08453, 0.08734,
0.09016, 0.09297, 0.09578, 0.09859, 0.10703, 0.12109, 0.13516, 0.14922, 0.16328, 0.17734, 0.19141,
0.20547, 0.21953, 0.23359, 0.24766, 0.26172, 0.27578, 0.28984, 0.30391, 0.31797, 0.33203, 0.34609,
0.36016, 0.37422, 0.38828, 0.40234, 0.41641, 0.43047, 0.44453, 0.45859, 0.47266, 0.48672, 0.50078,
0.51484, 0.52891, 0.54297, 0.55703, 0.57109, 0.58516, 0.59922, 0.61328, 0.62734, 0.64141, 0.65547,
0.66953, 0.68359, 0.69766, 0.71172, 0.72578, 0.73984, 0.75391, 0.76797, 0.78203, 0.79609, 0.81016,
0.82422, 0.83828, 0.85234, 0.86641, 0.88047, 0.89453, 0.90859, 0.92266, 0.93672, 0.95078, 0.96484,
0.97891, 0.99297, 1.00000};
} // namespace bestla
46 changes: 46 additions & 0 deletions bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,52 @@ static inline void dequant_s8_N_avx2(float* dstptr, int8_t* srcptr, __m256* vsca
}
}

inline BTLA_CODE dq8_get_fp_scale(uint8_t* src, float* dst, int row, int col, int scale_offset, int dq_blk,
int dq_offset_idx, float* dq_scale, int src_stride, int dst_stride,
bool zeropadding) {
auto head_proc_num = utils::updiv(scale_offset, 8) * 8 - scale_offset;
auto ymm_dq_offset = _mm256_set1_ps(dq_scale[dq_offset_idx]);

auto get_fp_scale_ref = [&](int proc_src_num, int scale_offset, uint8_t* src, float* dst) {
auto dq_s_idx = scale_offset / dq_blk;
for (int j = 0; j < col; j++) dst[j] = dq8_bnb_LUT[src[j]] * dq_scale[dq_s_idx] + dq_scale[dq_offset_idx];
};

auto get_fp_scale_avx2 = [&](int scale_offset, uint8_t* src, float* dst) {
auto dq_s_idx = scale_offset / dq_blk;
auto ymm_dq_scale = _mm256_set1_ps(dq_scale[dq_s_idx]);
__m256 fp32_dq_v;
for (int i = 0; i < 8; i++) fp32_dq_v[i] = dq8_bnb_LUT[src[i]];
auto fymm = _mm256_mul_ps(fp32_dq_v, ymm_dq_scale);
fymm = _mm256_add_ps(fymm, ymm_dq_offset);
_mm256_storeu_ps(dst, fymm);
};

for (int i = 0; i < row; i++) {
if (head_proc_num > col) {
get_fp_scale_ref(col, scale_offset, src + i * src_stride, dst + i * dst_stride);
} else {
get_fp_scale_ref(head_proc_num, scale_offset, src + i * src_stride, dst + i * dst_stride);
auto scale_offset_iter = scale_offset + head_proc_num;
uint8_t* src_iter_ptr = src + head_proc_num;
float* dst_iter_ptr = dst + head_proc_num;
auto body_loop = (col - head_proc_num) / 8;
auto tail_proc_num = (col - head_proc_num) % 8;
int ii = 0;
for (; ii < body_loop; ii++) {
get_fp_scale_avx2(scale_offset_iter + ii * 8, src_iter_ptr + i * src_stride + ii * 8,
dst_iter_ptr + i * dst_stride + ii * 8);
}
if (tail_proc_num > 0) {
get_fp_scale_ref(tail_proc_num, scale_offset_iter + ii * 8, src_iter_ptr + i * src_stride + ii * 8,
dst_iter_ptr + i * dst_stride + ii * 8);
}
}
}
if (zeropadding) assert(0);
return BTLA_CODE::Success;
}

static inline BTLA_CODE alphabeta_f32_f32(const float alpha, const float* srcptr, const int srcstep, const float beta,
const float* src1ptr, const int src1step, float* dstptr, const int dststep,
const int M, const int N) {
Expand Down
Loading

0 comments on commit 25edf34

Please sign in to comment.