Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

3bit weight infrastructure in BesTLA #125

Merged
merged 33 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
be584d5
support 4bits gptq for gptj
Zhenzhong1 Jan 30, 2024
018bbfe
3bit storage
zhewang1-intc Jan 19, 2024
d8fa56d
tmp
zhewang1-intc Jan 19, 2024
8b68bad
avx512f kernel draft, todo:add ut
zhewang1-intc Jan 21, 2024
8b9fdd9
ut draft
zhewang1-intc Jan 22, 2024
5891801
todo: fix n==64 bug
zhewang1-intc Jan 22, 2024
204739c
todo: more getwei & gemm ut
zhewang1-intc Jan 22, 2024
1f99030
todo: more cmpt gemm ut
zhewang1-intc Jan 22, 2024
3290f51
bf16-cmpt ut
zhewang1-intc Jan 22, 2024
cf33294
refine ut
zhewang1-intc Jan 22, 2024
9a50a60
fix parallel
zhewang1-intc Jan 22, 2024
f2fcbc0
add int8-cmpt ut, todo: re-pad N-dim
zhewang1-intc Jan 22, 2024
bfce22f
s3_clip N-dim pad to NTILE*PACK_ROW, todo: add spr-56 core case
zhewang1-intc Jan 23, 2024
70a9487
better ut
zhewang1-intc Jan 23, 2024
97225a0
emr 64core ut
zhewang1-intc Jan 23, 2024
4668aa4
fix packrow
zhewang1-intc Jan 23, 2024
a9b177b
store-unroll
zhewang1-intc Jan 23, 2024
0dfb806
int8 cmpt benchmark
zhewang1-intc Jan 23, 2024
b445a3f
jit draft
zhewang1-intc Jan 23, 2024
597b702
pack128 draft, toddo: fix jit kernel bug
zhewang1-intc Jan 24, 2024
d863a1f
fix jit & add jit_s3_s8 ut
zhewang1-intc Jan 24, 2024
d334b63
update benchmark
zhewang1-intc Jan 24, 2024
caf342a
bit3 quant-param in ns
zhewang1-intc Jan 25, 2024
bca3bb2
add s3_clip
luoyu-intel Jan 26, 2024
3a73350
support scalef32
zhewang1-intc Jan 26, 2024
5895017
remove repeat code
zhewang1-intc Jan 26, 2024
e74541f
jit unroll store 2-zmm
zhewang1-intc Jan 30, 2024
ed378fe
jit getfpkblockwei & fix 3bit getfpwei in proB
zhewang1-intc Jan 30, 2024
dd2a000
head/tail process
zhewang1-intc Jan 30, 2024
d4d4f61
clean code
zhewang1-intc Jan 30, 2024
c381892
update kpad & head/tail process pass ut
zhewang1-intc Jan 30, 2024
ae68c14
clean code & add repack script
zhewang1-intc Jan 31, 2024
ba1b3f8
perfect code
zhewang1-intc Feb 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bestla/bestla/bestla.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ enum class BTLA_DTYPE : uint32_t {
EleBitsMask = 0xff,
EleBitsShift = 0,
EleBitsUndef = 0,
EleBits3 = 3,
EleBits4 = 4,
EleBits8 = 8,
EleBits16 = 16,
Expand Down Expand Up @@ -63,6 +64,7 @@ enum class BTLA_DTYPE : uint32_t {
DQ8_BNB = EleBits8 | TypeFloat | SubType4,
S8 = EleBits8 | TypeInt,
U8 = EleBits8 | TypeInt | SubType1,
S3_CLIP = EleBits3 | TypeInt,
S4_CLIP = EleBits4 | TypeInt,
S4_FULLRANGE = EleBits4 | TypeInt | SubType1,
F4_E2M1 = EleBits4 | TypeFloat,
Expand Down
146 changes: 107 additions & 39 deletions bestla/bestla/bestla_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,24 @@ class WeightKBlockNInteger {
});
}

static void compressBit3Weight(const int N, const int K, const int8_t* B, int8_t* dstptr,
parallel::IThreading* threading) {
// TODO(zhe): 1D parallel compress
auto ld_dst = _GemmCore_T::NTILE * utils::padto(K, 64);
auto col = _GemmCore_T::NTILE * K;
auto row = N / _GemmCore_T::NTILE;
auto pad_64_buf = utils::avector<int8_t>(row * ld_dst, 0);
kernel::wrapper::Memcpy2D::forward<BTLA_ISA::NoSIMD>(B, pad_64_buf.data(), row, col, col, ld_dst);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(dstptr);
auto bit1ptr = reinterpret_cast<utils::bit1x8*>(dstptr + row * ld_dst / 4);
auto ret =
kernel::wrapper::CompressBit3::forward<ISA_T>(pad_64_buf.data(), bit2ptr, bit1ptr, row, col, ld_dst, ld_dst);
assert(ret == BTLA_CODE::Success);
}

