diff --git a/bestla/bestla/bestla.h b/bestla/bestla/bestla.h index 19d85b237..d3327a656 100644 --- a/bestla/bestla/bestla.h +++ b/bestla/bestla/bestla.h @@ -36,6 +36,7 @@ enum class BTLA_DTYPE : uint32_t { EleBitsMask = 0xff, EleBitsShift = 0, EleBitsUndef = 0, + EleBits3 = 3, EleBits4 = 4, EleBits8 = 8, EleBits16 = 16, @@ -63,6 +64,7 @@ enum class BTLA_DTYPE : uint32_t { DQ8_BNB = EleBits8 | TypeFloat | SubType4, S8 = EleBits8 | TypeInt, U8 = EleBits8 | TypeInt | SubType1, + S3_CLIP = EleBits3 | TypeInt, S4_CLIP = EleBits4 | TypeInt, S4_FULLRANGE = EleBits4 | TypeInt | SubType1, F4_E2M1 = EleBits4 | TypeFloat, diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index dd3d9da74..99f3ccc90 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -556,8 +556,24 @@ class WeightKBlockNInteger { }); } + static void compressBit3Weight(const int N, const int K, const int8_t* B, int8_t* dstptr, + parallel::IThreading* threading) { + // TODO(zhe): 1D parallel compress + auto ld_dst = _GemmCore_T::NTILE * utils::padto(K, 64); + auto col = _GemmCore_T::NTILE * K; + auto row = N / _GemmCore_T::NTILE; + auto pad_64_buf = utils::avector(row * ld_dst, 0); + kernel::wrapper::Memcpy2D::forward(B, pad_64_buf.data(), row, col, col, ld_dst); + auto bit2ptr = reinterpret_cast(dstptr); + auto bit1ptr = reinterpret_cast(dstptr + row * ld_dst / 4); + auto ret = + kernel::wrapper::CompressBit3::forward(pad_64_buf.data(), bit2ptr, bit1ptr, row, col, ld_dst, ld_dst); + assert(ret == BTLA_CODE::Success); + } + static void compressWeight(const int N, const int K, const int8_t* B, const int ldb, int8_t* dstptr, BTLA_DTYPE qtype, parallel::IThreading* threading) { + if (qtype == BTLA_DTYPE::S3_CLIP) return compressBit3Weight(N, K, B, dstptr, threading); parallel::Scheduler2D _para({threading->num_threads(), K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE}); threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp({tidx}); @@ -611,6 +627,8 @@ class WeightKBlockNInteger { return getQ8Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } else if (wptr->mDType == BTLA_DTYPE::S4_CLIP || wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { return getQ4Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { + return getQ3Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } else { assert(0); } @@ -690,40 +708,34 @@ class WeightKBlockNInteger { auto KPad = wptr->mKPad; int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - if (wptr->SDtype() == BTLA_DTYPE::F32) { - auto sptr = wptr->template SPtr() + n_offset + i; - if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S8) { - kernel::wrapper::DecompressKBlockS8S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } - } else if (wptr->SDtype() == BTLA_DTYPE::BF16) { - auto sptr = wptr->template SPtr() + n_offset + i; - if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S8) { - kernel::wrapper::DecompressKBlockS8S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } + if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { + kernel::wrapper::DecompressKBlockS4S8Fp::template forward( + wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + + i * KPad / 2, + *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { + kernel::wrapper::DecompressKBlockS4S8Fp::template forward( + wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + + i * KPad / 2, + *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S8) { + kernel::wrapper::DecompressKBlockS8S8Fp::template forward( + wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, + *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { + int8_t* bit3_ptr = wptr->template WPtr(); + auto elt_offset = + n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 128); + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128); + auto row = NPad / _GemmCore_T::NTILE; + assert(elt_offset % 8 == 0); + auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); + auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); + kernel::wrapper::DecompressKBlockS3S8Fp::template forward( + bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, + k_size / _GemmCore_T::PACK_ROW * ColSize, tmpcache, cachesize); + } else { + assert(0); } } *dststep = k_size; @@ -763,6 +775,20 @@ class WeightKBlockNInteger { *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { + int8_t* bit3_ptr = wptr->template WPtr(); + auto elt_offset = + n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 128); + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128); + auto row = NPad / _GemmCore_T::NTILE; + assert(elt_offset % 8 == 0); + auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); + auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); + kernel::wrapper::DecompressKBlockS3Fp<_T, _GemmCore_T::PACK_ROW>::template forward( + bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, k_size / _GemmCore_T::PACK_ROW, + ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, + wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); } else { assert(0); } @@ -790,6 +816,20 @@ class WeightKBlockNInteger { *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { + int8_t* bit3_ptr = wptr->template WPtr(); + auto elt_offset = + n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 128); + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128); + auto row = NPad / _GemmCore_T::NTILE; + assert(elt_offset % 8 == 0); + auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); + auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); + kernel::wrapper::DecompressKBlockS3Fp<_T, _GemmCore_T::PACK_ROW>::template forward( + bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, k_size / _GemmCore_T::PACK_ROW, + ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, + wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); } else { assert(0); } @@ -846,6 +886,29 @@ class WeightKBlockNInteger { return BTLA_CODE::Success; } + static inline BTLA_CODE getQ3Weight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { + auto wptr = _param.packedW; + int8_t* bit3_ptr = wptr->template WPtr(); + auto KPad = wptr->mKPad; + auto NPad = wptr->mNPad; + int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; + auto row = NPad / _GemmCore_T::NTILE; + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128); + auto base_offset = n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE; + for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { + auto elt_offset = base_offset + i * utils::padto(KPad, 128); + assert(elt_offset % 8 == 0); + auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); + auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); + kernel::wrapper::DecompressKBlockS3S8Fp::template forward( + bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, + k_size / _GemmCore_T::PACK_ROW * ColSize, reinterpret_cast(tmpcache), cachesize); + } + *dststep = k_size; + return BTLA_CODE::Success; + } + virtual inline void quantRowBlock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, float* scales, int8_t* zero_points, void* stor) { auto ptr = reinterpret_cast(stor); @@ -859,19 +922,24 @@ class WeightKBlockNInteger { } else if (quant_dtype == BTLA_DTYPE::S4_CLIP) { kernel::wrapper::QuantizeSignIntRowBlock::forward( srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); + } else if (quant_dtype == BTLA_DTYPE::S3_CLIP) { + kernel::wrapper::QuantizeSignIntRowBlock::forward( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); + } else { + assert(0); } } static inline BTLA_CODE doCompress(const int8_t* srcptr, void* dstptr, int row, int col, int ld_src, int ld_dst, BTLA_DTYPE quant_dtype) { if (quant_dtype == BTLA_DTYPE::S4_CLIP || quant_dtype == BTLA_DTYPE::S4_FULLRANGE) { - return kernel::wrapper::CompressS8S4<_GemmCore_T::NTILE>::template forward( - srcptr, reinterpret_cast(dstptr), row, col, ld_src, ld_dst); + return kernel::wrapper::CompressS8S4::forward(srcptr, reinterpret_cast(dstptr), row, col, + ld_src, ld_dst); } else if (quant_dtype == BTLA_DTYPE::F4_BNB || quant_dtype == BTLA_DTYPE::F4_NF4 || quant_dtype == BTLA_DTYPE::F4_E2M1) { - return kernel::wrapper::CompressFp4<_GemmCore_T::NTILE>::template forward( - srcptr, reinterpret_cast(dstptr), row, col, ld_src, - ld_dst); // ld_dst here not stride + return kernel::wrapper::CompressFp4::forward(srcptr, reinterpret_cast(dstptr), row, col, + ld_src, + ld_dst); // ld_dst here not stride } else { assert(0); return BTLA_CODE::NotSupport; diff --git a/bestla/bestla/bestla_storage.h b/bestla/bestla/bestla_storage.h index c7e33f645..7b13adbe9 100644 --- a/bestla/bestla/bestla_storage.h +++ b/bestla/bestla/bestla_storage.h @@ -706,7 +706,11 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase { InfoType::resize(NPad, KPad, Block, N, K, qtype); auto bits = utils::bestla_dtype_bits(qtype); auto elesize = static_cast(NPad) * KPad; - auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here + if (qtype == BTLA_DTYPE::S3_CLIP) + elesize = + static_cast(utils::padto(KPad, 128)) * NPad; // pad K-dim to 128 because 128pack round2 interleave. + // round2 interleave ld_dim == pad_to(KPad,128) * NTILE + auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here mQBuf.resize(bytes); int nk_scale = utils::updiv(KPad, Block); auto gemm_comp = bestla::gemm::CoreAttr::get_comp(mCoreId); diff --git a/bestla/bestla/bestla_utils.h b/bestla/bestla/bestla_utils.h index 891cdd80d..8804a4436 100644 --- a/bestla/bestla/bestla_utils.h +++ b/bestla/bestla/bestla_utils.h @@ -211,6 +211,24 @@ struct fp16 { } }; +struct bit2x4 { + int8_t a : 2; + int8_t b : 2; + int8_t c : 2; + int8_t d : 2; +}; + +struct bit1x8 { + int8_t a : 1; + int8_t b : 1; + int8_t c : 1; + int8_t d : 1; + int8_t e : 1; + int8_t f : 1; + int8_t g : 1; + int8_t h : 1; +}; + struct bit4x2 { int8_t x : 4; int8_t y : 4; diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index 6817804fc..6783ee8e6 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#include "bestla.h" #include "bestla_utils.h" +#include "kernel_jit.h" #include "kernel_ref.h" #include @@ -643,6 +645,89 @@ static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* d return BTLA_CODE::NotSupport; } +template +inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, + int interleave_n_offset, int unpack_elt, int8_t* tmp, size_t tmpsize) { + auto head_ignore_num = interleave_n_offset % 128; + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + auto zmm_shift = _mm512_set1_epi32(5); + + auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) { + const __m256i lowMask = _mm256_set1_epi8(0x03); + const __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1); + auto ymm0 = _mm256_and_si256(lowMask, bit2_data); // uop:1 p:015 + auto ymm1 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2)); // uop:1 p:01 + auto ymm2 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 4)); + auto ymm3 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 6)); + auto zmm1 = _mm512_inserti32x8(_mm512_castsi256_si512(ymm0), ymm1, 0x1); // lat3, tp1 uop1 p:5 + auto zmm2 = _mm512_inserti32x8(_mm512_castsi256_si512(ymm2), ymm3, 0x1); + + unsigned long long* bit1_ptr = reinterpret_cast(src2); + auto bit1_mask1 = _cvtu64_mask64(*bit1_ptr); + auto bit1_mask2 = _cvtu64_mask64(*(bit1_ptr + 1)); + auto zmm1_ = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask1, zmm_0x04); + auto zmm2_ = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask2, zmm_0x04); + zmm1 = _mm512_add_epi8(zmm1, zmm1_); + zmm2 = _mm512_add_epi8(zmm2, zmm2_); + zmm1 = _mm512_sllv_epi32(zmm1, zmm_shift); // int3_clip => int8 + zmm2 = _mm512_sllv_epi32(zmm2, zmm_shift); // int3_clip => int8 + + _mm512_storeu_epi8((__m512i*)dst, zmm1); + _mm512_storeu_epi8((__m512i*)(dst + 64), zmm2); + }; + + assert(head_ignore_num % 8 == 0); + + auto base_bit2ptr = bit2ptr - head_ignore_num / 4; + auto base_bit1ptr = bit1ptr - head_ignore_num / 8; + int compress_wei_ptr_offset = 0; + int8_t* s8_ptr = reinterpret_cast(tmp); + auto head_write_num = 128 - head_ignore_num; + if (head_ignore_num != 0) { + bit3_interleave_decompress_pack128(base_bit2ptr, base_bit1ptr, tmp); + for (int i = 0; i < head_write_num; i++) dstptr[i] = s8_ptr[head_ignore_num + i]; + compress_wei_ptr_offset += head_write_num; + } + + auto body_loop = (unpack_elt - head_write_num % 128) / 128; + auto tail_proc_num = (unpack_elt - head_write_num % 128) % 128; + + bestla::kernel::jit::DecompresssS3::forward_avx512f(bit2ptr + compress_wei_ptr_offset / 4, + bit1ptr + compress_wei_ptr_offset / 8, + dstptr + compress_wei_ptr_offset, tmp, body_loop * 128); + compress_wei_ptr_offset += body_loop * 128; + if (tail_proc_num > 0) { + bit3_interleave_decompress_pack128(base_bit2ptr, base_bit1ptr, tmp); + bit3_interleave_decompress_pack128(bit2ptr + compress_wei_ptr_offset / 4, bit1ptr + compress_wei_ptr_offset / 8, + tmp); + for (int i = 0; i < tail_proc_num; i++) dstptr[compress_wei_ptr_offset + i] = s8_ptr[i]; + } + return BTLA_CODE::Success; +} + +template +static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, + _DST_T* dstptr, int interleave_n_offset, int row, int col, + _ST* scales, int8_t* zero_points, int k_offset, int kblock, + int NPad, void* tmp, size_t tmpsize) { + auto unpack_elt = row * col; + decompress_kblock_s3_s8fp<_S3_T>(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, + reinterpret_cast(tmp), tmpsize); + // TODO(zhe): simd version + for (int i = 0; i < row; i++) { + int kpos = (k_offset + i) / kblock; + auto sptr = scales + kpos * NPad; + for (int j = 0; j < col; j++) { + float tmp = static_cast(dstptr[i * col + j]); + if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j / _PACK_ROW]); + dstptr[i * col + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); + } + } + + return BTLA_CODE::Success; +} + template static inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int k_offset, int kblock, int NPad, @@ -1261,6 +1346,7 @@ inline BTLA_CODE dq8_get_fp_scale(uint8_t* src, float* dst, int row, int col, in auto mask = _cvtu32_mask16(0xffff >> (16 - col)); get_fp_scale(col, mask, scale_offset, src + i * src_stride, dst + i * dst_stride); } else { + // TODO(zhe): consider head_proc_num==0 case. auto head_mask = _cvtu32_mask16(0xffff >> (16 - head_proc_num)); auto body_mask = _cvtu32_mask16(0xffff); get_fp_scale(head_proc_num, head_mask, scale_offset, src + i * src_stride, dst + i * dst_stride); diff --git a/bestla/bestla/kernel_jit.h b/bestla/bestla/kernel_jit.h index 81a51a41d..b19040f2c 100644 --- a/bestla/bestla/kernel_jit.h +++ b/bestla/bestla/kernel_jit.h @@ -250,6 +250,132 @@ class DequanS8FP { } } }; +class DecompresssS3 { + public: + template + class MicroKernelAVX512F : protected xbyak::JitAvx512f { + public: + struct params { + void *bit2ptr, *bit1ptr, *dstptr, *tmpbuf; + int unpack_elt; + int8_t ox3, ox4; + int ox5; + }; + typedef long long (*func_t)(params*); + static int constexpr VBytes = 64; + MicroKernelAVX512F() { + generate(); + this->ready(); + mKernel = this->getCode(); + } + + void generate() { + inLocalLabel(); // use local label for multiple instance + { + Xbyak::util::StackFrame st(this, 1, 13); + parambase = st.p[0]; + reg_bit1ptr = st.t[0]; + reg_bit2ptr = st.t[1]; + reg_loop = st.t[2]; + reg_iter = st.t[3]; + reg_dst = st.t[4]; + reg_tmp = st.t[5]; + reg_cache = st.t[6]; + reg_ret = rax; + xor_(reg_loop, reg_loop); + mov(reg_loop.cvt32(), ptr[parambase + OFFSET(unpack_elt)]); + xor_(reg_iter, reg_iter); + Xbyak::Ymm LowMask = ymm1; + Xbyak::Zmm zmm_0x04 = zmm31; + Xbyak::Zmm zmm_shift = zmm30; + vpbroadcastb(LowMask, ptr[parambase + OFFSET(ox3)]); + vpbroadcastb(zmm_0x04, ptr[parambase + OFFSET(ox4)]); + vpbroadcastd(zmm_shift, ptr[parambase + OFFSET(ox5)]); + mov(reg_bit1ptr, ptr[parambase + OFFSET(bit1ptr)]); + mov(reg_bit2ptr, ptr[parambase + OFFSET(bit2ptr)]); + mov(reg_dst, ptr[parambase + OFFSET(dstptr)]); + if constexpr (!std::is_same_v<_DST_T, int8_t>) mov(reg_cache, ptr[parambase + OFFSET(tmpbuf)]); + L("loop_label"); + imul(reg_tmp, reg_iter, 16); + kmovq(bit1_mask1, ptr[reg_bit1ptr + reg_tmp]); + kmovq(bit1_mask2, ptr[reg_bit1ptr + reg_tmp + 8]); + Xbyak::Zmm bit2_data_zmm = zmm0; + imul(reg_tmp, reg_iter, 32); + vmovups(ymm2, ptr[reg_bit2ptr + reg_tmp]); + + vpand(ymm4, LowMask, ymm2); + vpsrlw(ymm2, ymm2, 2); + vpand(ymm5, LowMask, ymm2); + vpsrlw(ymm2, ymm2, 2); + vpand(ymm6, LowMask, ymm2); + vpsrlw(ymm2, ymm2, 2); + vpand(ymm7, LowMask, ymm2); + vinsertf32x8(zmm4, zmm4, ymm5, 1); + vinsertf32x8(zmm6, zmm6, ymm7, 1); + + vxorps(zmm12, zmm12); + vxorps(zmm13, zmm13); + vmovdqu8(zmm12 | bit1_mask1, zmm_0x04); + vmovdqu8(zmm13 | bit1_mask2, zmm_0x04); + + vpaddb(zmm4, zmm4, zmm12); + vpaddb(zmm6, zmm6, zmm13); + + vpsllvd(zmm4, zmm4, zmm_shift); + vpsllvd(zmm6, zmm6, zmm_shift); + + if constexpr (std::is_same_v<_DST_T, int8_t>) { + imul(reg_tmp, reg_iter, 128); + vmovups(ptr[reg_dst + reg_tmp], zmm4); + vmovups(ptr[reg_dst + reg_tmp + 64], zmm6); + } else if constexpr (std::is_same_v<_DST_T, float> || std::is_same_v<_DST_T, utils::bf16>) { + vmovups(ptr[reg_cache], zmm4); + vmovups(ptr[reg_cache + 64], zmm6); + for (int i = 0; i < 8; i++) vpmovsxbd(Xbyak::Zmm(16 + i), ptr[reg_cache + 16 * i]); + for (int i = 0; i < 8; i++) vcvtdq2ps(Xbyak::Zmm(16 + i), Xbyak::Zmm(16 + i)); + imul(reg_tmp, reg_iter, 128 * sizeof(_DST_T)); + if constexpr (std::is_same_v<_DST_T, float>) { + for (int i = 0; i < 8; i++) vmovups(ptr[reg_dst + reg_tmp + i * 64], Xbyak::Zmm(16 + i)); + } else { + for (int i = 0; i < 8; i++) vcvtneps2bf16(Xbyak::Ymm(16 + i), Xbyak::Zmm(16 + i)); + for (int i = 0; i < 8; i++) vmovups(ptr[reg_dst + reg_tmp + i * 32], Xbyak::Ymm(16 + i)); + } + } else { + assert(0); + } + + add(reg_iter, 1); + cmp(reg_iter, reg_loop); + jb("loop_label"); + mov(reg_ret, 0); + } + outLocalLabel(); // end of local label + } + + func_t mKernel = nullptr; + + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_bit1ptr; + Xbyak::Reg64 reg_bit2ptr; + Xbyak::Reg64 reg_loop; + Xbyak::Reg64 reg_iter; + Xbyak::Reg64 reg_dst; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_cache; + Xbyak::Reg64 reg_ret; + Xbyak::Opmask bit1_mask1 = Xbyak::Opmask(1); + Xbyak::Opmask bit1_mask2 = Xbyak::Opmask(2); + Xbyak::Opmask bit1_mask3 = Xbyak::Opmask(3); + Xbyak::Opmask bit1_mask4 = Xbyak::Opmask(4); + }; + template + static void forward_avx512f(void* bit2ptr, void* bit1ptr, _DST_T* dstptr, void* tmpbuf, int unpack_elt) { + static MicroKernelAVX512F<_DST_T> ker; + typename MicroKernelAVX512F<_DST_T>::params param{bit2ptr, bit1ptr, dstptr, tmpbuf, unpack_elt / 128, 0x03, 0x4, 5}; + ker.mKernel(¶m); + } +}; class DequanKBlockS8Fp { public: diff --git a/bestla/bestla/kernel_ref.h b/bestla/bestla/kernel_ref.h index 5213713f1..7d47aa17e 100644 --- a/bestla/bestla/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -152,7 +152,6 @@ static inline BTLA_CODE transpose2d(const _T* srcptr, _T* dstptr, int row, int c return BTLA_CODE::Success; } -template static inline BTLA_CODE compress_s8_s4(const int8_t* srcptr, utils::int4x2* dstptr, int row, int col, int ld_src, int ld_dst) { for (int j = 0; j < row; j++) { @@ -166,7 +165,6 @@ static inline BTLA_CODE compress_s8_s4(const int8_t* srcptr, utils::int4x2* dstp return BTLA_CODE::Success; } -template static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, int row, int col, int ld_src, int ld_dst) { for (int j = 0; j < row; j++) { @@ -180,6 +178,48 @@ static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, i return BTLA_CODE::Success; } +static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, + int row, int col, int ld_src, int ld_dst) { + assert(col % 128 == 0); + + auto bit2_interleave = [&](int8_t* src, int8_t* dst) { + for (int i = 0; i < 128 / 4; i++) { + dst[4 * i] = src[i]; + dst[4 * i + 1] = src[128 / 4 + i]; + dst[4 * i + 2] = src[128 / 4 * 2 + i]; + dst[4 * i + 3] = src[128 / 4 * 3 + i]; + } + }; + + int8_t interleave_buf[128]; + + for (int i = 0; i < row; i++) { + for (int j = 0; j < col; j += 128) { + bit2_interleave(const_cast(srcptr + i * ld_src + j), interleave_buf); + for (int k = 0; k < 32; k++) { + bit2ptr[i * ld_dst / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5; + bit2ptr[i * ld_dst / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5; + bit2ptr[i * ld_dst / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5; + bit2ptr[i * ld_dst / 4 + j / 4 + k].d = interleave_buf[4 * k + 3] >> 5; + } + } + } + // store 1 bit without interleave as mask. + for (int i = 0; i < row; i++) { + for (int j = 0; j < col; j += 8) { + bit1ptr[i * ld_dst / 8 + j / 8].a = srcptr[i * ld_src + j] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].b = srcptr[i * ld_src + j + 1] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].c = srcptr[i * ld_src + j + 2] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].d = srcptr[i * ld_src + j + 3] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].e = srcptr[i * ld_src + j + 4] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].f = srcptr[i * ld_src + j + 5] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].g = srcptr[i * ld_src + j + 6] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].h = srcptr[i * ld_src + j + 7] >> 7; + } + } + return BTLA_CODE::Success; +} + template static inline BTLA_CODE decompress_s4_f32(utils::int4x2* srcptr, float* dstptr, int row, int col, int ld_src, int ld_dst, float* scales) { @@ -905,6 +945,7 @@ inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dst switch (S4_T) { case BTLA_DTYPE::S8: case BTLA_DTYPE::S4_CLIP: + case BTLA_DTYPE::S3_CLIP: if (zero_points == nullptr) { s8_calc_store_scale_and_quantv_sym(blocksize); } else { diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index 482879bdc..5b7c7f5b8 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -259,23 +259,30 @@ class Dq8GetScale { } }; -template class CompressS8S4 { public: template static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::int4x2* dstptr, int row, int col, int ld_src, int ld_dst) { - return ref::compress_s8_s4(srcptr, dstptr, row, col, ld_src, ld_dst); + return ref::compress_s8_s4(srcptr, dstptr, row, col, ld_src, ld_dst); } }; -template class CompressFp4 { public: template static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::f4x2* dstptr, int row, int col, int ld_src, int ld_dst) { - return ref::compress_f4(srcptr, dstptr, row, col, ld_src, ld_dst); + return ref::compress_f4(srcptr, dstptr, row, col, ld_src, ld_dst); + } +}; + +class CompressBit3 { + public: + template + static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int row, + int col, int ld_src, int ld_dst) { + return ref::compress_3bit(srcptr, bit2ptr, bit1ptr, row, col, ld_src, ld_dst); } }; @@ -454,6 +461,26 @@ class DecompressKBlockS4Fp { } }; +template // zero points always be int8_t, not compressed +class DecompressKBlockS3Fp { + public: + template + static inline BTLA_CODE forward(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, + int interleave_n_offset, int row, int col, _SCA_T* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad, void* tmp, size_t tmpsize) { + BTLA_CODE ret = BTLA_CODE::NotSupport; +#if CompileAVX512F() + if constexpr (utils::isa_base::avx512f) { + ret = avx512f::decompress_kblock_bit3_packrow_fp( + bit2ptr, bit1ptr, dstptr, interleave_n_offset, row, col, scales, zero_points, k_offset, kblock, NPad, tmp, + tmpsize); + } +#endif + assert(ret == BTLA_CODE::Success); + return ret; + } +}; + template // zero points always be int8_t, not compressed class DecompressKBlockS4S8Fp { public: @@ -476,6 +503,22 @@ class DecompressKBlockS4S8Fp { } }; +template +class DecompressKBlockS3S8Fp { + public: + template + static inline BTLA_CODE forward(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, + int interleave_n_offset, int unpack_elt, void* tmp, size_t tmpsize) { + BTLA_CODE ret = BTLA_CODE::NotSupport; +#if CompileAVX512F() + ret = avx512f::decompress_kblock_s3_s8fp(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, + reinterpret_cast(tmp), tmpsize); +#endif + assert(ret == BTLA_CODE::Success); + return ret; + } +}; + template class DecompressKBlockF4Fp { public: diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 470ef335e..29e18e35d 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -172,6 +172,102 @@ class UT_BlockQunatize_F8 { static UT_BlockQunatize_F8 sUT_BlockQunatize_F8; #endif +class UT_S3_WOQ { + public: + UT_S3_WOQ() { + UT_START(); + CheckISA(AVX512F); + ut(1, 4096, 4096, 32, 56); + ut(1, 4096, 4096, 32, 56); + ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 56); + ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 56); + } + + template + void ut(int m, int n, int k, int blocksize, int enable_thr) { + DefaultThreading.set_threads(enable_thr); + printf("%s:%d %d %d %d\n", __FUNCTION__, m, n, k, blocksize); + int ldb = n; + + int kblk_num = utils::updiv(k, blocksize); + utils::aligned_vector scales(kblk_num * n); + ut::fill_buffer_randn(scales.data(), scales.size(), 0.005f, 0.01f); + ut::UT_vector_s8 quanW; + quanW.resize(k * n); + quanW.fill_rand(-8, 7); + for (int i = 0; i < k * n; i++) quanW.data()[i] = (quanW.data()[i] * 16) & 0xe0; + + using PrologueB = prologue_b::gemm::WeightKBlockNInteger; + + PrologueB kernel; + auto ptr = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::F32, false); + auto ptr_ref = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::F32, false); + avector buffer(ptr.mSize); + avector buffer_ref(ptr_ref.mSize); + ptr.assign(buffer.data()); + ptr_ref.assign(buffer_ref.data()); + kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr, &DefaultThreading); + kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr_ref, &DefaultThreading); + using Launcher = + wrapper::gemm::LauncherKBlock; + using Parallel = parallel::gemm::SchedulerKBlock; + + Launcher launcher; + avector matC(m * n), refC(m * n); + if constexpr (ISA == BTLA_ISA::AVX512F) { + avector matAf32(m * k); + fill_buffer_randn(matAf32.data(), matAf32.size(), -0.5f, 0.5f); + utils::GemmProblem gp(1, m, n, k, blocksize); + typename Launcher::Param args{ + gp, {matAf32.data(), k}, {&ptr}, {ptr.template SPtr(), ptr.SDtype(), ptr.CStep()}, {matC.data(), n}}; + parallel::GemmRun(launcher, args, &DefaultThreading); + typename Launcher::Param args_ref{gp, + {matAf32.data(), k}, + {&ptr_ref}, + {ptr_ref.template SPtr(), ptr_ref.SDtype(), ptr_ref.CStep()}, + {refC.data(), n}}; + parallel::GemmRun(launcher, args_ref, &DefaultThreading); + } else if constexpr (ISA == BTLA_ISA::AMX_BF16) { + avector matAbf16(m * k); + fill_buffer_randn(matAbf16.data(), matAbf16.size(), utils::bf16(-0.5f), utils::bf16(0.5f)); + GemmProblem gp(1, m, n, k, blocksize); + typename Launcher::Param args{ + gp, {matAbf16.data(), k}, {&ptr}, {ptr.template SPtr(), ptr.SDtype(), ptr.CStep()}, {matC.data(), n}}; + parallel::GemmRun(launcher, args, &DefaultThreading); + typename Launcher::Param args_ref{gp, + {matAbf16.data(), k}, + {&ptr_ref}, + {ptr_ref.template SPtr(), ptr_ref.SDtype(), ptr_ref.CStep()}, + {refC.data(), n}}; + parallel::GemmRun(launcher, args_ref, &DefaultThreading); + } else { + using Launcher2 = wrapper::gemm::LauncherIntKBlock; + Launcher2 launcher; + using Parallel2 = parallel::gemm::SchedulerKBlockS; + avector matAf32(m * k); + fill_buffer_randn(matAf32.data(), matAf32.size(), -0.5f, 0.5f); + auto quanA = launcher.mProA.createStorage(m, k, blocksize, false); + auto quanA_ref = launcher.mProA.createStorage(m, k, blocksize, false); + utils::avector bufferA(quanA.mSize); + utils::avector bufferA_ref(quanA.mSize); + quanA.assign(bufferA.data()); + quanA_ref.assign(bufferA_ref.data()); + GemmProblem gp(1, m, n, k, blocksize); + typename Launcher2::Param args{gp, {matAf32.data(), k, &quanA}, {&ptr}, {matC.data(), n}}; + parallel::GemmRunWithA(launcher, args, &DefaultThreading); + typename Launcher2::Param args_ref{gp, {matAf32.data(), k, &quanA_ref}, {&ptr_ref}, {refC.data(), n}}; + parallel::GemmRunWithA(launcher, args_ref, &DefaultThreading); + } + buffer_error(matC.data(), refC.data(), matC.size(), 0.001f); + } +}; +#ifdef BTLA_UT_PROLOGUE_B +static UT_S3_WOQ sUT_S3_WOQ; +#endif class UT_TransposeBlockQuantize_F4 { public: UT_TransposeBlockQuantize_F4() { @@ -725,8 +821,11 @@ class UTBenchmark_CompFp32 { } void ut_s4() { - benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); - benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + // benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + benchmark_all(32, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); + benchmark_all(32, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + // benchmark_all(1024, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); + // benchmark_all(1024, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(2048, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(4096, 4096, 11008, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); @@ -786,7 +885,7 @@ class UTBenchmark_CompFp32 { kernel.mProB.packWeight(n, k, B, n, &packBs[0], &DefaultThreading); for (size_t i = 1; i < batch; i++) { memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); - memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(float)); + memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); } auto psize = (size_t)m * n * k * 2; auto memsize = (size_t)packBs[0].mSize + (m * k + m * n) * sizeof(float); @@ -812,9 +911,14 @@ class UTBenchmark_CompFp32 { int threads, BTLA_DTYPE qtype) { LOG_T log; using Parallel = parallel::gemm::SchedulerKBlock; + // using Launcher = + // wrapper::gemm::LauncherKBlock; using Launcher = - wrapper::gemm::LauncherKBlock; + wrapper::gemm::LauncherIntKBlock; Launcher kernel; DefaultThreading.set_threads(threads); auto corestr = gemm::CoreAttr::to_str(Core_T::ID); @@ -831,6 +935,9 @@ class UTBenchmark_CompFp32 { } std::vector packBs(batch, 0); std::vector bufB(tmpB.mSize * batch); + auto quanA = kernel.mProA.createStorage(m, k, blocksize, false); + utils::avector bufferA(quanA.mSize); + quanA.assign(bufferA.data()); for (size_t i = 0; i < batch; i++) { packBs[i] = tmpB; packBs[i].assign(bufB.data() + i * tmpB.mSize); @@ -838,7 +945,7 @@ class UTBenchmark_CompFp32 { kernel.mProB.packWeight(n, k, B, n, &packBs[0], &DefaultThreading); for (size_t i = 1; i < batch; i++) { memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); - memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(float)); + memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); } auto psize = (size_t)m * n * k * 2; auto memsize = (size_t)packBs[0].mSize + (m * k + m * n) * sizeof(float); @@ -848,11 +955,12 @@ class UTBenchmark_CompFp32 { for (size_t i = 0; i < batch; i++) { GemmProblem gp(1, m, n, k, blocksize); typename Launcher::Param args{gp, - {A + i * m * k, k}, + {A + i * m * k, k, &quanA}, {&packBs[i]}, - {packBs[i].template SPtr(), packBs[i].SDtype(), packBs[i].CStep()}, + // {packBs[i].template SPtr(), packBs[i].SDtype(), packBs[i].CStep()}, {C + i * m * n, n}}; - parallel::GemmRun(kernel, args, &DefaultThreading); + // parallel::GemmRun(kernel, args, &DefaultThreading); + parallel::GemmRunWithA(kernel, args, &DefaultThreading); } if (log.stop()) { double t = log.avg_val / batch; @@ -879,16 +987,17 @@ class UTBenchmark_CompFp32 { float testtime = 500.f; GetCPUDevice(); if (_cd->AVX512F()) { - int blocksize = 32; - benchmark, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), - C.data(), testtime, 48, qtype); - benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), - C.data(), testtime, 48, qtype); - blocksize = 128; - benchmark, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), - C.data(), testtime, 48, qtype); - benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), - C.data(), testtime, 48, qtype); + // int blocksize = 32; + // benchmark, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), + // C.data(), testtime, 48, qtype); + // benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), + // C.data(), testtime, 48, qtype); + int blocksize = 128; + // benchmark, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), + // C.data(), testtime, 48, qtype); + // benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), + benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), + B.data(), C.data(), testtime, 56, qtype); } } }; diff --git a/bestla/bestla/ut/kernel_jit.cpp b/bestla/bestla/ut/kernel_jit.cpp index 613bcc154..ce1198c99 100644 --- a/bestla/bestla/ut/kernel_jit.cpp +++ b/bestla/bestla/ut/kernel_jit.cpp @@ -286,7 +286,7 @@ class UT_DecompressS4S8 { aligned_vector src(row * col / 2); aligned_vector src8(row * col); ut::fill_buffer_randn(src8.data(), src8.size(), int8_t(-128), int8_t(127)); - kernel::ref::compress_s8_s4<48>(src8.data(), src.data(), row, col, col, col); + kernel::ref::compress_s8_s4(src8.data(), src.data(), row, col, col, col); aligned_vector ref(row * col), tar(row * col); kernel::ref::decompress_s4_s8(src.data(), ref.data(), row, col, col, col); kernel::jit::decompress_s4_s8(src.data(), tar.data(), row, col, col, col); diff --git a/neural_speed/convert/convert_quantized_gptj.py b/neural_speed/convert/convert_quantized_gptj.py index 2b021f63a..9711a698b 100644 --- a/neural_speed/convert/convert_quantized_gptj.py +++ b/neural_speed/convert/convert_quantized_gptj.py @@ -39,14 +39,15 @@ def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config): scales = model[f"{src_name}.scales"] qweight = model[f"{src_name}.qweight"] - int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config) - int_weight = int_weight.view(-1,int_weight.shape[-1]) + int_weight, gptq_scales, gptq_zeros = unpack_weight( + qweight, scales, qzeros, q_config) + int_weight = int_weight.view(-1, int_weight.shape[-1]) # shuffle weight in GPTQ when act order is on - if 'desc_act'in q_config and q_config['desc_act']: + if 'desc_act' in q_config and q_config['desc_act']: g_idx = model[f"{src_name}.g_idx"] int_weight2 = int_weight.clone() - group_size=q_config['group_size'] + group_size = q_config['group_size'] group_dict = {} for i in range(len(g_idx)): group_idx = g_idx[i].item() @@ -84,7 +85,7 @@ def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config): gptq_zeros = np.empty(0, dtype=np.int8) else: gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy()) - if 'desc_act'in q_config and q_config['desc_act']: + if 'desc_act' in q_config and q_config['desc_act']: g_idx = np.ascontiguousarray(g_idx.numpy()) else: g_idx = np.empty(0, dtype=np.int32) @@ -96,7 +97,7 @@ def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config): alg="sym" if q_config['sym'] else "asym", compute_dtype="int8") dst.flatten()[:byte_size].tofile(fout) - print(f"converting {dst_name} quantized tensor to bestla q4 block") + print(f"converting {dst_name} quantized tensor to bestla q{q_config['bits']} block") def main(args_in: Optional[List[str]] = None) -> None: @@ -111,13 +112,14 @@ def main(args_in: Optional[List[str]] = None) -> None: model, config, quantize_config = load_quantized_model(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True) fout = open(out_path, "wb") # 1. write hparams hparams = config n_layer = hparams["n_layer"] - fout.write(b"ggjt"[::-1]) #0x67676d6c)) # magic: ggml in hex + fout.write(b"ggjt"[::-1]) # 0x67676d6c)) # magic: ggml in hex values = [ 1, # file version hparams["vocab_size"], @@ -140,10 +142,10 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) - fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps + fout.write(struct.pack("f", hparams.get( + "rms_norm_eps", 1e-6))) # rms norm eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor - fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1)) fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2)) fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1)) diff --git a/neural_speed/core/layers/bestla_gemm.cpp b/neural_speed/core/layers/bestla_gemm.cpp index f2421f9d4..471dc0e03 100644 --- a/neural_speed/core/layers/bestla_gemm.cpp +++ b/neural_speed/core/layers/bestla_gemm.cpp @@ -264,6 +264,7 @@ size_t BTLAGemmPackBSize(size_t N, size_t K, size_t BlkSize, BTLA_DTYPE QuantTyp ne_comp_type CompType, int* shuffle_indice) { switch (QuantType) { case BTLA_DTYPE::S4_CLIP: + case BTLA_DTYPE::S3_CLIP: case BTLA_DTYPE::S4_FULLRANGE: case BTLA_DTYPE::S8: return BTLAGemmPackBSizeLocal(N, K, BlkSize, QuantType, ScaleDtype, @@ -364,6 +365,7 @@ bool BTLAGemmQuantPackB(void* PackedBuf, const float* FpData, size_t N, size_t K void* ThreadPool) { switch (QuantType) { case BTLA_DTYPE::S4_CLIP: + case BTLA_DTYPE::S3_CLIP: case BTLA_DTYPE::S4_FULLRANGE: case BTLA_DTYPE::S8: return BTLAGemmQuantPackBLocal( @@ -566,6 +568,7 @@ size_t BTLAGemmPackBSize(size_t N, size_t K, size_t BlkSize, BTLA_DTYPE QuantTyp ne_comp_type CompType, int* shuffle_indice) { switch (QuantType) { case BTLA_DTYPE::S4_CLIP: + case BTLA_DTYPE::S3_CLIP: case BTLA_DTYPE::S4_FULLRANGE: case BTLA_DTYPE::S8: return BTLAGemmPackBSizeLocal(N, K, BlkSize, QuantType, ScaleDtype, @@ -589,6 +592,7 @@ bool BTLAGemmQuantPackB(void* PackedBuf, const float* FpData, size_t N, size_t K void* ThreadPool) { switch (QuantType) { case BTLA_DTYPE::S4_CLIP: + case BTLA_DTYPE::S3_CLIP: case BTLA_DTYPE::S4_FULLRANGE: case BTLA_DTYPE::S8: return BTLAGemmQuantPackBLocal( @@ -611,6 +615,7 @@ bool BTLAGemmPackB(void* PackedBuf, const int8_t* QData, const float* Scales, co ne_comp_type CompType, int* shuffle_indice, void* ThreadPool) { // only for integer quant switch (QuantType) { + case BTLA_DTYPE::S3_CLIP: case BTLA_DTYPE::S4_CLIP: case BTLA_DTYPE::S4_FULLRANGE: case BTLA_DTYPE::S8: diff --git a/neural_speed/models/model_utils/quant_config.h b/neural_speed/models/model_utils/quant_config.h index f470c7598..e6ee60040 100644 --- a/neural_speed/models/model_utils/quant_config.h +++ b/neural_speed/models/model_utils/quant_config.h @@ -18,8 +18,11 @@ #include "core/data_types.h" #include "bestla/bestla.h" -enum class quant_bits : int { q4 = 0, q8, fp4_e2m1, nf4, fp8_e4m3, fp8_e5m2, count }; +enum class quant_bits : int { q4 = 0, q3, q8, fp4_e2m1, nf4, fp8_e4m3, fp8_e5m2, count }; static inline quant_bits parse_bits(const std::string& bits) { + if (bits == "int3") { + return quant_bits::q3; + } if (bits == "int4") { return quant_bits::q4; } diff --git a/neural_speed/models/model_utils/quant_utils.cpp b/neural_speed/models/model_utils/quant_utils.cpp index 0eacbc771..65e943875 100644 --- a/neural_speed/models/model_utils/quant_utils.cpp +++ b/neural_speed/models/model_utils/quant_utils.cpp @@ -238,6 +238,7 @@ size_t bestla_qpack(const int8_t* src_w, const float* src_scales, const int8_t* if (params.bits == quant_bits::q8) { quant_type = BTLA_DTYPE::S8; } + if (params.bits == quant_bits::q3) quant_type = BTLA_DTYPE::S3_CLIP; auto dtype_type = static_cast( bestla::utils::bestla_dtype_get_mask_val(quant_type, BTLA_DTYPE::TypeMask, BTLA_DTYPE::TypeShift)); if (dtype_type == BTLA_DTYPE::TypeFloat) { @@ -280,6 +281,9 @@ size_t bestla_quantize(const float* f32ptr, void* dstpr, const quant_params_inte bestla::parallel::StdThreading threading(nthread); #endif BTLA_DTYPE quant_type = BTLA_DTYPE::S4_CLIP; + if (params.bits == quant_bits::q3) { + quant_type = BTLA_DTYPE::S3_CLIP; + } if (params.bits == quant_bits::q8) { quant_type = BTLA_DTYPE::S8; }