static void compressWeight(const int N, const int K, const int8_t* B, const int ldb, int8_t* dstptr, BTLA_DTYPE qtype,
parallel::IThreading* threading) {
if (qtype == BTLA_DTYPE::S3_CLIP) return compressBit3Weight(N, K, B, dstptr, threading);
parallel::Scheduler2D _para({threading->num_threads(), K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE});
threading->parallel_for([&](int tidx) {
parallel::ThreadProblem2D thdp({tidx});
Expand Down Expand Up @@ -611,6 +627,8 @@ class WeightKBlockNInteger {
return getQ8Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S4_CLIP || wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) {
return getQ4Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) {
return getQ3Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize);
} else {
assert(0);
}
Expand Down Expand Up @@ -690,40 +708,34 @@ class WeightKBlockNInteger {
auto KPad = wptr->mKPad;
int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW;
for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) {
if (wptr->SDtype() == BTLA_DTYPE::F32) {
auto sptr = wptr->template SPtr<float>() + n_offset + i;
if (wptr->mDType == BTLA_DTYPE::S4_CLIP) {
kernel::wrapper::DecompressKBlockS4S8Fp<T>::template forward<ISA_T, BTLA_DTYPE::S4_CLIP>(
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
i * KPad / 2,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) {
kernel::wrapper::DecompressKBlockS4S8Fp<T>::template forward<ISA_T, BTLA_DTYPE::S4_FULLRANGE>(
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
i * KPad / 2,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S8) {
kernel::wrapper::DecompressKBlockS8S8Fp<T>::template forward<ISA_T>(
wptr->template WPtr<int8_t>() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize);
}
} else if (wptr->SDtype() == BTLA_DTYPE::BF16) {
auto sptr = wptr->template SPtr<utils::bf16>() + n_offset + i;
if (wptr->mDType == BTLA_DTYPE::S4_CLIP) {
kernel::wrapper::DecompressKBlockS4S8Fp<T>::template forward<ISA_T, BTLA_DTYPE::S4_CLIP>(
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
i * KPad / 2,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) {
kernel::wrapper::DecompressKBlockS4S8Fp<T>::template forward<ISA_T, BTLA_DTYPE::S4_FULLRANGE>(
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
i * KPad / 2,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S8) {
kernel::wrapper::DecompressKBlockS8S8Fp<T>::template forward<ISA_T>(
wptr->template WPtr<int8_t>() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize);
}
if (wptr->mDType == BTLA_DTYPE::S4_CLIP) {
kernel::wrapper::DecompressKBlockS4S8Fp<T>::template forward<ISA_T, BTLA_DTYPE::S4_CLIP>(
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
i * KPad / 2,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) {
kernel::wrapper::DecompressKBlockS4S8Fp<T>::template forward<ISA_T, BTLA_DTYPE::S4_FULLRANGE>(
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
i * KPad / 2,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S8) {
kernel::wrapper::DecompressKBlockS8S8Fp<T>::template forward<ISA_T>(
wptr->template WPtr<int8_t>() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) {
int8_t* bit3_ptr = wptr->template WPtr<int8_t>();
auto elt_offset =
n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 128);
auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128);
auto row = NPad / _GemmCore_T::NTILE;
assert(elt_offset % 8 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit3_ptr + elt_offset / 4);
auto bit1ptr = reinterpret_cast<utils::bit1x8*>(bit3_ptr + row * ld_dst / 4 + elt_offset / 8);
kernel::wrapper::DecompressKBlockS3S8Fp<T>::template forward<ISA_T, BTLA_DTYPE::S3_CLIP>(
bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE,
k_size / _GemmCore_T::PACK_ROW * ColSize, tmpcache, cachesize);
} else {
assert(0);
}
}
*dststep = k_size;
Expand Down Expand Up @@ -763,6 +775,20 @@ class WeightKBlockNInteger {
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr,
zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW,
wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) {
int8_t* bit3_ptr = wptr->template WPtr<int8_t>();
auto elt_offset =
n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 128);
auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128);
auto row = NPad / _GemmCore_T::NTILE;
assert(elt_offset % 8 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit3_ptr + elt_offset / 4);
auto bit1ptr = reinterpret_cast<utils::bit1x8*>(bit3_ptr + row * ld_dst / 4 + elt_offset / 8);
kernel::wrapper::DecompressKBlockS3Fp<_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T, float,
BTLA_DTYPE::S3_CLIP>(
bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, k_size / _GemmCore_T::PACK_ROW,
ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW,
wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize);
} else {
assert(0);
}
Expand Down Expand Up @@ -790,6 +816,20 @@ class WeightKBlockNInteger {
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr,
zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW,
wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) {
int8_t* bit3_ptr = wptr->template WPtr<int8_t>();
auto elt_offset =
n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 128);
auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128);
auto row = NPad / _GemmCore_T::NTILE;
assert(elt_offset % 8 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit3_ptr + elt_offset / 4);
auto bit1ptr = reinterpret_cast<utils::bit1x8*>(bit3_ptr + row * ld_dst / 4 + elt_offset / 8);
kernel::wrapper::DecompressKBlockS3Fp<_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T, utils::bf16,
BTLA_DTYPE::S3_CLIP>(
bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, k_size / _GemmCore_T::PACK_ROW,
ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW,
wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize);
} else {
assert(0);
}
Expand Down Expand Up @@ -846,6 +886,29 @@ class WeightKBlockNInteger {
return BTLA_CODE::Success;
}

static inline BTLA_CODE getQ3Weight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset,
const Param& _param, void* tmpcache, size_t cachesize) {
auto wptr = _param.packedW;
int8_t* bit3_ptr = wptr->template WPtr<int8_t>();
auto KPad = wptr->mKPad;
auto NPad = wptr->mNPad;
int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW;
auto row = NPad / _GemmCore_T::NTILE;
auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128);
auto base_offset = n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE;
for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) {
auto elt_offset = base_offset + i * utils::padto(KPad, 128);
assert(elt_offset % 8 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit3_ptr + elt_offset / 4);
auto bit1ptr = reinterpret_cast<utils::bit1x8*>(bit3_ptr + row * ld_dst / 4 + elt_offset / 8);
kernel::wrapper::DecompressKBlockS3S8Fp<int8_t>::template forward<ISA_T, BTLA_DTYPE::S3_CLIP>(
bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE,
k_size / _GemmCore_T::PACK_ROW * ColSize, reinterpret_cast<int8_t*>(tmpcache), cachesize);
}
*dststep = k_size;
return BTLA_CODE::Success;
}

virtual inline void quantRowBlock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst,
float* scales, int8_t* zero_points, void* stor) {
auto ptr = reinterpret_cast<StorageWeight*>(stor);
Expand All @@ -859,19 +922,24 @@ class WeightKBlockNInteger {
} else if (quant_dtype == BTLA_DTYPE::S4_CLIP) {
kernel::wrapper::QuantizeSignIntRowBlock::forward<ISA_T, BTLA_DTYPE::S4_CLIP>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize);
} else if (quant_dtype == BTLA_DTYPE::S3_CLIP) {
kernel::wrapper::QuantizeSignIntRowBlock::forward<ISA_T, BTLA_DTYPE::S3_CLIP>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize);
} else {
assert(0);
}
}

static inline BTLA_CODE doCompress(const int8_t* srcptr, void* dstptr, int row, int col, int ld_src, int ld_dst,
BTLA_DTYPE quant_dtype) {
if (quant_dtype == BTLA_DTYPE::S4_CLIP || quant_dtype == BTLA_DTYPE::S4_FULLRANGE) {
return kernel::wrapper::CompressS8S4<_GemmCore_T::NTILE>::template forward<ISA_T>(
srcptr, reinterpret_cast<utils::int4x2*>(dstptr), row, col, ld_src, ld_dst);
return kernel::wrapper::CompressS8S4::forward<ISA_T>(srcptr, reinterpret_cast<utils::int4x2*>(dstptr), row, col,
ld_src, ld_dst);
} else if (quant_dtype == BTLA_DTYPE::F4_BNB || quant_dtype == BTLA_DTYPE::F4_NF4 ||
quant_dtype == BTLA_DTYPE::F4_E2M1) {
return kernel::wrapper::CompressFp4<_GemmCore_T::NTILE>::template forward<ISA_T>(
srcptr, reinterpret_cast<utils::f4x2*>(dstptr), row, col, ld_src,
ld_dst); // ld_dst here not stride
return kernel::wrapper::CompressFp4::forward<ISA_T>(srcptr, reinterpret_cast<utils::f4x2*>(dstptr), row, col,
ld_src,
ld_dst); // ld_dst here not stride
} else {
assert(0);
return BTLA_CODE::NotSupport;
Expand Down
6 changes: 5 additions & 1 deletion bestla/bestla/bestla_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,11 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase {
InfoType::resize(NPad, KPad, Block, N, K, qtype);
auto bits = utils::bestla_dtype_bits(qtype);
auto elesize = static_cast<size_t>(NPad) * KPad;
auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here
if (qtype == BTLA_DTYPE::S3_CLIP)
elesize =
static_cast<size_t>(utils::padto(KPad, 128)) * NPad; // pad K-dim to 128 because 128pack round2 interleave.
// round2 interleave ld_dim == pad_to(KPad,128) * NTILE
auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here
mQBuf.resize(bytes);
int nk_scale = utils::updiv(KPad, Block);
auto gemm_comp = bestla::gemm::CoreAttr::get_comp(mCoreId);
Expand Down
18 changes: 18 additions & 0 deletions bestla/bestla/bestla_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,24 @@ struct fp16 {
}
};

struct bit2x4 {
int8_t a : 2;
int8_t b : 2;
int8_t c : 2;
int8_t d : 2;
};

struct bit1x8 {
int8_t a : 1;
int8_t b : 1;
int8_t c : 1;
int8_t d : 1;
int8_t e : 1;
int8_t f : 1;
int8_t g : 1;
int8_t h : 1;
};

struct bit4x2 {
int8_t x : 4;
int8_t y : 4;
Expand Down
Loading
Loading