From be584d5aff292812d79e21f106581e07179a6f52 Mon Sep 17 00:00:00 2001 From: Zhenzhong1 Date: Mon, 29 Jan 2024 19:55:10 -0800 Subject: [PATCH 01/33] support 4bits gptq for gptj --- neural_speed/convert/convert_quantized_gptj.py | 1 - 1 file changed, 1 deletion(-) diff --git a/neural_speed/convert/convert_quantized_gptj.py b/neural_speed/convert/convert_quantized_gptj.py index 2b021f63a..8d2a22147 100644 --- a/neural_speed/convert/convert_quantized_gptj.py +++ b/neural_speed/convert/convert_quantized_gptj.py @@ -143,7 +143,6 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor - fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1)) fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2)) fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1)) From 018bbfeae08274356a0620603cfcf641304b9b37 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Fri, 19 Jan 2024 12:01:35 +0800 Subject: [PATCH 02/33] 3bit storage --- bestla/CMakeLists.txt | 2 +- bestla/bestla/bestla.h | 2 ++ bestla/bestla/bestla_prologue_b.h | 21 +++++++++++++---- bestla/bestla/bestla_storage.h | 4 +++- bestla/bestla/bestla_utils.h | 18 +++++++++++++++ bestla/bestla/kernel_ref.h | 32 ++++++++++++++++++++++++-- bestla/bestla/kernel_wrapper.h | 15 ++++++++---- bestla/bestla/ut/bestla_prologue_b.cpp | 1 + bestla/bestla/ut/kernel_jit.cpp | 2 +- 9 files changed, 83 insertions(+), 14 deletions(-) diff --git a/bestla/CMakeLists.txt b/bestla/CMakeLists.txt index 2b17a8603..834e61160 100644 --- a/bestla/CMakeLists.txt +++ b/bestla/CMakeLists.txt @@ -7,7 +7,7 @@ file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) option(BTLA_USE_OPENMP "Enable OpenMP thread pool" OFF) option(BTLA_UT_ALL "Enable all unit tests" OFF) -option(BTLA_UT_DEBUG "Enable debug unit tests" OFF) +option(BTLA_UT_DEBUG "Enable debug unit tests" ON) option(BTLA_UT_EPILOGUE "Enable unit test for epilogue" OFF) option(BTLA_UT_PROLOGUE_A "Enable unit test for activation prologue" OFF) option(BTLA_UT_PROLOGUE_B "Enable unit test for weight prologue" OFF) diff --git a/bestla/bestla/bestla.h b/bestla/bestla/bestla.h index 19d85b237..d3327a656 100644 --- a/bestla/bestla/bestla.h +++ b/bestla/bestla/bestla.h @@ -36,6 +36,7 @@ enum class BTLA_DTYPE : uint32_t { EleBitsMask = 0xff, EleBitsShift = 0, EleBitsUndef = 0, + EleBits3 = 3, EleBits4 = 4, EleBits8 = 8, EleBits16 = 16, @@ -63,6 +64,7 @@ enum class BTLA_DTYPE : uint32_t { DQ8_BNB = EleBits8 | TypeFloat | SubType4, S8 = EleBits8 | TypeInt, U8 = EleBits8 | TypeInt | SubType1, + S3_CLIP = EleBits3 | TypeInt, S4_CLIP = EleBits4 | TypeInt, S4_FULLRANGE = EleBits4 | TypeInt | SubType1, F4_E2M1 = EleBits4 | TypeFloat, diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index dd3d9da74..d6f9df723 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once #include +#include "bestla.h" #include "bestla_utils.h" #include "bestla_storage.h" #include "bestla_device.h" @@ -556,8 +557,18 @@ class WeightKBlockNInteger { }); } + static void compressBit3Weight(const int N, const int K, const int8_t* B, const int ldb, int8_t* dstptr, + parallel::IThreading* threading) { + // TODO(zhe): 1D parallel compress + auto bit2ptr = reinterpret_cast(dstptr); + auto bit1ptr = reinterpret_cast(dstptr + K * ldb / 4); + auto ret = kernel::wrapper::CompressBit3::forward(B, bit2ptr, bit1ptr, K, N, ldb, ldb); + assert(ret == BTLA_CODE::Success); + } + static void compressWeight(const int N, const int K, const int8_t* B, const int ldb, int8_t* dstptr, BTLA_DTYPE qtype, parallel::IThreading* threading) { + if (qtype == BTLA_DTYPE::S3_CLIP) return compressBit3Weight(N, K, B, ldb, dstptr, threading); parallel::Scheduler2D _para({threading->num_threads(), K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE}); threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp({tidx}); @@ -865,13 +876,13 @@ class WeightKBlockNInteger { static inline BTLA_CODE doCompress(const int8_t* srcptr, void* dstptr, int row, int col, int ld_src, int ld_dst, BTLA_DTYPE quant_dtype) { if (quant_dtype == BTLA_DTYPE::S4_CLIP || quant_dtype == BTLA_DTYPE::S4_FULLRANGE) { - return kernel::wrapper::CompressS8S4<_GemmCore_T::NTILE>::template forward( - srcptr, reinterpret_cast(dstptr), row, col, ld_src, ld_dst); + return kernel::wrapper::CompressS8S4::forward(srcptr, reinterpret_cast(dstptr), row, col, + ld_src, ld_dst); } else if (quant_dtype == BTLA_DTYPE::F4_BNB || quant_dtype == BTLA_DTYPE::F4_NF4 || quant_dtype == BTLA_DTYPE::F4_E2M1) { - return kernel::wrapper::CompressFp4<_GemmCore_T::NTILE>::template forward( - srcptr, reinterpret_cast(dstptr), row, col, ld_src, - ld_dst); // ld_dst here not stride + return kernel::wrapper::CompressFp4::forward(srcptr, reinterpret_cast(dstptr), row, col, + ld_src, + ld_dst); // ld_dst here not stride } else { assert(0); return BTLA_CODE::NotSupport; diff --git a/bestla/bestla/bestla_storage.h b/bestla/bestla/bestla_storage.h index c7e33f645..2f9743bfb 100644 --- a/bestla/bestla/bestla_storage.h +++ b/bestla/bestla/bestla_storage.h @@ -706,7 +706,9 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase { InfoType::resize(NPad, KPad, Block, N, K, qtype); auto bits = utils::bestla_dtype_bits(qtype); auto elesize = static_cast(NPad) * KPad; - auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here + if (qtype == BTLA_DTYPE::S3_CLIP) + elesize = static_cast(utils::padto(NPad, 64)) * KPad; // pad N-dim to 64 because 64pack interleave. + auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here mQBuf.resize(bytes); int nk_scale = utils::updiv(KPad, Block); auto gemm_comp = bestla::gemm::CoreAttr::get_comp(mCoreId); diff --git a/bestla/bestla/bestla_utils.h b/bestla/bestla/bestla_utils.h index 891cdd80d..8804a4436 100644 --- a/bestla/bestla/bestla_utils.h +++ b/bestla/bestla/bestla_utils.h @@ -211,6 +211,24 @@ struct fp16 { } }; +struct bit2x4 { + int8_t a : 2; + int8_t b : 2; + int8_t c : 2; + int8_t d : 2; +}; + +struct bit1x8 { + int8_t a : 1; + int8_t b : 1; + int8_t c : 1; + int8_t d : 1; + int8_t e : 1; + int8_t f : 1; + int8_t g : 1; + int8_t h : 1; +}; + struct bit4x2 { int8_t x : 4; int8_t y : 4; diff --git a/bestla/bestla/kernel_ref.h b/bestla/bestla/kernel_ref.h index 5213713f1..5b158fbc5 100644 --- a/bestla/bestla/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -152,7 +152,6 @@ static inline BTLA_CODE transpose2d(const _T* srcptr, _T* dstptr, int row, int c return BTLA_CODE::Success; } -template static inline BTLA_CODE compress_s8_s4(const int8_t* srcptr, utils::int4x2* dstptr, int row, int col, int ld_src, int ld_dst) { for (int j = 0; j < row; j++) { @@ -166,7 +165,6 @@ static inline BTLA_CODE compress_s8_s4(const int8_t* srcptr, utils::int4x2* dstp return BTLA_CODE::Success; } -template static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, int row, int col, int ld_src, int ld_dst) { for (int j = 0; j < row; j++) { @@ -180,6 +178,36 @@ static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, i return BTLA_CODE::Success; } +static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, + int row, int col, int ld_src, int ld_dst) { + assert(col % 64 == 0); + // interleave + store 2bit. + for (int i = 0; i < row; i++) { + for (int j = 0; j < col; i += 64) { + for (int k = 0; k < 16; k++) { + bit2ptr[i * ld_dst / 4 + i / 4 + k].a = srcptr[i * ld_src + col + 4 * k]; + bit2ptr[i * ld_dst / 4 + i / 4 + k].b = srcptr[i * ld_src + col + 4 * k + 1]; + bit2ptr[i * ld_dst / 4 + i / 4 + k].c = srcptr[i * ld_src + col + 4 * k + 2]; + bit2ptr[i * ld_dst / 4 + i / 4 + k].d = srcptr[i * ld_src + col + 4 * k + 3]; + } + } + } + // store 1 bit without interleave as mask. + for (int i = 0; i < row; i++) { + for (int j = 0; j < col; j += 8) { + bit1ptr[i * ld_dst].a = srcptr[i * ld_src + j] >> 2; + bit1ptr[i * ld_dst].b = srcptr[i * ld_src + j + 1] >> 2; + bit1ptr[i * ld_dst].c = srcptr[i * ld_src + j + 2] >> 2; + bit1ptr[i * ld_dst].d = srcptr[i * ld_src + j + 3] >> 2; + bit1ptr[i * ld_dst].e = srcptr[i * ld_src + j + 4] >> 2; + bit1ptr[i * ld_dst].f = srcptr[i * ld_src + j + 5] >> 2; + bit1ptr[i * ld_dst].g = srcptr[i * ld_src + j + 6] >> 2; + bit1ptr[i * ld_dst].h = srcptr[i * ld_src + j + 7] >> 2; + } + } + return BTLA_CODE::Success; +} + template static inline BTLA_CODE decompress_s4_f32(utils::int4x2* srcptr, float* dstptr, int row, int col, int ld_src, int ld_dst, float* scales) { diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index 482879bdc..cb20e5b69 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -259,23 +259,30 @@ class Dq8GetScale { } }; -template class CompressS8S4 { public: template static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::int4x2* dstptr, int row, int col, int ld_src, int ld_dst) { - return ref::compress_s8_s4(srcptr, dstptr, row, col, ld_src, ld_dst); + return ref::compress_s8_s4(srcptr, dstptr, row, col, ld_src, ld_dst); } }; -template class CompressFp4 { public: template static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::f4x2* dstptr, int row, int col, int ld_src, int ld_dst) { - return ref::compress_f4(srcptr, dstptr, row, col, ld_src, ld_dst); + return ref::compress_f4(srcptr, dstptr, row, col, ld_src, ld_dst); + } +}; + +class CompressBit3 { + public: + template + static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int row, + int col, int ld_src, int ld_dst) { + return ref::compress_3bit(srcptr, bit2ptr, bit1ptr, row, col, ld_src, ld_dst); } }; diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 470ef335e..e9610ecae 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -515,6 +515,7 @@ class UT_ShuffleIndices { delete wptr; } }; +static UT_ShuffleIndices sUT_ShuffleIndices; #ifdef BTLA_UT_PROLOGUE_B static UT_ShuffleIndices sUT_ShuffleIndices; #endif diff --git a/bestla/bestla/ut/kernel_jit.cpp b/bestla/bestla/ut/kernel_jit.cpp index 613bcc154..ce1198c99 100644 --- a/bestla/bestla/ut/kernel_jit.cpp +++ b/bestla/bestla/ut/kernel_jit.cpp @@ -286,7 +286,7 @@ class UT_DecompressS4S8 { aligned_vector src(row * col / 2); aligned_vector src8(row * col); ut::fill_buffer_randn(src8.data(), src8.size(), int8_t(-128), int8_t(127)); - kernel::ref::compress_s8_s4<48>(src8.data(), src.data(), row, col, col, col); + kernel::ref::compress_s8_s4(src8.data(), src.data(), row, col, col, col); aligned_vector ref(row * col), tar(row * col); kernel::ref::decompress_s4_s8(src.data(), ref.data(), row, col, col, col); kernel::jit::decompress_s4_s8(src.data(), tar.data(), row, col, col, col); From d8fa56de18ad485ccbe2b72e4baeff5d421f931d Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Fri, 19 Jan 2024 16:46:12 +0800 Subject: [PATCH 03/33] tmp --- bestla/bestla/bestla_prologue_b.h | 30 ++++++++++++++++++++++++++---- bestla/bestla/bestla_storage.h | 6 ++++-- bestla/bestla/kernel_avx512f.h | 9 +++++++++ bestla/bestla/kernel_wrapper.h | 17 +++++++++++++++++ 4 files changed, 56 insertions(+), 6 deletions(-) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index d6f9df723..4ce4e1532 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once #include +#include #include "bestla.h" #include "bestla_utils.h" #include "bestla_storage.h" @@ -557,18 +558,26 @@ class WeightKBlockNInteger { }); } - static void compressBit3Weight(const int N, const int K, const int8_t* B, const int ldb, int8_t* dstptr, + static void compressBit3Weight(const int N, const int K, const int8_t* B, int8_t* dstptr, parallel::IThreading* threading) { // TODO(zhe): 1D parallel compress + // N==NPad, K==Kpad, ldb==Kpad + auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(K, 64); + auto col = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * K; + assert(N % (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE) == 0); + auto row = N / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE); + auto pad_64_buf = utils::avector(row * ld_dst, 0); + kernel::wrapper::Memcpy2D::forward(B, pad_64_buf.data(), row, col, col, ld_dst); auto bit2ptr = reinterpret_cast(dstptr); - auto bit1ptr = reinterpret_cast(dstptr + K * ldb / 4); - auto ret = kernel::wrapper::CompressBit3::forward(B, bit2ptr, bit1ptr, K, N, ldb, ldb); + auto bit1ptr = reinterpret_cast(dstptr + row * ld_dst / 4); + auto ret = + kernel::wrapper::CompressBit3::forward(pad_64_buf.data(), bit2ptr, bit1ptr, row, col, ld_dst, ld_dst); assert(ret == BTLA_CODE::Success); } static void compressWeight(const int N, const int K, const int8_t* B, const int ldb, int8_t* dstptr, BTLA_DTYPE qtype, parallel::IThreading* threading) { - if (qtype == BTLA_DTYPE::S3_CLIP) return compressBit3Weight(N, K, B, ldb, dstptr, threading); + if (qtype == BTLA_DTYPE::S3_CLIP) return compressBit3Weight(N, K, B, dstptr, threading); parallel::Scheduler2D _para({threading->num_threads(), K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE}); threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp({tidx}); @@ -801,6 +810,19 @@ class WeightKBlockNInteger { *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { + // n_offset+i : real n_offset k_offset: real k_offset. + int8_t* bit3_ptr = wptr->template WPtr(); + auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad; + auto ld_dst = utils::padto(NPad, 64); + assert(elt_offset % 8 == 0); + auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); + auto bit1ptr = reinterpret_cast(bit3_ptr + wptr->mKPad * ld_dst / 4 + elt_offset / 8); + kernel::wrapper::DecompressKBlockS3Fp<_T, _GemmCore_T::PACK_ROW>::template forward( + bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, + k_size / _GemmCore_T::PACK_ROW * ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, + k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad); } else { assert(0); } diff --git a/bestla/bestla/bestla_storage.h b/bestla/bestla/bestla_storage.h index 2f9743bfb..81b48387f 100644 --- a/bestla/bestla/bestla_storage.h +++ b/bestla/bestla/bestla_storage.h @@ -707,8 +707,10 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase { auto bits = utils::bestla_dtype_bits(qtype); auto elesize = static_cast(NPad) * KPad; if (qtype == BTLA_DTYPE::S3_CLIP) - elesize = static_cast(utils::padto(NPad, 64)) * KPad; // pad N-dim to 64 because 64pack interleave. - auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here + elesize = static_cast(utils::padto(KPad, 64)) * + NPad; // pad K-dim to 64 because 64pack round2 interleave. round2 interleave ld_dim == PACK_ROW * + // pad_to(KPad,64) * NTILE + auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here mQBuf.resize(bytes); int nk_scale = utils::updiv(KPad, Block); auto gemm_comp = bestla::gemm::CoreAttr::get_comp(mCoreId); diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index 6817804fc..257d2c2ca 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#include "bestla.h" #include "bestla_utils.h" #include "kernel_ref.h" @@ -643,6 +644,14 @@ static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* d return BTLA_CODE::NotSupport; } +template +static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, + _DST_T* dstptr, int interleave_n_offset, int unpack_elt, + _ST* scales, int8_t* zero_points, int k_offset, int kblock, + int NPad) { + return BTLA_CODE::Success; +} + template static inline BTLA_CODE decompress_kblock_f4_fp(utils::f4x2* srcptr, _DST_T* dstptr, int row, int col, int ld_src, int ld_dst, _ST* scales, int k_offset, int kblock, int NPad, diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index cb20e5b69..19c38e985 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -461,6 +461,23 @@ class DecompressKBlockS4Fp { } }; +template // zero points always be int8_t, not compressed +class DecompressKBlockS3Fp { + template + static inline BTLA_CODE forward(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, + int interleave_n_offset, int unpack_elt, _SCA_T* scales, int8_t* zero_points, + int k_offset, int kblock, int NPad) { + BTLA_CODE ret = BTLA_CODE::NotSupport; +#if CompileAVX512F() + if constexpr (utils::isa_base::avx512f) { + ret = avx512f::decompress_kblock_bit3_packrow_fp( + bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, scales, zero_points, k_offset, kblock, NPad); + } +#endif + return ret; + } +} + template // zero points always be int8_t, not compressed class DecompressKBlockS4S8Fp { public: From 8b68badb61739c1ba21bdce2ff3e731954a64f9d Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Sun, 21 Jan 2024 13:31:50 +0800 Subject: [PATCH 04/33] avx512f kernel draft, todo:add ut --- bestla/bestla/bestla_prologue_b.h | 7 ++-- bestla/bestla/kernel_avx512f.h | 67 ++++++++++++++++++++++++++++++- bestla/bestla/kernel_wrapper.h | 4 +- 3 files changed, 71 insertions(+), 7 deletions(-) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 4ce4e1532..1c1f3492f 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -811,7 +811,6 @@ class WeightKBlockNInteger { zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { - // n_offset+i : real n_offset k_offset: real k_offset. int8_t* bit3_ptr = wptr->template WPtr(); auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad; auto ld_dst = utils::padto(NPad, 64); @@ -820,9 +819,9 @@ class WeightKBlockNInteger { auto bit1ptr = reinterpret_cast(bit3_ptr + wptr->mKPad * ld_dst / 4 + elt_offset / 8); kernel::wrapper::DecompressKBlockS3Fp<_T, _GemmCore_T::PACK_ROW>::template forward( - bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, - k_size / _GemmCore_T::PACK_ROW * ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, - k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad); + bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, k_size / _GemmCore_T::PACK_ROW, + ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, + wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad); } else { assert(0); } diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index 257d2c2ca..dc190a2e8 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -646,9 +646,73 @@ static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* d template static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, - _DST_T* dstptr, int interleave_n_offset, int unpack_elt, + _DST_T* dstptr, int interleave_n_offset, int row, int col, _ST* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { + auto head_ignore_num = interleave_n_offset % 64; + auto zmm_0x04 = _mm512_set1_epi8(0x04); + auto zmm_0x00 = _mm512_set1_epi8(0x00); + + auto bit3_interleave_decompress = [&](utils::bit2x4* src1, utils::bit1x8* src2) { + const __m128i lowMask = _mm_set1_epi8(0x03); + const __m128i bit2_data = _mm_loadu_si128((const __m128i*)src1); + auto xmm0 = _mm_and_si128(lowMask, bit2_data); // uop:1 p:015 + auto xmm1 = _mm_and_si128(lowMask, _mm_srli_epi16(bit2_data, 2)); // uop:1 p:01 + auto xmm2 = _mm_and_si128(lowMask, _mm_srli_epi16(bit2_data, 4)); + auto xmm3 = _mm_and_si128(lowMask, _mm_srli_epi16(bit2_data, 6)); + auto ymm1 = _mm256_set_m128i(xmm1, xmm0); // uop:1 p5 + auto ymm2 = _mm256_set_m128i(xmm3, xmm2); + auto zmm = _mm512_castps_si512(_mm512_insertf32x8(_mm512_castsi256_si512(ymm1), ymm2, 0x1)); + // make bit1-storage as mmask64, then cvt the mask to the int8-value. + unsigned long long* bit1_ptr = reinterpret_cast(src2); + auto bit1_mask = _cvtu64_mask64(*bit1_ptr); + auto zmm2 = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask, zmm_0x04); + zmm = _mm512_add_epi8(zmm, zmm2); + return zmm; + }; + + auto unpack_elt = row * col; + auto unpack_buf = utils::avector(unpack_elt, 0); + assert(head_ignore_num % 8 == 0); + + if (head_ignore_num > unpack_elt) { + assert(0); + } + auto base_bit2ptr = bit2ptr - head_ignore_num / 4; + auto base_bit1ptr = bit1ptr - head_ignore_num / 8; + auto base_unpack_buf = unpack_buf.data() - head_ignore_num; + int compress_wei_ptr_offset = 0; + if (head_ignore_num != 0) { + auto unpack_mask = _cvtu64_mask64(0xffffffffffffffff << head_ignore_num); + auto head_zmm = bit3_interleave_decompress(base_bit2ptr, base_bit1ptr); + _mm512_mask_storeu_epi8(base_unpack_buf, unpack_mask, head_zmm); + compress_wei_ptr_offset += 64; + } + auto body_loop = (unpack_elt - (64 - head_ignore_num) % 64) / 64; + auto tail_proc_num = (unpack_elt - (64 - head_ignore_num) % 64) % 64; + for (int i = 0; i < body_loop; i++) { + auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, + base_bit1ptr + compress_wei_ptr_offset / 8); + _mm512_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, zmm); + compress_wei_ptr_offset += 64; + } + if (tail_proc_num > 0) { + auto unpack_mask = _cvtu64_mask64(0xffffffffffffffff >> (64 - tail_proc_num)); + auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, + base_bit1ptr + compress_wei_ptr_offset / 8); + _mm512_mask_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, unpack_mask, zmm); + } + + for (int i = 0; i < row; i++) { + int kpos = (k_offset + i) / kblock; + auto sptr = scales + kpos * NPad; + for (int j = 0; j < col; j++) { + float tmp = static_cast(unpack_buf[i * col + j]); + if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j / _PACK_ROW]); + dstptr[i * col + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); + } + } + return BTLA_CODE::Success; } @@ -1270,6 +1334,7 @@ inline BTLA_CODE dq8_get_fp_scale(uint8_t* src, float* dst, int row, int col, in auto mask = _cvtu32_mask16(0xffff >> (16 - col)); get_fp_scale(col, mask, scale_offset, src + i * src_stride, dst + i * dst_stride); } else { + // TODO(zhe): consider head_proc_num==0 case. auto head_mask = _cvtu32_mask16(0xffff >> (16 - head_proc_num)); auto body_mask = _cvtu32_mask16(0xffff); get_fp_scale(head_proc_num, head_mask, scale_offset, src + i * src_stride, dst + i * dst_stride); diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index 19c38e985..e8121bcd3 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -465,13 +465,13 @@ template // zero point class DecompressKBlockS3Fp { template static inline BTLA_CODE forward(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, - int interleave_n_offset, int unpack_elt, _SCA_T* scales, int8_t* zero_points, + int interleave_n_offset, int row, int col, _SCA_T* scales, int8_t* zero_points, int k_offset, int kblock, int NPad) { BTLA_CODE ret = BTLA_CODE::NotSupport; #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { ret = avx512f::decompress_kblock_bit3_packrow_fp( - bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, scales, zero_points, k_offset, kblock, NPad); + bit2ptr, bit1ptr, dstptr, interleave_n_offset, row, col, scales, zero_points, k_offset, kblock, NPad); } #endif return ret; From 8b9fdd94fd43e3d7fddd34c063e30a86d9109668 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Mon, 22 Jan 2024 08:32:03 +0800 Subject: [PATCH 05/33] ut draft --- bestla/bestla/bestla_prologue_b.h | 5 ++++ bestla/bestla/kernel_avx512f.h | 2 +- bestla/bestla/kernel_ref.h | 1 + bestla/bestla/kernel_wrapper.h | 3 ++- bestla/bestla/ut/bestla_prologue_b.cpp | 34 ++++++++++++++++++++++++++ 5 files changed, 43 insertions(+), 2 deletions(-) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 1c1f3492f..a95ae75d5 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -891,6 +891,11 @@ class WeightKBlockNInteger { } else if (quant_dtype == BTLA_DTYPE::S4_CLIP) { kernel::wrapper::QuantizeSignIntRowBlock::forward( srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); + } else if (quant_dtype == BTLA_DTYPE::S3_CLIP) { + kernel::wrapper::QuantizeSignIntRowBlock::forward( + srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize); + } else { + assert(0); } } diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index dc190a2e8..e3dfdc484 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -707,7 +707,7 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr int kpos = (k_offset + i) / kblock; auto sptr = scales + kpos * NPad; for (int j = 0; j < col; j++) { - float tmp = static_cast(unpack_buf[i * col + j]); + float tmp = static_cast(unpack_buf[i * col + j] << 5); // int3_clip => int8 if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j / _PACK_ROW]); dstptr[i * col + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); } diff --git a/bestla/bestla/kernel_ref.h b/bestla/bestla/kernel_ref.h index 5b158fbc5..29aac484b 100644 --- a/bestla/bestla/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -933,6 +933,7 @@ inline BTLA_CODE quantize_f32_sign_int_rowblock(const float* srcptr, int8_t* dst switch (S4_T) { case BTLA_DTYPE::S8: case BTLA_DTYPE::S4_CLIP: + case BTLA_DTYPE::S3_CLIP: if (zero_points == nullptr) { s8_calc_store_scale_and_quantv_sym(blocksize); } else { diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index e8121bcd3..ac4361dfe 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -463,6 +463,7 @@ class DecompressKBlockS4Fp { template // zero points always be int8_t, not compressed class DecompressKBlockS3Fp { + public: template static inline BTLA_CODE forward(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, int interleave_n_offset, int row, int col, _SCA_T* scales, int8_t* zero_points, @@ -476,7 +477,7 @@ class DecompressKBlockS3Fp { #endif return ret; } -} +}; template // zero points always be int8_t, not compressed class DecompressKBlockS4S8Fp { diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index e9610ecae..40edabfe0 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -1,3 +1,4 @@ +#include "bestla.h" #include "bestla_gemm.h" #include "bestla_prologue_b.h" #include "bestla_parallel.h" @@ -172,6 +173,39 @@ class UT_BlockQunatize_F8 { static UT_BlockQunatize_F8 sUT_BlockQunatize_F8; #endif +class UT_BlockQunatize_S3 { + public: + UT_BlockQunatize_S3() { + UT_START(); + CheckISA(AVX512F); + ut(48, 4, 2, BTLA_DTYPE::S3_CLIP); + } + void ut(int n, int k, int blocksize, BTLA_DTYPE QUANT_T) { + printf("%s: %d %d %d\n", __FUNCTION__, n, k, blocksize); + int ldb = n; + utils::aligned_vector raw(n * k); + ut::fill_buffer_randn(raw.data(), raw.size(), -3.f, 3.f); + + auto constexpr RuntimeISA = BTLA_ISA::AVX512F; + using PrologueB = prologue_b::gemm::WeightKBlockNInteger; + PrologueB kernel; + auto ptr = kernel.createStorage(n, k, blocksize, QUANT_T, BTLA_DTYPE::BF16, BTLA_DTYPE::F32, false); + auto ptr_ref = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::BF16, BTLA_DTYPE::F32, false); + avector buffer(ptr.mSize); + avector buffer_ref(ptr_ref.mSize); + ptr.assign(buffer.data()); + ptr_ref.assign(buffer_ref.data()); + kernel.packWeight(n, k, raw.data(), ldb, &ptr, &DefaultThreading); + kernel.packWeight(n, k, raw.data(), ldb, &ptr_ref, &DefaultThreading); + avector dequant(n * k, 0); + avector dequant_ref(n * k, 0); + kernel.unpackWeight(n, k, &ptr, dequant.data(), n, &DefaultThreading); + kernel.unpackWeight(n, k, &ptr_ref, dequant_ref.data(), n, &DefaultThreading); + } +}; +#ifdef BTLA_UT_PROLOGUE_B +static UTBUT_BlockQunatize_S3 sUT_BlockQunatize_S3; +#endif class UT_TransposeBlockQuantize_F4 { public: UT_TransposeBlockQuantize_F4() { From 5891801d8e17c0983b09acbb1a77d5e07f35a7a6 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Mon, 22 Jan 2024 14:32:28 +0800 Subject: [PATCH 06/33] todo: fix n==64 bug --- bestla/bestla/bestla_prologue_b.h | 5 +-- bestla/bestla/kernel_avx512f.h | 6 +++- bestla/bestla/kernel_ref.h | 48 +++++++++++++++++++------- bestla/bestla/ut/bestla_prologue_b.cpp | 29 ++++++++++++---- 4 files changed, 65 insertions(+), 23 deletions(-) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index a95ae75d5..445df7dfe 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -813,10 +813,11 @@ class WeightKBlockNInteger { } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { int8_t* bit3_ptr = wptr->template WPtr(); auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad; - auto ld_dst = utils::padto(NPad, 64); + auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(KPad, 64); + auto row = NPad / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE); assert(elt_offset % 8 == 0); auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); - auto bit1ptr = reinterpret_cast(bit3_ptr + wptr->mKPad * ld_dst / 4 + elt_offset / 8); + auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); kernel::wrapper::DecompressKBlockS3Fp<_T, _GemmCore_T::PACK_ROW>::template forward( bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, k_size / _GemmCore_T::PACK_ROW, diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index e3dfdc484..fe965f3ee 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -652,6 +652,7 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr auto head_ignore_num = interleave_n_offset % 64; auto zmm_0x04 = _mm512_set1_epi8(0x04); auto zmm_0x00 = _mm512_set1_epi8(0x00); + auto zmm_shift = _mm512_set1_epi32(5); auto bit3_interleave_decompress = [&](utils::bit2x4* src1, utils::bit1x8* src2) { const __m128i lowMask = _mm_set1_epi8(0x03); @@ -668,6 +669,7 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr auto bit1_mask = _cvtu64_mask64(*bit1_ptr); auto zmm2 = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask, zmm_0x04); zmm = _mm512_add_epi8(zmm, zmm2); + zmm = _mm512_sllv_epi32(zmm, zmm_shift); // int3_clip => int8 return zmm; }; @@ -694,6 +696,8 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, base_bit1ptr + compress_wei_ptr_offset / 8); _mm512_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, zmm); + // int8_t* test = reinterpret_cast(&zmm); + // for (int j = 0; j < 64; j++) std::cout << int(test[j]) << std::endl; compress_wei_ptr_offset += 64; } if (tail_proc_num > 0) { @@ -707,7 +711,7 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr int kpos = (k_offset + i) / kblock; auto sptr = scales + kpos * NPad; for (int j = 0; j < col; j++) { - float tmp = static_cast(unpack_buf[i * col + j] << 5); // int3_clip => int8 + float tmp = static_cast(unpack_buf[i * col + j]); if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j / _PACK_ROW]); dstptr[i * col + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); } diff --git a/bestla/bestla/kernel_ref.h b/bestla/bestla/kernel_ref.h index 29aac484b..ab601842a 100644 --- a/bestla/bestla/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -178,31 +178,53 @@ static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, i return BTLA_CODE::Success; } +// #include static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int row, int col, int ld_src, int ld_dst) { assert(col % 64 == 0); // interleave + store 2bit. + + // for (int i = 0; i < row * col; i++) { + // char tmp; + // memcpy(&tmp, &srcptr[i], 1); + // tmp &= 0xe0; + // std::cout << int(*(reinterpret_cast(&tmp))) << std::endl; + // } + // std::cout << "==============" << std::endl; + + auto bit2_interleave = [&](int8_t* src, int8_t* dst) { + for (int i = 0; i < 64 / 4; i++) { + dst[4 * i] = src[i]; + dst[4 * i + 1] = src[64 / 4 + i]; + dst[4 * i + 2] = src[64 / 4 * 2 + i]; + dst[4 * i + 3] = src[64 / 4 * 3 + i]; + } + }; + + int8_t interleave_buf[64]; + for (int i = 0; i < row; i++) { - for (int j = 0; j < col; i += 64) { + for (int j = 0; j < col; j += 64) { + bit2_interleave(const_cast(srcptr + i * ld_src + j), interleave_buf); for (int k = 0; k < 16; k++) { - bit2ptr[i * ld_dst / 4 + i / 4 + k].a = srcptr[i * ld_src + col + 4 * k]; - bit2ptr[i * ld_dst / 4 + i / 4 + k].b = srcptr[i * ld_src + col + 4 * k + 1]; - bit2ptr[i * ld_dst / 4 + i / 4 + k].c = srcptr[i * ld_src + col + 4 * k + 2]; - bit2ptr[i * ld_dst / 4 + i / 4 + k].d = srcptr[i * ld_src + col + 4 * k + 3]; + bit2ptr[i * ld_dst / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5; + bit2ptr[i * ld_dst / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5; + bit2ptr[i * ld_dst / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5; + bit2ptr[i * ld_dst / 4 + j / 4 + k].d = interleave_buf[4 * k + 3] >> 5; } } } // store 1 bit without interleave as mask. for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += 8) { - bit1ptr[i * ld_dst].a = srcptr[i * ld_src + j] >> 2; - bit1ptr[i * ld_dst].b = srcptr[i * ld_src + j + 1] >> 2; - bit1ptr[i * ld_dst].c = srcptr[i * ld_src + j + 2] >> 2; - bit1ptr[i * ld_dst].d = srcptr[i * ld_src + j + 3] >> 2; - bit1ptr[i * ld_dst].e = srcptr[i * ld_src + j + 4] >> 2; - bit1ptr[i * ld_dst].f = srcptr[i * ld_src + j + 5] >> 2; - bit1ptr[i * ld_dst].g = srcptr[i * ld_src + j + 6] >> 2; - bit1ptr[i * ld_dst].h = srcptr[i * ld_src + j + 7] >> 2; + bit1ptr[i * ld_dst / 8 + j / 8].a = srcptr[i * ld_src + j] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].b = srcptr[i * ld_src + j + 1] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].c = srcptr[i * ld_src + j + 2] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].d = srcptr[i * ld_src + j + 3] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].e = srcptr[i * ld_src + j + 4] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].f = srcptr[i * ld_src + j + 5] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].g = srcptr[i * ld_src + j + 6] >> 7; + bit1ptr[i * ld_dst / 8 + j / 8].h = srcptr[i * ld_src + j + 7] >> 7; } } return BTLA_CODE::Success; diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 40edabfe0..682822d23 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -178,13 +178,20 @@ class UT_BlockQunatize_S3 { UT_BlockQunatize_S3() { UT_START(); CheckISA(AVX512F); - ut(48, 4, 2, BTLA_DTYPE::S3_CLIP); + ut(64, 4, 4, BTLA_DTYPE::S3_CLIP); } void ut(int n, int k, int blocksize, BTLA_DTYPE QUANT_T) { + DefaultThreading.set_threads(1); printf("%s: %d %d %d\n", __FUNCTION__, n, k, blocksize); int ldb = n; - utils::aligned_vector raw(n * k); - ut::fill_buffer_randn(raw.data(), raw.size(), -3.f, 3.f); + + int kblk_num = utils::updiv(k, blocksize); + utils::aligned_vector scales(kblk_num * n); + ut::fill_buffer_randn(scales.data(), scales.size(), 0.005f, 0.01f); + ut::UT_vector_s8 quanW; + quanW.resize(k * n); + quanW.fill_rand(-8, 7); + for (int i = 0; i < k * n; i++) quanW.data()[i] = (quanW.data()[i] * 16) & 0xe0; auto constexpr RuntimeISA = BTLA_ISA::AVX512F; using PrologueB = prologue_b::gemm::WeightKBlockNInteger; @@ -195,16 +202,25 @@ class UT_BlockQunatize_S3 { avector buffer_ref(ptr_ref.mSize); ptr.assign(buffer.data()); ptr_ref.assign(buffer_ref.data()); - kernel.packWeight(n, k, raw.data(), ldb, &ptr, &DefaultThreading); - kernel.packWeight(n, k, raw.data(), ldb, &ptr_ref, &DefaultThreading); + kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr, &DefaultThreading); + kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr_ref, &DefaultThreading); avector dequant(n * k, 0); avector dequant_ref(n * k, 0); kernel.unpackWeight(n, k, &ptr, dequant.data(), n, &DefaultThreading); kernel.unpackWeight(n, k, &ptr_ref, dequant_ref.data(), n, &DefaultThreading); + for (int i = 0; i < k; i++) { + for (int j = 0; j < n; j++) { + if ((dequant[i * n + j] - dequant_ref[i * n + j]) != 0) { + std::cout << "i: " << i << " j:" << j << std::endl; + std::cout << dequant[i * n + j] << " vs " << dequant_ref[i * n + j] << std::endl; + } + } + } } }; +static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; #ifdef BTLA_UT_PROLOGUE_B -static UTBUT_BlockQunatize_S3 sUT_BlockQunatize_S3; +static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; #endif class UT_TransposeBlockQuantize_F4 { public: @@ -549,7 +565,6 @@ class UT_ShuffleIndices { delete wptr; } }; -static UT_ShuffleIndices sUT_ShuffleIndices; #ifdef BTLA_UT_PROLOGUE_B static UT_ShuffleIndices sUT_ShuffleIndices; #endif From 204739c459916a396af27096ee4192e25229f2df Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Mon, 22 Jan 2024 15:27:39 +0800 Subject: [PATCH 07/33] todo: more getwei & gemm ut --- bestla/bestla/bestla_prologue_b.h | 3 ++- bestla/bestla/ut/bestla_prologue_b.cpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 445df7dfe..1e155aa08 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -812,7 +812,8 @@ class WeightKBlockNInteger { wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { int8_t* bit3_ptr = wptr->template WPtr(); - auto elt_offset = n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad; + auto elt_offset = + n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 64); auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(KPad, 64); auto row = NPad / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE); assert(elt_offset % 8 == 0); diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 682822d23..f98f11258 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -178,10 +178,10 @@ class UT_BlockQunatize_S3 { UT_BlockQunatize_S3() { UT_START(); CheckISA(AVX512F); - ut(64, 4, 4, BTLA_DTYPE::S3_CLIP); + ut(1024, 1024, 32, BTLA_DTYPE::S3_CLIP); + ut(4128, 4096, 32, BTLA_DTYPE::S3_CLIP); } void ut(int n, int k, int blocksize, BTLA_DTYPE QUANT_T) { - DefaultThreading.set_threads(1); printf("%s: %d %d %d\n", __FUNCTION__, n, k, blocksize); int ldb = n; From 1f99030dfafe296df38ca0f8cba824a8d3c5eebf Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Mon, 22 Jan 2024 17:10:34 +0800 Subject: [PATCH 08/33] todo: more cmpt gemm ut --- bestla/bestla/bestla_prologue_b.h | 16 +++++++- bestla/bestla/kernel_avx512f.h | 32 +++++++++------ bestla/bestla/kernel_wrapper.h | 24 +++++++++++- bestla/bestla/ut/bestla_prologue_b.cpp | 54 +++++++++++++++++++------- 4 files changed, 96 insertions(+), 30 deletions(-) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 1e155aa08..ebaaa5314 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -743,6 +743,20 @@ class WeightKBlockNInteger { kernel::wrapper::DecompressKBlockS8S8Fp::template forward( wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { + int8_t* bit3_ptr = wptr->template WPtr(); + auto elt_offset = + n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 64); + auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(KPad, 64); + auto row = NPad / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE); + assert(elt_offset % 8 == 0); + auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); + auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); + kernel::wrapper::DecompressKBlockS3S8Fp::template forward( + bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, + k_size / _GemmCore_T::PACK_ROW * ColSize, tmpcache, cachesize); + } else { + assert(0); } } } @@ -823,7 +837,7 @@ class WeightKBlockNInteger { BTLA_DTYPE::S3_CLIP>( bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, k_size / _GemmCore_T::PACK_ROW, ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, - wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad); + wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); } else { assert(0); } diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index fe965f3ee..90c3c4711 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -644,11 +644,9 @@ static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* d return BTLA_CODE::NotSupport; } -template -static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, - _DST_T* dstptr, int interleave_n_offset, int row, int col, - _ST* scales, int8_t* zero_points, int k_offset, int kblock, - int NPad) { +template +inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, + int interleave_n_offset, int unpack_elt, int8_t* tmp, size_t tmpsize) { auto head_ignore_num = interleave_n_offset % 64; auto zmm_0x04 = _mm512_set1_epi8(0x04); auto zmm_0x00 = _mm512_set1_epi8(0x00); @@ -673,8 +671,6 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr return zmm; }; - auto unpack_elt = row * col; - auto unpack_buf = utils::avector(unpack_elt, 0); assert(head_ignore_num % 8 == 0); if (head_ignore_num > unpack_elt) { @@ -682,9 +678,10 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr } auto base_bit2ptr = bit2ptr - head_ignore_num / 4; auto base_bit1ptr = bit1ptr - head_ignore_num / 8; - auto base_unpack_buf = unpack_buf.data() - head_ignore_num; + auto base_unpack_buf = tmp - head_ignore_num; int compress_wei_ptr_offset = 0; if (head_ignore_num != 0) { + assert(0); auto unpack_mask = _cvtu64_mask64(0xffffffffffffffff << head_ignore_num); auto head_zmm = bit3_interleave_decompress(base_bit2ptr, base_bit1ptr); _mm512_mask_storeu_epi8(base_unpack_buf, unpack_mask, head_zmm); @@ -695,23 +692,34 @@ static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr for (int i = 0; i < body_loop; i++) { auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, base_bit1ptr + compress_wei_ptr_offset / 8); - _mm512_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, zmm); - // int8_t* test = reinterpret_cast(&zmm); - // for (int j = 0; j < 64; j++) std::cout << int(test[j]) << std::endl; + _mm512_storeu_epi8(base_unpack_buf, zmm); + for (int j = 0; j < 64; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j); compress_wei_ptr_offset += 64; } if (tail_proc_num > 0) { + assert(0); auto unpack_mask = _cvtu64_mask64(0xffffffffffffffff >> (64 - tail_proc_num)); auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, base_bit1ptr + compress_wei_ptr_offset / 8); _mm512_mask_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, unpack_mask, zmm); } + return BTLA_CODE::Success; +} +template +static inline BTLA_CODE decompress_kblock_bit3_packrow_fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, + _DST_T* dstptr, int interleave_n_offset, int row, int col, + _ST* scales, int8_t* zero_points, int k_offset, int kblock, + int NPad, void* tmp, size_t tmpsize) { + auto unpack_elt = row * col; + decompress_kblock_s3_s8fp<_S3_T>(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, + reinterpret_cast(tmp), tmpsize); + // TODO(zhe): simd version for (int i = 0; i < row; i++) { int kpos = (k_offset + i) / kblock; auto sptr = scales + kpos * NPad; for (int j = 0; j < col; j++) { - float tmp = static_cast(unpack_buf[i * col + j]); + float tmp = static_cast(dstptr[i * col + j]); if (zero_points != nullptr) tmp -= static_cast(zero_points[kpos * NPad + j / _PACK_ROW]); dstptr[i * col + j] = static_cast<_DST_T>(tmp * sptr[j / _PACK_ROW]); } diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index ac4361dfe..48ebefe9e 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -467,14 +467,16 @@ class DecompressKBlockS3Fp { template static inline BTLA_CODE forward(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, int interleave_n_offset, int row, int col, _SCA_T* scales, int8_t* zero_points, - int k_offset, int kblock, int NPad) { + int k_offset, int kblock, int NPad, void* tmp, size_t tmpsize) { BTLA_CODE ret = BTLA_CODE::NotSupport; #if CompileAVX512F() if constexpr (utils::isa_base::avx512f) { ret = avx512f::decompress_kblock_bit3_packrow_fp( - bit2ptr, bit1ptr, dstptr, interleave_n_offset, row, col, scales, zero_points, k_offset, kblock, NPad); + bit2ptr, bit1ptr, dstptr, interleave_n_offset, row, col, scales, zero_points, k_offset, kblock, NPad, tmp, + tmpsize); } #endif + assert(ret == BTLA_CODE::Success); return ret; } }; @@ -501,6 +503,24 @@ class DecompressKBlockS4S8Fp { } }; +template +class DecompressKBlockS3S8Fp { + public: + template + static inline BTLA_CODE forward(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, + int interleave_n_offset, int unpack_elt, void* tmp, size_t tmpsize) { + BTLA_CODE ret = BTLA_CODE::NotSupport; +#if CompileAVX512F() + if constexpr (utils::isa_base::avx512f) { + ret = avx512f::decompress_kblock_s3_s8fp(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, + reinterpret_cast(tmp), tmpsize); + } +#endif + assert(ret == BTLA_CODE::Success); + return ret; + } +}; + template class DecompressKBlockF4Fp { public: diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index f98f11258..eaddb19c2 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -178,10 +178,12 @@ class UT_BlockQunatize_S3 { UT_BlockQunatize_S3() { UT_START(); CheckISA(AVX512F); - ut(1024, 1024, 32, BTLA_DTYPE::S3_CLIP); - ut(4128, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut(128, 1024, 1024, 32, BTLA_DTYPE::S3_CLIP); + // ut(1024, 4128, 4096, 32, BTLA_DTYPE::S3_CLIP); } - void ut(int n, int k, int blocksize, BTLA_DTYPE QUANT_T) { + + template + void ut(int m, int n, int k, int blocksize, BTLA_DTYPE QUANT_T) { printf("%s: %d %d %d\n", __FUNCTION__, n, k, blocksize); int ldb = n; @@ -195,6 +197,14 @@ class UT_BlockQunatize_S3 { auto constexpr RuntimeISA = BTLA_ISA::AVX512F; using PrologueB = prologue_b::gemm::WeightKBlockNInteger; + using Launcher = + wrapper::gemm::LauncherKBlock; + using Parallel = parallel::gemm::SchedulerKBlock; + + Launcher launcher; + PrologueB kernel; auto ptr = kernel.createStorage(n, k, blocksize, QUANT_T, BTLA_DTYPE::BF16, BTLA_DTYPE::F32, false); auto ptr_ref = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::BF16, BTLA_DTYPE::F32, false); @@ -204,18 +214,32 @@ class UT_BlockQunatize_S3 { ptr_ref.assign(buffer_ref.data()); kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr, &DefaultThreading); kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr_ref, &DefaultThreading); - avector dequant(n * k, 0); - avector dequant_ref(n * k, 0); - kernel.unpackWeight(n, k, &ptr, dequant.data(), n, &DefaultThreading); - kernel.unpackWeight(n, k, &ptr_ref, dequant_ref.data(), n, &DefaultThreading); - for (int i = 0; i < k; i++) { - for (int j = 0; j < n; j++) { - if ((dequant[i * n + j] - dequant_ref[i * n + j]) != 0) { - std::cout << "i: " << i << " j:" << j << std::endl; - std::cout << dequant[i * n + j] << " vs " << dequant_ref[i * n + j] << std::endl; - } - } - } + + avector matAf32(m * k), matC(m * n), refC(m * n); + fill_buffer_randn(matAf32.data(), matAf32.size(), -0.5f, 0.5f); + utils::GemmProblem gp(1, m, n, k, blocksize); + typename Launcher::Param args{ + gp, {matAf32.data(), k}, {&ptr}, {ptr.template SPtr(), ptr.SDtype(), ptr.CStep()}, {matC.data(), n}}; + parallel::GemmRun(launcher, args, &DefaultThreading); + typename Launcher::Param args_ref{gp, + {matAf32.data(), k}, + {&ptr_ref}, + {ptr_ref.template SPtr(), ptr_ref.SDtype(), ptr_ref.CStep()}, + {refC.data(), n}}; + parallel::GemmRun(launcher, args_ref, &DefaultThreading); + buffer_error(matC.data(), refC.data(), matC.size(), 0.001f); + // avector dequant(n * k, 0); + // avector dequant_ref(n * k, 0); + // kernel.unpackWeight(n, k, &ptr, dequant.data(), n, &DefaultThreading); + // kernel.unpackWeight(n, k, &ptr_ref, dequant_ref.data(), n, &DefaultThreading); + // for (int i = 0; i < k; i++) { + // for (int j = 0; j < n; j++) { + // if ((dequant[i * n + j] - dequant_ref[i * n + j]) != 0) { + // std::cout << "i: " << i << " j:" << j << std::endl; + // std::cout << dequant[i * n + j] << " vs " << dequant_ref[i * n + j] << std::endl; + // } + // } + // } } }; static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; From 3290f51feccbf2aac445ff0cb3ac037104fe9563 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Mon, 22 Jan 2024 17:49:44 +0800 Subject: [PATCH 09/33] bf16-cmpt ut --- bestla/bestla/ut/bestla_prologue_b.cpp | 78 ++++++++++++++++++-------- 1 file changed, 56 insertions(+), 22 deletions(-) diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index eaddb19c2..38227dc00 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -178,11 +178,21 @@ class UT_BlockQunatize_S3 { UT_BlockQunatize_S3() { UT_START(); CheckISA(AVX512F); - ut(128, 1024, 1024, 32, BTLA_DTYPE::S3_CLIP); - // ut(1024, 4128, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut<0>(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + ut<0>(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + ut<0>(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut<0>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut<0>(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut<0>(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut<1>(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + ut<1>(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + ut<1>(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut<1>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut<1>(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut<1>(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); } - template + template // 0: fp32 1: bf16 2: int8. void ut(int m, int n, int k, int blocksize, BTLA_DTYPE QUANT_T) { printf("%s: %d %d %d\n", __FUNCTION__, n, k, blocksize); int ldb = n; @@ -197,13 +207,6 @@ class UT_BlockQunatize_S3 { auto constexpr RuntimeISA = BTLA_ISA::AVX512F; using PrologueB = prologue_b::gemm::WeightKBlockNInteger; - using Launcher = - wrapper::gemm::LauncherKBlock; - using Parallel = parallel::gemm::SchedulerKBlock; - - Launcher launcher; PrologueB kernel; auto ptr = kernel.createStorage(n, k, blocksize, QUANT_T, BTLA_DTYPE::BF16, BTLA_DTYPE::F32, false); @@ -215,18 +218,49 @@ class UT_BlockQunatize_S3 { kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr, &DefaultThreading); kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr_ref, &DefaultThreading); - avector matAf32(m * k), matC(m * n), refC(m * n); - fill_buffer_randn(matAf32.data(), matAf32.size(), -0.5f, 0.5f); - utils::GemmProblem gp(1, m, n, k, blocksize); - typename Launcher::Param args{ - gp, {matAf32.data(), k}, {&ptr}, {ptr.template SPtr(), ptr.SDtype(), ptr.CStep()}, {matC.data(), n}}; - parallel::GemmRun(launcher, args, &DefaultThreading); - typename Launcher::Param args_ref{gp, - {matAf32.data(), k}, - {&ptr_ref}, - {ptr_ref.template SPtr(), ptr_ref.SDtype(), ptr_ref.CStep()}, - {refC.data(), n}}; - parallel::GemmRun(launcher, args_ref, &DefaultThreading); + avector matC(m * n), refC(m * n); + if constexpr (CMPT_MODE == 0) { + using Launcher = + wrapper::gemm::LauncherKBlock; + using Parallel = parallel::gemm::SchedulerKBlock; + + Launcher launcher; + avector matAf32(m * k); + fill_buffer_randn(matAf32.data(), matAf32.size(), -0.5f, 0.5f); + utils::GemmProblem gp(1, m, n, k, blocksize); + typename Launcher::Param args{ + gp, {matAf32.data(), k}, {&ptr}, {ptr.template SPtr(), ptr.SDtype(), ptr.CStep()}, {matC.data(), n}}; + parallel::GemmRun(launcher, args, &DefaultThreading); + typename Launcher::Param args_ref{gp, + {matAf32.data(), k}, + {&ptr_ref}, + {ptr_ref.template SPtr(), ptr_ref.SDtype(), ptr_ref.CStep()}, + {refC.data(), n}}; + parallel::GemmRun(launcher, args_ref, &DefaultThreading); + } else if constexpr (CMPT_MODE == 1) { + using Launcher = + wrapper::gemm::LauncherKBlock; + using Parallel = parallel::gemm::SchedulerKBlock; + + Launcher launcher; + avector matAbf16(m * k); + fill_buffer_randn(matAbf16.data(), matAbf16.size(), utils::bf16(-0.5f), utils::bf16(0.5f)); + GemmProblem gp(1, m, n, k, blocksize); + typename Launcher::Param args{ + gp, {matAbf16.data(), k}, {&ptr}, {ptr.template SPtr(), ptr.SDtype(), ptr.CStep()}, {matC.data(), n}}; + parallel::GemmRun(launcher, args, &DefaultThreading); + typename Launcher::Param args_ref{gp, + {matAbf16.data(), k}, + {&ptr_ref}, + {ptr_ref.template SPtr(), ptr_ref.SDtype(), ptr_ref.CStep()}, + {matC.data(), n}}; + parallel::GemmRun(launcher, args_ref, &DefaultThreading); + } else { + } buffer_error(matC.data(), refC.data(), matC.size(), 0.001f); // avector dequant(n * k, 0); // avector dequant_ref(n * k, 0); From cf332944c2d0dfa66d48c3418e975972a1bc1e07 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Mon, 22 Jan 2024 18:04:37 +0800 Subject: [PATCH 10/33] refine ut --- bestla/bestla/ut/bestla_prologue_b.cpp | 56 +++++++++++--------------- 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 38227dc00..59b2c7799 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -178,22 +178,23 @@ class UT_BlockQunatize_S3 { UT_BlockQunatize_S3() { UT_START(); CheckISA(AVX512F); - ut<0>(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - ut<0>(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - ut<0>(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut<0>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut<0>(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut<0>(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut<1>(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - ut<1>(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - ut<1>(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut<1>(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut<1>(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut<1>(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + ut(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + ut(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + ut(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + ut(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); } - template // 0: fp32 1: bf16 2: int8. + template void ut(int m, int n, int k, int blocksize, BTLA_DTYPE QUANT_T) { + DefaultThreading.set_threads(64); printf("%s: %d %d %d\n", __FUNCTION__, n, k, blocksize); int ldb = n; @@ -205,8 +206,7 @@ class UT_BlockQunatize_S3 { quanW.fill_rand(-8, 7); for (int i = 0; i < k * n; i++) quanW.data()[i] = (quanW.data()[i] * 16) & 0xe0; - auto constexpr RuntimeISA = BTLA_ISA::AVX512F; - using PrologueB = prologue_b::gemm::WeightKBlockNInteger; + using PrologueB = prologue_b::gemm::WeightKBlockNInteger; PrologueB kernel; auto ptr = kernel.createStorage(n, k, blocksize, QUANT_T, BTLA_DTYPE::BF16, BTLA_DTYPE::F32, false); @@ -217,16 +217,15 @@ class UT_BlockQunatize_S3 { ptr_ref.assign(buffer_ref.data()); kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr, &DefaultThreading); kernel.packQWeight(n, k, quanW.data(), ldb, scales.data(), nullptr, &ptr_ref, &DefaultThreading); + using Launcher = + wrapper::gemm::LauncherKBlock; + using Parallel = parallel::gemm::SchedulerKBlock; + Launcher launcher; avector matC(m * n), refC(m * n); - if constexpr (CMPT_MODE == 0) { - using Launcher = - wrapper::gemm::LauncherKBlock; - using Parallel = parallel::gemm::SchedulerKBlock; - - Launcher launcher; + if constexpr (ISA == BTLA_ISA::AVX512F) { avector matAf32(m * k); fill_buffer_randn(matAf32.data(), matAf32.size(), -0.5f, 0.5f); utils::GemmProblem gp(1, m, n, k, blocksize); @@ -239,14 +238,7 @@ class UT_BlockQunatize_S3 { {ptr_ref.template SPtr(), ptr_ref.SDtype(), ptr_ref.CStep()}, {refC.data(), n}}; parallel::GemmRun(launcher, args_ref, &DefaultThreading); - } else if constexpr (CMPT_MODE == 1) { - using Launcher = - wrapper::gemm::LauncherKBlock; - using Parallel = parallel::gemm::SchedulerKBlock; - - Launcher launcher; + } else if constexpr (ISA == BTLA_ISA::AMX_BF16) { avector matAbf16(m * k); fill_buffer_randn(matAbf16.data(), matAbf16.size(), utils::bf16(-0.5f), utils::bf16(0.5f)); GemmProblem gp(1, m, n, k, blocksize); @@ -257,7 +249,7 @@ class UT_BlockQunatize_S3 { {matAbf16.data(), k}, {&ptr_ref}, {ptr_ref.template SPtr(), ptr_ref.SDtype(), ptr_ref.CStep()}, - {matC.data(), n}}; + {refC.data(), n}}; parallel::GemmRun(launcher, args_ref, &DefaultThreading); } else { } From 9a50a605466d78dbd58a23b422f46c820ae687ca Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Mon, 22 Jan 2024 18:14:07 +0800 Subject: [PATCH 11/33] fix parallel --- bestla/bestla/ut/bestla_prologue_b.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 59b2c7799..7b789c3aa 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -221,7 +221,7 @@ class UT_BlockQunatize_S3 { wrapper::gemm::LauncherKBlock; - using Parallel = parallel::gemm::SchedulerKBlock; + using Parallel = parallel::gemm::SchedulerKBlock; Launcher launcher; avector matC(m * n), refC(m * n); From f2fcbc0efd4633d8e79072d74960001aae7d13f6 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Mon, 22 Jan 2024 20:38:33 +0800 Subject: [PATCH 12/33] add int8-cmpt ut, todo: re-pad N-dim --- bestla/bestla/bestla_prologue_b.h | 27 ++++++++++++++- bestla/bestla/kernel_avx512f.h | 15 ++++++-- bestla/bestla/ut/bestla_prologue_b.cpp | 48 +++++++++++++++++++------- 3 files changed, 74 insertions(+), 16 deletions(-) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index ebaaa5314..7da9757e4 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -564,7 +564,7 @@ class WeightKBlockNInteger { // N==NPad, K==Kpad, ldb==Kpad auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(K, 64); auto col = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * K; - assert(N % (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE) == 0); + assert(N % (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE) == 0); // TODO(zhe): consider N pad to packrow*ntile? auto row = N / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE); auto pad_64_buf = utils::avector(row * ld_dst, 0); kernel::wrapper::Memcpy2D::forward(B, pad_64_buf.data(), row, col, col, ld_dst); @@ -631,6 +631,8 @@ class WeightKBlockNInteger { return getQ8Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } else if (wptr->mDType == BTLA_DTYPE::S4_CLIP || wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { return getQ4Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { + return getQ3Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize); } else { assert(0); } @@ -894,6 +896,29 @@ class WeightKBlockNInteger { return BTLA_CODE::Success; } + static inline BTLA_CODE getQ3Weight(int8_t** dstptr, int* dststep, int k_size, int n_size, int k_offset, int n_offset, + const Param& _param, void* tmpcache, size_t cachesize) { + auto wptr = _param.packedW; + int8_t* bit3_ptr = wptr->template WPtr(); + auto KPad = wptr->mKPad; + auto NPad = wptr->mNPad; + int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; + auto base_offset = n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE; + for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { + auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(KPad, 64); + auto elt_offset = base_offset + i * utils::padto(KPad, 64); + auto row = NPad / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE); + assert(elt_offset % 8 == 0); + auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); + auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); + kernel::wrapper::DecompressKBlockS3S8Fp::template forward( + bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, + k_size / _GemmCore_T::PACK_ROW * ColSize, reinterpret_cast(tmpcache), cachesize); + } + *dststep = k_size; + return BTLA_CODE::Success; + } + virtual inline void quantRowBlock(const float* srcptr, int8_t* dstptr, int row, int col, int ld_src, int ld_dst, float* scales, int8_t* zero_points, void* stor) { auto ptr = reinterpret_cast(stor); diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index 90c3c4711..5b673a2da 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -678,7 +678,12 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 } auto base_bit2ptr = bit2ptr - head_ignore_num / 4; auto base_bit1ptr = bit1ptr - head_ignore_num / 8; - auto base_unpack_buf = tmp - head_ignore_num; + int8_t* base_unpack_buf; + if constexpr (std::is_same_v<_DST_T, int8_t>) { + base_unpack_buf = dstptr - head_ignore_num; + } else { + base_unpack_buf = tmp - head_ignore_num; + } int compress_wei_ptr_offset = 0; if (head_ignore_num != 0) { assert(0); @@ -692,8 +697,12 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 for (int i = 0; i < body_loop; i++) { auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, base_bit1ptr + compress_wei_ptr_offset / 8); - _mm512_storeu_epi8(base_unpack_buf, zmm); - for (int j = 0; j < 64; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j); + if constexpr (!std::is_same_v<_DST_T, int8_t>) { + _mm512_storeu_epi8(base_unpack_buf, zmm); + for (int j = 0; j < 64; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j); + } else { + _mm512_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, zmm); + } compress_wei_ptr_offset += 64; } if (tail_proc_num > 0) { diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 7b789c3aa..c8eceef22 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -178,18 +178,24 @@ class UT_BlockQunatize_S3 { UT_BlockQunatize_S3() { UT_START(); CheckISA(AVX512F); - ut(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - ut(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - ut(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - ut(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - ut(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + // ut(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + // ut(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + // ut(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + // ut(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + // ut(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + // ut(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + // ut(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + // ut(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + // ut(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + // ut(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + // ut(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + // ut(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + // ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + // ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); + // ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); + // ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + // ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); } template @@ -252,6 +258,24 @@ class UT_BlockQunatize_S3 { {refC.data(), n}}; parallel::GemmRun(launcher, args_ref, &DefaultThreading); } else { + using Launcher2 = wrapper::gemm::LauncherIntKBlock; + Launcher2 launcher; + using Parallel2 = parallel::gemm::SchedulerKBlockS; + avector matAf32(m * k); + fill_buffer_randn(matAf32.data(), matAf32.size(), -0.5f, 0.5f); + auto quanA = launcher.mProA.createStorage(m, k, blocksize, false); + auto quanA_ref = launcher.mProA.createStorage(m, k, blocksize, false); + utils::avector bufferA(quanA.mSize); + utils::avector bufferA_ref(quanA.mSize); + quanA.assign(bufferA.data()); + quanA_ref.assign(bufferA_ref.data()); + GemmProblem gp(1, m, n, k, blocksize); + typename Launcher2::Param args{gp, {matAf32.data(), k, &quanA}, {&ptr}, {matC.data(), n}}; + parallel::GemmRunWithA(launcher, args, &DefaultThreading); + typename Launcher2::Param args_ref{gp, {matAf32.data(), k, &quanA_ref}, {&ptr_ref}, {refC.data(), n}}; + parallel::GemmRunWithA(launcher, args_ref, &DefaultThreading); } buffer_error(matC.data(), refC.data(), matC.size(), 0.001f); // avector dequant(n * k, 0); From bfce22f39b9fee3b84fb63bc34968648f3fc8ab8 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Tue, 23 Jan 2024 09:27:29 +0800 Subject: [PATCH 13/33] s3_clip N-dim pad to NTILE*PACK_ROW, todo: add spr-56 core case --- bestla/bestla/bestla_prologue_b.h | 5 +++-- bestla/bestla/ut/bestla_prologue_b.cpp | 10 +++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 7da9757e4..dde52af79 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -122,6 +122,7 @@ class WeightKBlockNInteger { bool is_asym) { int KPad = utils::padto(k, _GemmCore_T::KTILE); int NPad = utils::padto(n, _GemmCore_T::NTILE); + if (qtype == BTLA_DTYPE::S3_CLIP) NPad = utils::padto(n, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW); StorageWeight tmp(_GemmCore_T::ID); tmp.resize(NPad, KPad, blocksize <= 0 ? KPad : blocksize, n, k, qtype, scat, redt, is_asym); return tmp; @@ -903,11 +904,11 @@ class WeightKBlockNInteger { auto KPad = wptr->mKPad; auto NPad = wptr->mNPad; int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; + auto row = NPad / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE); + auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(KPad, 64); auto base_offset = n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(KPad, 64); auto elt_offset = base_offset + i * utils::padto(KPad, 64); - auto row = NPad / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE); assert(elt_offset % 8 == 0); auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index c8eceef22..7f2557853 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -190,12 +190,12 @@ class UT_BlockQunatize_S3 { // ut(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); // ut(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); // ut(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); - // ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - // ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - // ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, BTLA_DTYPE::S3_CLIP); + ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, BTLA_DTYPE::S3_CLIP); + ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); - // ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); - // ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); + ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, BTLA_DTYPE::S3_CLIP); + ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, BTLA_DTYPE::S3_CLIP); } template From 70a9487c6e46e110e4f086fddcc4c58f38bf42b7 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Tue, 23 Jan 2024 09:39:35 +0800 Subject: [PATCH 14/33] better ut --- bestla/bestla/ut/bestla_prologue_b.cpp | 50 ++++++++++++++------------ 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 7f2557853..bd681c7ee 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -178,30 +178,36 @@ class UT_BlockQunatize_S3 { UT_BlockQunatize_S3() { UT_START(); CheckISA(AVX512F); - // ut(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - // ut(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - // ut(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); - // ut(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); - // ut(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); - // ut(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); - // ut(128, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - // ut(1, 4096, 16384, 32, BTLA_DTYPE::S3_CLIP); - // ut(128, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); - // ut(1, 4096, 4096, 32, BTLA_DTYPE::S3_CLIP); - // ut(128, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); - // ut(1, 16384, 4096, 32, BTLA_DTYPE::S3_CLIP); - ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, BTLA_DTYPE::S3_CLIP); - ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, BTLA_DTYPE::S3_CLIP); - ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); - ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); - ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, BTLA_DTYPE::S3_CLIP); - ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, BTLA_DTYPE::S3_CLIP); + ut(128, 4096, 16384, 32, 56); + ut(1, 4096, 16384, 32, 56); + ut(128, 4096, 4096, 32, 56); + ut(1, 4096, 4096, 32, 56); + ut(128, 16384, 4096, 32, 56); + ut(1, 16384, 4096, 32, 56); + ut(128, 4096, 16384, 32, 56); + ut(1, 4096, 16384, 32, 56); + ut(128, 4096, 4096, 32, 56); + ut(1, 4096, 4096, 32, 56); + ut(128, 16384, 4096, 32, 56); + ut(1, 16384, 4096, 32, 56); + ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, 56); + ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, 56); + ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, 56); + ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 56); + ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, 56); + ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, 56); + ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 16384, 128, 56); + ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 16384, 128, 56); + ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 4096, 128, 56); + ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 56); + ut, BTLA_ISA::AVX512_VNNI>(128, 16384, 4096, 128, 56); + ut, BTLA_ISA::AVX512_VNNI>(1, 16384, 4096, 128, 56); } template - void ut(int m, int n, int k, int blocksize, BTLA_DTYPE QUANT_T) { - DefaultThreading.set_threads(64); - printf("%s: %d %d %d\n", __FUNCTION__, n, k, blocksize); + void ut(int m, int n, int k, int blocksize, int enable_thr) { + DefaultThreading.set_threads(enable_thr); + printf("%s:%d %d %d %d\n", __FUNCTION__, m, n, k, blocksize); int ldb = n; int kblk_num = utils::updiv(k, blocksize); @@ -215,7 +221,7 @@ class UT_BlockQunatize_S3 { using PrologueB = prologue_b::gemm::WeightKBlockNInteger; PrologueB kernel; - auto ptr = kernel.createStorage(n, k, blocksize, QUANT_T, BTLA_DTYPE::BF16, BTLA_DTYPE::F32, false); + auto ptr = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::BF16, BTLA_DTYPE::F32, false); auto ptr_ref = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::BF16, BTLA_DTYPE::F32, false); avector buffer(ptr.mSize); avector buffer_ref(ptr_ref.mSize); From 97225a0ca1bb82e37315b9627d9a5dec55ee4a3b Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Tue, 23 Jan 2024 09:50:04 +0800 Subject: [PATCH 15/33] emr 64core ut --- bestla/bestla/ut/bestla_prologue_b.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index bd681c7ee..b751a7c38 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -202,6 +202,31 @@ class UT_BlockQunatize_S3 { ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 56); ut, BTLA_ISA::AVX512_VNNI>(128, 16384, 4096, 128, 56); ut, BTLA_ISA::AVX512_VNNI>(1, 16384, 4096, 128, 56); + // emr case + ut(128, 4096, 16384, 32, 64); + ut(1, 4096, 16384, 32, 64); + ut(128, 4096, 4096, 32, 64); + ut(1, 4096, 4096, 32, 64); + ut(128, 16384, 4096, 32, 64); + ut(1, 16384, 4096, 32, 64); + ut(128, 4096, 16384, 32, 64); + ut(1, 4096, 16384, 32, 64); + ut(128, 4096, 4096, 32, 64); + ut(1, 4096, 4096, 32, 64); + ut(128, 16384, 4096, 32, 64); + ut(1, 16384, 4096, 32, 64); + ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, 64); + ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, 64); + ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, 64); + ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 64); + ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, 64); + ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, 64); + ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 16384, 128, 64); + ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 16384, 128, 64); + ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 4096, 128, 64); + ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 64); + ut, BTLA_ISA::AVX512_VNNI>(128, 16384, 4096, 128, 64); + ut, BTLA_ISA::AVX512_VNNI>(1, 16384, 4096, 128, 64); } template From 4668aa458fa64a45f4bf072450febc381d960269 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Tue, 23 Jan 2024 10:52:38 +0800 Subject: [PATCH 16/33] fix packrow --- bestla/bestla/bestla_prologue_b.h | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index dde52af79..315b31909 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -122,7 +122,7 @@ class WeightKBlockNInteger { bool is_asym) { int KPad = utils::padto(k, _GemmCore_T::KTILE); int NPad = utils::padto(n, _GemmCore_T::NTILE); - if (qtype == BTLA_DTYPE::S3_CLIP) NPad = utils::padto(n, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW); + // if (qtype == BTLA_DTYPE::S3_CLIP) NPad = utils::padto(n, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW); StorageWeight tmp(_GemmCore_T::ID); tmp.resize(NPad, KPad, blocksize <= 0 ? KPad : blocksize, n, k, qtype, scat, redt, is_asym); return tmp; @@ -563,10 +563,9 @@ class WeightKBlockNInteger { parallel::IThreading* threading) { // TODO(zhe): 1D parallel compress // N==NPad, K==Kpad, ldb==Kpad - auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(K, 64); - auto col = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * K; - assert(N % (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE) == 0); // TODO(zhe): consider N pad to packrow*ntile? - auto row = N / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE); + auto ld_dst = _GemmCore_T::NTILE * utils::padto(K, 64); + auto col = _GemmCore_T::NTILE * K; + auto row = N / _GemmCore_T::NTILE; auto pad_64_buf = utils::avector(row * ld_dst, 0); kernel::wrapper::Memcpy2D::forward(B, pad_64_buf.data(), row, col, col, ld_dst); auto bit2ptr = reinterpret_cast(dstptr); @@ -750,8 +749,8 @@ class WeightKBlockNInteger { int8_t* bit3_ptr = wptr->template WPtr(); auto elt_offset = n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 64); - auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(KPad, 64); - auto row = NPad / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE); + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64); + auto row = NPad / _GemmCore_T::NTILE; assert(elt_offset % 8 == 0); auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); @@ -831,8 +830,8 @@ class WeightKBlockNInteger { int8_t* bit3_ptr = wptr->template WPtr(); auto elt_offset = n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 64); - auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(KPad, 64); - auto row = NPad / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE); + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64); + auto row = NPad / _GemmCore_T::NTILE; assert(elt_offset % 8 == 0); auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); @@ -904,8 +903,8 @@ class WeightKBlockNInteger { auto KPad = wptr->mKPad; auto NPad = wptr->mNPad; int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; - auto row = NPad / (_GemmCore_T::PACK_ROW * _GemmCore_T::NTILE); - auto ld_dst = _GemmCore_T::PACK_ROW * _GemmCore_T::NTILE * utils::padto(KPad, 64); + auto row = NPad / _GemmCore_T::NTILE; + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64); auto base_offset = n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { auto elt_offset = base_offset + i * utils::padto(KPad, 64); From a9b177b878de26310ddac282be95b1ac540d5b42 Mon Sep 17 00:00:00 2001 From: ZheWang Date: Mon, 22 Jan 2024 23:09:16 -0800 Subject: [PATCH 17/33] store-unroll --- bestla/bestla/kernel_avx512f.h | 62 ++++++++++++++++++++----- bestla/bestla/ut/bestla_prologue_b.cpp | 64 +++++++++++++------------- 2 files changed, 84 insertions(+), 42 deletions(-) diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index 5b673a2da..0214885c0 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -652,19 +652,23 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 auto zmm_0x00 = _mm512_set1_epi8(0x00); auto zmm_shift = _mm512_set1_epi32(5); + // auto bit3_interleave_decompress = [&](__m128i bit2_data, utils::bit1x8* src2) { auto bit3_interleave_decompress = [&](utils::bit2x4* src1, utils::bit1x8* src2) { const __m128i lowMask = _mm_set1_epi8(0x03); const __m128i bit2_data = _mm_loadu_si128((const __m128i*)src1); + // __m128i bit2_data; auto xmm0 = _mm_and_si128(lowMask, bit2_data); // uop:1 p:015 auto xmm1 = _mm_and_si128(lowMask, _mm_srli_epi16(bit2_data, 2)); // uop:1 p:01 auto xmm2 = _mm_and_si128(lowMask, _mm_srli_epi16(bit2_data, 4)); auto xmm3 = _mm_and_si128(lowMask, _mm_srli_epi16(bit2_data, 6)); auto ymm1 = _mm256_set_m128i(xmm1, xmm0); // uop:1 p5 auto ymm2 = _mm256_set_m128i(xmm3, xmm2); - auto zmm = _mm512_castps_si512(_mm512_insertf32x8(_mm512_castsi256_si512(ymm1), ymm2, 0x1)); + auto zmm = _mm512_inserti32x8(_mm512_castsi256_si512(ymm1), ymm2, 0x1); // make bit1-storage as mmask64, then cvt the mask to the int8-value. unsigned long long* bit1_ptr = reinterpret_cast(src2); auto bit1_mask = _cvtu64_mask64(*bit1_ptr); + // __mmask64 bit1_mask; + // __m512i zmm; auto zmm2 = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask, zmm_0x04); zmm = _mm512_add_epi8(zmm, zmm2); zmm = _mm512_sllv_epi32(zmm, zmm_shift); // int3_clip => int8 @@ -687,19 +691,55 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 int compress_wei_ptr_offset = 0; if (head_ignore_num != 0) { assert(0); - auto unpack_mask = _cvtu64_mask64(0xffffffffffffffff << head_ignore_num); - auto head_zmm = bit3_interleave_decompress(base_bit2ptr, base_bit1ptr); - _mm512_mask_storeu_epi8(base_unpack_buf, unpack_mask, head_zmm); - compress_wei_ptr_offset += 64; + // auto unpack_mask = _cvtu64_mask64(0xffffffffffffffff << head_ignore_num); + // auto head_zmm = bit3_interleave_decompress(base_bit2ptr, base_bit1ptr); + // _mm512_mask_storeu_epi8(base_unpack_buf, unpack_mask, head_zmm); + // compress_wei_ptr_offset += 64; } auto body_loop = (unpack_elt - (64 - head_ignore_num) % 64) / 64; auto tail_proc_num = (unpack_elt - (64 - head_ignore_num) % 64) % 64; + + __m128i bit2_data[4]; + __m512i bit2_data_zmm; for (int i = 0; i < body_loop; i++) { - auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, + // if (i % 4 == 0) { + // bit2_data_zmm=_mm512_loadu_si512(base_bit2ptr + compress_wei_ptr_offset); + // bit2_data[0] = _mm512_extracti32x4_epi32(bit2_data_zmm,0x0); + // bit2_data[1] = _mm512_extracti32x4_epi32(bit2_data_zmm,0x1); + // bit2_data[2] = _mm512_extracti32x4_epi32(bit2_data_zmm,0x2); + // bit2_data[3] = _mm512_extracti32x4_epi32(bit2_data_zmm,0x3); + // } + // bit2_data = _mm512_extracti32x4_epi32(bit2_data_zmm,0x0); + // auto zmm = bit3_interleave_decompress(bit2_data[i%4], + auto zmm = bit3_interleave_decompress(base_bit2ptr+compress_wei_ptr_offset/4, base_bit1ptr + compress_wei_ptr_offset / 8); + // __m512i zmm; if constexpr (!std::is_same_v<_DST_T, int8_t>) { _mm512_storeu_epi8(base_unpack_buf, zmm); - for (int j = 0; j < 64; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j); + auto xmm1 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf)); + auto xmm2 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf + 16)); + auto xmm3 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf + 32)); + auto xmm4 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf + 48)); + auto zmm1 = _mm512_cvtepi8_epi32(xmm1); + auto zmm2 = _mm512_cvtepi8_epi32(xmm2); + auto zmm3 = _mm512_cvtepi8_epi32(xmm3); + auto zmm4 = _mm512_cvtepi8_epi32(xmm4); + if constexpr (std::is_same_v<_DST_T, utils::bf16>) { + auto ymm1 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm1)); + auto ymm2 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm2)); + auto ymm3 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm3)); + auto ymm4 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm4)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr+compress_wei_ptr_offset), ymm1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr+compress_wei_ptr_offset + 16), ymm2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr+compress_wei_ptr_offset + 32), ymm3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr+compress_wei_ptr_offset + 48), ymm4); + } else { + _mm512_storeu_ps(dstptr+compress_wei_ptr_offset, _mm512_cvtepi32_ps(zmm1)); + _mm512_storeu_ps(dstptr+compress_wei_ptr_offset + 16, _mm512_cvtepi32_ps(zmm2)); + _mm512_storeu_ps(dstptr+compress_wei_ptr_offset + 32, _mm512_cvtepi32_ps(zmm3)); + _mm512_storeu_ps(dstptr+compress_wei_ptr_offset + 48, _mm512_cvtepi32_ps(zmm4)); + } + // for (int j = 0; j < 64; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j); } else { _mm512_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, zmm); } @@ -707,10 +747,10 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 } if (tail_proc_num > 0) { assert(0); - auto unpack_mask = _cvtu64_mask64(0xffffffffffffffff >> (64 - tail_proc_num)); - auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, - base_bit1ptr + compress_wei_ptr_offset / 8); - _mm512_mask_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, unpack_mask, zmm); + // auto unpack_mask = _cvtu64_mask64(0xffffffffffffffff >> (64 - tail_proc_num)); + // auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, + // base_bit1ptr + compress_wei_ptr_offset / 8); + // _mm512_mask_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, unpack_mask, zmm); } return BTLA_CODE::Success; } diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index b751a7c38..1ff78677d 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -203,30 +203,30 @@ class UT_BlockQunatize_S3 { ut, BTLA_ISA::AVX512_VNNI>(128, 16384, 4096, 128, 56); ut, BTLA_ISA::AVX512_VNNI>(1, 16384, 4096, 128, 56); // emr case - ut(128, 4096, 16384, 32, 64); - ut(1, 4096, 16384, 32, 64); - ut(128, 4096, 4096, 32, 64); - ut(1, 4096, 4096, 32, 64); - ut(128, 16384, 4096, 32, 64); - ut(1, 16384, 4096, 32, 64); - ut(128, 4096, 16384, 32, 64); - ut(1, 4096, 16384, 32, 64); - ut(128, 4096, 4096, 32, 64); - ut(1, 4096, 4096, 32, 64); - ut(128, 16384, 4096, 32, 64); - ut(1, 16384, 4096, 32, 64); - ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, 64); - ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, 64); - ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, 64); - ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 64); - ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, 64); - ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, 64); - ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 16384, 128, 64); - ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 16384, 128, 64); - ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 4096, 128, 64); - ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 64); - ut, BTLA_ISA::AVX512_VNNI>(128, 16384, 4096, 128, 64); - ut, BTLA_ISA::AVX512_VNNI>(1, 16384, 4096, 128, 64); + // ut(128, 4096, 16384, 32, 64); + // ut(1, 4096, 16384, 32, 64); + // ut(128, 4096, 4096, 32, 64); + // ut(1, 4096, 4096, 32, 64); + // ut(128, 16384, 4096, 32, 64); + // ut(1, 16384, 4096, 32, 64); + // ut(128, 4096, 16384, 32, 64); + // ut(1, 4096, 16384, 32, 64); + // ut(128, 4096, 4096, 32, 64); + // ut(1, 4096, 4096, 32, 64); + // ut(128, 16384, 4096, 32, 64); + // ut(1, 16384, 4096, 32, 64); + // ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, 64); + // ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, 64); + // ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, 64); + // ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 64); + // ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, 64); + // ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, 64); + // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 16384, 128, 64); + // ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 16384, 128, 64); + // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 4096, 128, 64); + // ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 64); + // ut, BTLA_ISA::AVX512_VNNI>(128, 16384, 4096, 128, 64); + // ut, BTLA_ISA::AVX512_VNNI>(1, 16384, 4096, 128, 64); } template @@ -323,7 +323,7 @@ class UT_BlockQunatize_S3 { // } } }; -static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; +// static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; #ifdef BTLA_UT_PROLOGUE_B static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; #endif @@ -880,7 +880,8 @@ class UTBenchmark_CompFp32 { } void ut_s4() { - benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + // benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(2048, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(4096, 4096, 11008, 128, BTLA_DTYPE::S4_CLIP); @@ -993,7 +994,7 @@ class UTBenchmark_CompFp32 { kernel.mProB.packWeight(n, k, B, n, &packBs[0], &DefaultThreading); for (size_t i = 1; i < batch; i++) { memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); - memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(float)); + memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(uint16_t)); } auto psize = (size_t)m * n * k * 2; auto memsize = (size_t)packBs[0].mSize + (m * k + m * n) * sizeof(float); @@ -1035,18 +1036,19 @@ class UTBenchmark_CompFp32 { GetCPUDevice(); if (_cd->AVX512F()) { int blocksize = 32; - benchmark, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), - C.data(), testtime, 48, qtype); + // benchmark, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), + // C.data(), testtime, 48, qtype); benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), C.data(), testtime, 48, qtype); blocksize = 128; - benchmark, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), - C.data(), testtime, 48, qtype); + // benchmark, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), + // C.data(), testtime, 48, qtype); benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), C.data(), testtime, 48, qtype); } } }; +static UTBenchmark_CompFp32 sUTBenchmark_CompFp32; #ifdef BTLA_UT_PROLOGUE_B_ static UTBenchmark_CompFp32 sUTBenchmark_CompFp32; #endif From 0dfb80669f252b41a9e5030b1c8b890954041c03 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Tue, 23 Jan 2024 15:32:43 +0800 Subject: [PATCH 18/33] int8 cmpt benchmark --- bestla/bestla/kernel_avx512f.h | 22 +++++++++--------- bestla/bestla/ut/bestla_prologue_b.cpp | 32 +++++++++++++++++--------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index 0214885c0..dfc4e046d 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -699,8 +699,8 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 auto body_loop = (unpack_elt - (64 - head_ignore_num) % 64) / 64; auto tail_proc_num = (unpack_elt - (64 - head_ignore_num) % 64) % 64; - __m128i bit2_data[4]; - __m512i bit2_data_zmm; + // __m128i bit2_data[4]; + // __m512i bit2_data_zmm; for (int i = 0; i < body_loop; i++) { // if (i % 4 == 0) { // bit2_data_zmm=_mm512_loadu_si512(base_bit2ptr + compress_wei_ptr_offset); @@ -711,7 +711,7 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 // } // bit2_data = _mm512_extracti32x4_epi32(bit2_data_zmm,0x0); // auto zmm = bit3_interleave_decompress(bit2_data[i%4], - auto zmm = bit3_interleave_decompress(base_bit2ptr+compress_wei_ptr_offset/4, + auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, base_bit1ptr + compress_wei_ptr_offset / 8); // __m512i zmm; if constexpr (!std::is_same_v<_DST_T, int8_t>) { @@ -729,15 +729,15 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 auto ymm2 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm2)); auto ymm3 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm3)); auto ymm4 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm4)); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr+compress_wei_ptr_offset), ymm1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr+compress_wei_ptr_offset + 16), ymm2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr+compress_wei_ptr_offset + 32), ymm3); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr+compress_wei_ptr_offset + 48), ymm4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset), ymm1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset + 16), ymm2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset + 32), ymm3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset + 48), ymm4); } else { - _mm512_storeu_ps(dstptr+compress_wei_ptr_offset, _mm512_cvtepi32_ps(zmm1)); - _mm512_storeu_ps(dstptr+compress_wei_ptr_offset + 16, _mm512_cvtepi32_ps(zmm2)); - _mm512_storeu_ps(dstptr+compress_wei_ptr_offset + 32, _mm512_cvtepi32_ps(zmm3)); - _mm512_storeu_ps(dstptr+compress_wei_ptr_offset + 48, _mm512_cvtepi32_ps(zmm4)); + _mm512_storeu_ps(dstptr + compress_wei_ptr_offset, _mm512_cvtepi32_ps(zmm1)); + _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 16, _mm512_cvtepi32_ps(zmm2)); + _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 32, _mm512_cvtepi32_ps(zmm3)); + _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 48, _mm512_cvtepi32_ps(zmm4)); } // for (int j = 0; j < 64; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j); } else { diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 1ff78677d..ec542ba3e 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -968,9 +968,14 @@ class UTBenchmark_CompFp32 { int threads, BTLA_DTYPE qtype) { LOG_T log; using Parallel = parallel::gemm::SchedulerKBlock; + // using Launcher = + // wrapper::gemm::LauncherKBlock; using Launcher = - wrapper::gemm::LauncherKBlock; + wrapper::gemm::LauncherIntKBlock; Launcher kernel; DefaultThreading.set_threads(threads); auto corestr = gemm::CoreAttr::to_str(Core_T::ID); @@ -987,6 +992,9 @@ class UTBenchmark_CompFp32 { } std::vector packBs(batch, 0); std::vector bufB(tmpB.mSize * batch); + auto quanA = kernel.mProA.createStorage(m, k, blocksize, false); + utils::avector bufferA(quanA.mSize); + quanA.assign(bufferA.data()); for (size_t i = 0; i < batch; i++) { packBs[i] = tmpB; packBs[i].assign(bufB.data() + i * tmpB.mSize); @@ -1004,11 +1012,12 @@ class UTBenchmark_CompFp32 { for (size_t i = 0; i < batch; i++) { GemmProblem gp(1, m, n, k, blocksize); typename Launcher::Param args{gp, - {A + i * m * k, k}, + {A + i * m * k, k, &quanA}, {&packBs[i]}, - {packBs[i].template SPtr(), packBs[i].SDtype(), packBs[i].CStep()}, + // {packBs[i].template SPtr(), packBs[i].SDtype(), packBs[i].CStep()}, {C + i * m * n, n}}; - parallel::GemmRun(kernel, args, &DefaultThreading); + // parallel::GemmRun(kernel, args, &DefaultThreading); + parallel::GemmRunWithA(kernel, args, &DefaultThreading); } if (log.stop()) { double t = log.avg_val / batch; @@ -1035,16 +1044,17 @@ class UTBenchmark_CompFp32 { float testtime = 500.f; GetCPUDevice(); if (_cd->AVX512F()) { - int blocksize = 32; + // int blocksize = 32; // benchmark, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), // C.data(), testtime, 48, qtype); - benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), - C.data(), testtime, 48, qtype); - blocksize = 128; + // benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), + // C.data(), testtime, 48, qtype); + int blocksize = 128; // benchmark, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), // C.data(), testtime, 48, qtype); - benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), - C.data(), testtime, 48, qtype); + // benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), + benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), + B.data(), C.data(), testtime, 48, qtype); } } }; From b445a3feba73ed851eb1fc1e5b2c8ef25869064c Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Tue, 23 Jan 2024 22:50:41 +0800 Subject: [PATCH 19/33] jit draft --- bestla/bestla/kernel_jit.h | 129 +++++++++++++++++++++++++++++++++ bestla/bestla/kernel_wrapper.h | 5 +- 2 files changed, 133 insertions(+), 1 deletion(-) diff --git a/bestla/bestla/kernel_jit.h b/bestla/bestla/kernel_jit.h index 81a51a41d..04d1c3ff6 100644 --- a/bestla/bestla/kernel_jit.h +++ b/bestla/bestla/kernel_jit.h @@ -250,6 +250,135 @@ class DequanS8FP { } } }; +class DecompresssS3 { + public: + class MicroKernelAVX512F : protected xbyak::JitAvx512f { + public: + struct params { + void *bit2ptr, *bit1ptr, *dstptr; + int unpack_elt; + int8_t ox3, ox4; + int ox5; + }; + typedef long long (*func_t)(params*); + static int constexpr VBytes = 64; + MicroKernelAVX512F() { + generate(); + this->ready(); + mKernel = this->getCode(); + } + + void generate() { + inLocalLabel(); // use local label for multiple instance + { + Xbyak::util::StackFrame st(this, 1, 13); + parambase = st.p[0]; + reg_bit1ptr = st.t[0]; + reg_bit2ptr = st.t[1]; + reg_loop = st.t[2]; + reg_iter = st.t[3]; + reg_dst = st.t[4]; + reg_tmp = st.t[5]; + reg_ret = rax; + xor_(reg_loop, reg_loop); + mov(reg_loop.cvt32(), ptr[parambase + OFFSET(unpack_elt)]); + xor_(reg_iter, reg_iter); + Xbyak::Ymm LowMask = ymm1; + Xbyak::Zmm zmm_0x04 = zmm31; + Xbyak::Zmm zmm_shift = zmm30; + vpbroadcastb(LowMask, ptr[parambase + OFFSET(ox3)]); + vpbroadcastb(zmm_0x04, ptr[parambase + OFFSET(ox4)]); + vpbroadcastd(zmm_shift, ptr[parambase + OFFSET(ox5)]); + mov(reg_bit1ptr, ptr[parambase + OFFSET(bit1ptr)]); + mov(reg_bit2ptr, ptr[parambase + OFFSET(bit2ptr)]); + mov(reg_dst, ptr[parambase + OFFSET(dstptr)]); + L("loop_label"); + imul(reg_tmp, reg_iter, 32); + kmovq(bit1_mask1, ptr[reg_bit1ptr + reg_tmp]); + kmovq(bit1_mask2, ptr[reg_bit1ptr + reg_tmp + 8]); + kmovq(bit1_mask3, ptr[reg_bit1ptr + reg_tmp + 16]); + kmovq(bit1_mask4, ptr[reg_bit1ptr + reg_tmp + 24]); + Xbyak::Zmm bit2_data_zmm = zmm0; + imul(reg_tmp, reg_iter, 64); + vmovups(bit2_data_zmm, ptr[reg_bit2ptr + reg_tmp]); + vextractf32x8(ymm2, bit2_data_zmm, 0x0); + vextractf32x8(ymm3, bit2_data_zmm, 0x1); + + vpand(ymm4, LowMask, ymm2); + vpsrlw(ymm2, bit2_data_zmm, 2); + vpand(ymm5, LowMask, ymm2); + vpsrlw(ymm2, bit2_data_zmm, 2); + vpand(ymm6, LowMask, ymm2); + vpsrlw(ymm2, bit2_data_zmm, 2); + vpand(ymm7, LowMask, ymm2); + vinsertf32x8(zmm4, zmm4, ymm5, 1); + vinsertf32x8(zmm6, zmm6, ymm7, 1); + + vpand(ymm8, LowMask, ymm3); + vpsrlw(ymm3, bit2_data_zmm, 2); + vpand(ymm9, LowMask, ymm3); + vpsrlw(ymm3, bit2_data_zmm, 2); + vpand(ymm10, LowMask, ymm3); + vpsrlw(ymm3, bit2_data_zmm, 2); + vpand(ymm11, LowMask, ymm3); + vinsertf32x8(zmm8, zmm8, ymm9, 1); + vinsertf32x8(zmm10, zmm10, ymm11, 1); + + vxorps(zmm12, zmm12); + vxorps(zmm13, zmm13); + vxorps(zmm14, zmm14); + vxorps(zmm15, zmm15); + vmovdqu8(zmm12 | bit1_mask1, zmm_0x04); + vmovdqu8(zmm13 | bit1_mask2, zmm_0x04); + vmovdqu8(zmm14 | bit1_mask3, zmm_0x04); + vmovdqu8(zmm15 | bit1_mask4, zmm_0x04); + + vpaddb(zmm4, zmm4, zmm12); + vpaddb(zmm6, zmm6, zmm13); + vpaddb(zmm8, zmm8, zmm14); + vpaddb(zmm10, zmm10, zmm15); + + vpsllvd(zmm4, zmm4, zmm_shift); + vpsllvd(zmm6, zmm6, zmm_shift); + vpsllvd(zmm8, zmm10, zmm_shift); + vpsllvd(zmm10, zmm10, zmm_shift); + + imul(reg_tmp, reg_iter, 256); + vmovups(ptr[reg_dst + reg_tmp], zmm4); + vmovups(ptr[reg_dst + reg_tmp + 64], zmm6); + vmovups(ptr[reg_dst + reg_tmp + 128], zmm8); + vmovups(ptr[reg_dst + reg_tmp + 192], zmm10); + + add(reg_iter, 1); + cmp(reg_iter, reg_loop); + jb("loop_label"); + mov(reg_ret, 0); + } + outLocalLabel(); // end of local label + } + + func_t mKernel = nullptr; + + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_bit1ptr; + Xbyak::Reg64 reg_bit2ptr; + Xbyak::Reg64 reg_loop; + Xbyak::Reg64 reg_iter; + Xbyak::Reg64 reg_dst; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_ret; + Xbyak::Opmask bit1_mask1 = Xbyak::Opmask(1); + Xbyak::Opmask bit1_mask2 = Xbyak::Opmask(2); + Xbyak::Opmask bit1_mask3 = Xbyak::Opmask(3); + Xbyak::Opmask bit1_mask4 = Xbyak::Opmask(4); + }; + static void forward_avx512f(void* bit2ptr, void* bit1ptr, void* dstptr, int unpack_elt) { + static MicroKernelAVX512F ker; + auto param = MicroKernelAVX512F::params{bit2ptr, bit1ptr, dstptr, unpack_elt / 256, 0x03, 0x4, 5}; + ker.mKernel(¶m); + } +}; class DequanKBlockS8Fp { public: diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index 48ebefe9e..26f69fc39 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -511,9 +511,12 @@ class DecompressKBlockS3S8Fp { int interleave_n_offset, int unpack_elt, void* tmp, size_t tmpsize) { BTLA_CODE ret = BTLA_CODE::NotSupport; #if CompileAVX512F() - if constexpr (utils::isa_base::avx512f) { + if constexpr (utils::isa_base::avx512f && !std::is_same_v<_DST_T, int8_t>) { ret = avx512f::decompress_kblock_s3_s8fp(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, reinterpret_cast(tmp), tmpsize); + } else { + jit::DecompresssS3::forward_avx512f(bit2ptr, bit1ptr, dstptr, unpack_elt); + ret = BTLA_CODE::Success; } #endif assert(ret == BTLA_CODE::Success); From 597b7021f4d69abeb10d07ad62e91376d647cd69 Mon Sep 17 00:00:00 2001 From: "Zhe,Wang" Date: Tue, 23 Jan 2024 20:55:54 -0800 Subject: [PATCH 20/33] pack128 draft, toddo: fix jit kernel bug --- bestla/bestla/kernel_avx512f.h | 103 +++++++++++++++---------- bestla/bestla/kernel_jit.h | 8 +- bestla/bestla/kernel_ref.h | 26 +++++-- bestla/bestla/kernel_wrapper.h | 10 +-- bestla/bestla/ut/bestla_prologue_b.cpp | 50 ++++++------ 5 files changed, 114 insertions(+), 83 deletions(-) diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index dfc4e046d..1f9a727e9 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -675,6 +675,31 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 return zmm; }; + auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t *dst) { +const __m256i lowMask = _mm256_set1_epi8(0x03); + const __m256i bit2_data = _mm256_loadu_si256((const __m256i *)src1); + auto ymm0 = _mm256_and_si256(lowMask, bit2_data); // uop:1 p:015 + auto ymm1 = + _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2)); // uop:1 p:01 + auto ymm2 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 4)); + auto ymm3 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 6)); + auto zmm1 = _mm512_inserti32x8(_mm512_castsi256_si512(ymm0), ymm1, 0x1); // lat3, tp1 uop1 p:5 + auto zmm2 = _mm512_inserti32x8(_mm512_castsi256_si512(ymm2), ymm3, 0x1); + + unsigned long long *bit1_ptr = reinterpret_cast(src2); + auto bit1_mask1 = _cvtu64_mask64(*bit1_ptr); + auto bit1_mask2 = _cvtu64_mask64(*(bit1_ptr + 1)); + auto zmm1_ = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask1, zmm_0x04); + auto zmm2_ = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask2, zmm_0x04); + zmm1 = _mm512_add_epi8(zmm1, zmm1_); + zmm2 = _mm512_add_epi8(zmm2, zmm2_); + zmm1 = _mm512_sllv_epi32(zmm1, zmm_shift); // int3_clip => int8 + zmm2 = _mm512_sllv_epi32(zmm2, zmm_shift); // int3_clip => int8 + + _mm512_storeu_epi8((__m512i *)dst, zmm1); + _mm512_storeu_epi8((__m512i *)(dst + 64), zmm2); + }; + assert(head_ignore_num % 8 == 0); if (head_ignore_num > unpack_elt) { @@ -699,51 +724,45 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 auto body_loop = (unpack_elt - (64 - head_ignore_num) % 64) / 64; auto tail_proc_num = (unpack_elt - (64 - head_ignore_num) % 64) % 64; - // __m128i bit2_data[4]; - // __m512i bit2_data_zmm; - for (int i = 0; i < body_loop; i++) { - // if (i % 4 == 0) { - // bit2_data_zmm=_mm512_loadu_si512(base_bit2ptr + compress_wei_ptr_offset); - // bit2_data[0] = _mm512_extracti32x4_epi32(bit2_data_zmm,0x0); - // bit2_data[1] = _mm512_extracti32x4_epi32(bit2_data_zmm,0x1); - // bit2_data[2] = _mm512_extracti32x4_epi32(bit2_data_zmm,0x2); - // bit2_data[3] = _mm512_extracti32x4_epi32(bit2_data_zmm,0x3); - // } - // bit2_data = _mm512_extracti32x4_epi32(bit2_data_zmm,0x0); - // auto zmm = bit3_interleave_decompress(bit2_data[i%4], - auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, - base_bit1ptr + compress_wei_ptr_offset / 8); - // __m512i zmm; + // for (int i = 0; i < body_loop; i++) { + for (int i = 0; i < body_loop/2; i++) { if constexpr (!std::is_same_v<_DST_T, int8_t>) { - _mm512_storeu_epi8(base_unpack_buf, zmm); - auto xmm1 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf)); - auto xmm2 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf + 16)); - auto xmm3 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf + 32)); - auto xmm4 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf + 48)); - auto zmm1 = _mm512_cvtepi8_epi32(xmm1); - auto zmm2 = _mm512_cvtepi8_epi32(xmm2); - auto zmm3 = _mm512_cvtepi8_epi32(xmm3); - auto zmm4 = _mm512_cvtepi8_epi32(xmm4); - if constexpr (std::is_same_v<_DST_T, utils::bf16>) { - auto ymm1 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm1)); - auto ymm2 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm2)); - auto ymm3 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm3)); - auto ymm4 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm4)); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset), ymm1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset + 16), ymm2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset + 32), ymm3); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset + 48), ymm4); - } else { - _mm512_storeu_ps(dstptr + compress_wei_ptr_offset, _mm512_cvtepi32_ps(zmm1)); - _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 16, _mm512_cvtepi32_ps(zmm2)); - _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 32, _mm512_cvtepi32_ps(zmm3)); - _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 48, _mm512_cvtepi32_ps(zmm4)); - } - // for (int j = 0; j < 64; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j); +bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 4,base_bit1ptr + compress_wei_ptr_offset / 8,reinterpret_cast(base_unpack_buf)); +for (int j = 0; j < 128; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j); + // auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, + // base_bit1ptr + compress_wei_ptr_offset / 8); + // _mm512_storeu_epi8(base_unpack_buf, zmm); + // auto xmm1 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf)); + // auto xmm2 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf + 16)); + // auto xmm3 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf + 32)); + // auto xmm4 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf + 48)); + // auto zmm1 = _mm512_cvtepi8_epi32(xmm1); + // auto zmm2 = _mm512_cvtepi8_epi32(xmm2); + // auto zmm3 = _mm512_cvtepi8_epi32(xmm3); + // auto zmm4 = _mm512_cvtepi8_epi32(xmm4); + // if constexpr (std::is_same_v<_DST_T, utils::bf16>) { + // auto ymm1 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm1)); + // auto ymm2 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm2)); + // auto ymm3 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm3)); + // auto ymm4 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm4)); + // _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset), ymm1); + // _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset + 16), ymm2); + // _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset + 32), ymm3); + // _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset + 48), ymm4); + // } else { + // _mm512_storeu_ps(dstptr + compress_wei_ptr_offset, _mm512_cvtepi32_ps(zmm1)); + // _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 16, _mm512_cvtepi32_ps(zmm2)); + // _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 32, _mm512_cvtepi32_ps(zmm3)); + // _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 48, _mm512_cvtepi32_ps(zmm4)); + // } } else { - _mm512_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, zmm); + // auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, + // base_bit1ptr + compress_wei_ptr_offset / 8); +bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 4,base_bit1ptr + compress_wei_ptr_offset / 8,reinterpret_cast(base_unpack_buf)+compress_wei_ptr_offset); + // _mm512_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, zmm); } - compress_wei_ptr_offset += 64; + // compress_wei_ptr_offset += 64; + compress_wei_ptr_offset += 128; } if (tail_proc_num > 0) { assert(0); diff --git a/bestla/bestla/kernel_jit.h b/bestla/bestla/kernel_jit.h index 04d1c3ff6..eb387dcc5 100644 --- a/bestla/bestla/kernel_jit.h +++ b/bestla/bestla/kernel_jit.h @@ -307,9 +307,9 @@ class DecompresssS3 { vpand(ymm4, LowMask, ymm2); vpsrlw(ymm2, bit2_data_zmm, 2); vpand(ymm5, LowMask, ymm2); - vpsrlw(ymm2, bit2_data_zmm, 2); + vpsrlw(ymm2, bit2_data_zmm, 4); vpand(ymm6, LowMask, ymm2); - vpsrlw(ymm2, bit2_data_zmm, 2); + vpsrlw(ymm2, bit2_data_zmm, 6); vpand(ymm7, LowMask, ymm2); vinsertf32x8(zmm4, zmm4, ymm5, 1); vinsertf32x8(zmm6, zmm6, ymm7, 1); @@ -317,9 +317,9 @@ class DecompresssS3 { vpand(ymm8, LowMask, ymm3); vpsrlw(ymm3, bit2_data_zmm, 2); vpand(ymm9, LowMask, ymm3); - vpsrlw(ymm3, bit2_data_zmm, 2); + vpsrlw(ymm3, bit2_data_zmm, 4); vpand(ymm10, LowMask, ymm3); - vpsrlw(ymm3, bit2_data_zmm, 2); + vpsrlw(ymm3, bit2_data_zmm, 6); vpand(ymm11, LowMask, ymm3); vinsertf32x8(zmm8, zmm8, ymm9, 1); vinsertf32x8(zmm10, zmm10, ymm11, 1); diff --git a/bestla/bestla/kernel_ref.h b/bestla/bestla/kernel_ref.h index ab601842a..08aa3358f 100644 --- a/bestla/bestla/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -193,20 +193,32 @@ static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x // std::cout << "==============" << std::endl; auto bit2_interleave = [&](int8_t* src, int8_t* dst) { - for (int i = 0; i < 64 / 4; i++) { + for (int i = 0; i < 128 / 4; i++) { dst[4 * i] = src[i]; - dst[4 * i + 1] = src[64 / 4 + i]; - dst[4 * i + 2] = src[64 / 4 * 2 + i]; - dst[4 * i + 3] = src[64 / 4 * 3 + i]; + dst[4 * i + 1] = src[128 / 4 + i]; + dst[4 * i + 2] = src[128 / 4 * 2 + i]; + dst[4 * i + 3] = src[128 / 4 * 3 + i]; } }; - int8_t interleave_buf[64]; + // auto bit2_interleave = [&](int8_t* src, int8_t* dst) { + // for (int i = 0; i < 64 / 4; i++) { + // dst[4 * i] = src[i]; + // dst[4 * i + 1] = src[64 / 4 + i]; + // dst[4 * i + 2] = src[64 / 4 * 2 + i]; + // dst[4 * i + 3] = src[64 / 4 * 3 + i]; + // } + // }; + + + int8_t interleave_buf[128]; for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j += 64) { + for (int j = 0; j < col; j += 128) { + // for (int j = 0; j < col; j += 64) { bit2_interleave(const_cast(srcptr + i * ld_src + j), interleave_buf); - for (int k = 0; k < 16; k++) { + for (int k = 0; k < 32; k++) { + // for (int k = 0; k < 16; k++) { bit2ptr[i * ld_dst / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5; bit2ptr[i * ld_dst / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5; bit2ptr[i * ld_dst / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5; diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index 26f69fc39..077687d15 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -511,13 +511,13 @@ class DecompressKBlockS3S8Fp { int interleave_n_offset, int unpack_elt, void* tmp, size_t tmpsize) { BTLA_CODE ret = BTLA_CODE::NotSupport; #if CompileAVX512F() - if constexpr (utils::isa_base::avx512f && !std::is_same_v<_DST_T, int8_t>) { + // if constexpr (utils::isa_base::avx512f && !std::is_same_v<_DST_T, int8_t>) { ret = avx512f::decompress_kblock_s3_s8fp(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, reinterpret_cast(tmp), tmpsize); - } else { - jit::DecompresssS3::forward_avx512f(bit2ptr, bit1ptr, dstptr, unpack_elt); - ret = BTLA_CODE::Success; - } + // } else { + // jit::DecompresssS3::forward_avx512f(bit2ptr, bit1ptr, dstptr, unpack_elt); + // ret = BTLA_CODE::Success; + // } #endif assert(ret == BTLA_CODE::Success); return ret; diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index ec542ba3e..7dfdc37a1 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -178,30 +178,30 @@ class UT_BlockQunatize_S3 { UT_BlockQunatize_S3() { UT_START(); CheckISA(AVX512F); - ut(128, 4096, 16384, 32, 56); - ut(1, 4096, 16384, 32, 56); - ut(128, 4096, 4096, 32, 56); - ut(1, 4096, 4096, 32, 56); - ut(128, 16384, 4096, 32, 56); - ut(1, 16384, 4096, 32, 56); - ut(128, 4096, 16384, 32, 56); - ut(1, 4096, 16384, 32, 56); - ut(128, 4096, 4096, 32, 56); - ut(1, 4096, 4096, 32, 56); - ut(128, 16384, 4096, 32, 56); - ut(1, 16384, 4096, 32, 56); + // ut(128, 4096, 16384, 32, 56); + // ut(1, 4096, 16384, 32, 56); + // ut(128, 4096, 4096, 32, 56); + // ut(1, 4096, 4096, 32, 56); + // ut(128, 16384, 4096, 32, 56); + // ut(1, 16384, 4096, 32, 56); + // ut(128, 4096, 16384, 32, 56); + // ut(1, 4096, 16384, 32, 56); + // ut(128, 4096, 4096, 32, 56); + // ut(1, 4096, 4096, 32, 56); + // ut(128, 16384, 4096, 32, 56); + // ut(1, 16384, 4096, 32, 56); ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, 56); ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, 56); ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, 56); ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 56); ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, 56); ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, 56); - ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 16384, 128, 56); - ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 16384, 128, 56); - ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 4096, 128, 56); - ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 56); - ut, BTLA_ISA::AVX512_VNNI>(128, 16384, 4096, 128, 56); - ut, BTLA_ISA::AVX512_VNNI>(1, 16384, 4096, 128, 56); + // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 16384, 128, 56); + // ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 16384, 128, 56); + // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 4096, 128, 56); + // ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 56); + // ut, BTLA_ISA::AVX512_VNNI>(128, 16384, 4096, 128, 56); + // ut, BTLA_ISA::AVX512_VNNI>(1, 16384, 4096, 128, 56); // emr case // ut(128, 4096, 16384, 32, 64); // ut(1, 4096, 16384, 32, 64); @@ -215,12 +215,12 @@ class UT_BlockQunatize_S3 { // ut(1, 4096, 4096, 32, 64); // ut(128, 16384, 4096, 32, 64); // ut(1, 16384, 4096, 32, 64); - // ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, 64); - // ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, 64); - // ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, 64); - // ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 64); - // ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, 64); - // ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, 64); + ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, 64); + ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, 64); + ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, 64); + ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 64); + ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, 64); + ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, 64); // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 16384, 128, 64); // ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 16384, 128, 64); // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 4096, 128, 64); @@ -323,7 +323,7 @@ class UT_BlockQunatize_S3 { // } } }; -// static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; +static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; #ifdef BTLA_UT_PROLOGUE_B static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; #endif From d863a1fe9f7f7209db653de16eac1aef970a993c Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Wed, 24 Jan 2024 13:18:33 +0800 Subject: [PATCH 21/33] fix jit & add jit_s3_s8 ut --- bestla/bestla/kernel_avx512f.h | 61 ++++++++++++++------------ bestla/bestla/kernel_jit.h | 14 +++--- bestla/bestla/kernel_wrapper.h | 10 ++--- bestla/bestla/ut/bestla_prologue_b.cpp | 4 +- bestla/bestla/ut/kernel_jit.cpp | 61 ++++++++++++++++++++++++++ 5 files changed, 107 insertions(+), 43 deletions(-) diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index 1f9a727e9..ea031e912 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -675,29 +675,28 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 return zmm; }; - auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t *dst) { -const __m256i lowMask = _mm256_set1_epi8(0x03); - const __m256i bit2_data = _mm256_loadu_si256((const __m256i *)src1); - auto ymm0 = _mm256_and_si256(lowMask, bit2_data); // uop:1 p:015 - auto ymm1 = - _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2)); // uop:1 p:01 - auto ymm2 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 4)); - auto ymm3 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 6)); - auto zmm1 = _mm512_inserti32x8(_mm512_castsi256_si512(ymm0), ymm1, 0x1); // lat3, tp1 uop1 p:5 - auto zmm2 = _mm512_inserti32x8(_mm512_castsi256_si512(ymm2), ymm3, 0x1); - - unsigned long long *bit1_ptr = reinterpret_cast(src2); - auto bit1_mask1 = _cvtu64_mask64(*bit1_ptr); - auto bit1_mask2 = _cvtu64_mask64(*(bit1_ptr + 1)); - auto zmm1_ = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask1, zmm_0x04); - auto zmm2_ = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask2, zmm_0x04); - zmm1 = _mm512_add_epi8(zmm1, zmm1_); - zmm2 = _mm512_add_epi8(zmm2, zmm2_); - zmm1 = _mm512_sllv_epi32(zmm1, zmm_shift); // int3_clip => int8 - zmm2 = _mm512_sllv_epi32(zmm2, zmm_shift); // int3_clip => int8 - - _mm512_storeu_epi8((__m512i *)dst, zmm1); - _mm512_storeu_epi8((__m512i *)(dst + 64), zmm2); + auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) { + const __m256i lowMask = _mm256_set1_epi8(0x03); + const __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1); + auto ymm0 = _mm256_and_si256(lowMask, bit2_data); // uop:1 p:015 + auto ymm1 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2)); // uop:1 p:01 + auto ymm2 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 4)); + auto ymm3 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 6)); + auto zmm1 = _mm512_inserti32x8(_mm512_castsi256_si512(ymm0), ymm1, 0x1); // lat3, tp1 uop1 p:5 + auto zmm2 = _mm512_inserti32x8(_mm512_castsi256_si512(ymm2), ymm3, 0x1); + + unsigned long long* bit1_ptr = reinterpret_cast(src2); + auto bit1_mask1 = _cvtu64_mask64(*bit1_ptr); + auto bit1_mask2 = _cvtu64_mask64(*(bit1_ptr + 1)); + auto zmm1_ = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask1, zmm_0x04); + auto zmm2_ = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask2, zmm_0x04); + zmm1 = _mm512_add_epi8(zmm1, zmm1_); + zmm2 = _mm512_add_epi8(zmm2, zmm2_); + zmm1 = _mm512_sllv_epi32(zmm1, zmm_shift); // int3_clip => int8 + zmm2 = _mm512_sllv_epi32(zmm2, zmm_shift); // int3_clip => int8 + + _mm512_storeu_epi8((__m512i*)dst, zmm1); + _mm512_storeu_epi8((__m512i*)(dst + 64), zmm2); }; assert(head_ignore_num % 8 == 0); @@ -725,10 +724,12 @@ const __m256i lowMask = _mm256_set1_epi8(0x03); auto tail_proc_num = (unpack_elt - (64 - head_ignore_num) % 64) % 64; // for (int i = 0; i < body_loop; i++) { - for (int i = 0; i < body_loop/2; i++) { + for (int i = 0; i < body_loop / 2; i++) { if constexpr (!std::is_same_v<_DST_T, int8_t>) { -bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 4,base_bit1ptr + compress_wei_ptr_offset / 8,reinterpret_cast(base_unpack_buf)); -for (int j = 0; j < 128; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j); + bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 4, + base_bit1ptr + compress_wei_ptr_offset / 8, + reinterpret_cast(base_unpack_buf)); + for (int j = 0; j < 128; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j); // auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, // base_bit1ptr + compress_wei_ptr_offset / 8); // _mm512_storeu_epi8(base_unpack_buf, zmm); @@ -756,9 +757,11 @@ for (int j = 0; j < 128; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_of // _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 48, _mm512_cvtepi32_ps(zmm4)); // } } else { - // auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, - // base_bit1ptr + compress_wei_ptr_offset / 8); -bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 4,base_bit1ptr + compress_wei_ptr_offset / 8,reinterpret_cast(base_unpack_buf)+compress_wei_ptr_offset); + // auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, + // base_bit1ptr + compress_wei_ptr_offset / 8); + bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 4, + base_bit1ptr + compress_wei_ptr_offset / 8, + reinterpret_cast(base_unpack_buf) + compress_wei_ptr_offset); // _mm512_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, zmm); } // compress_wei_ptr_offset += 64; diff --git a/bestla/bestla/kernel_jit.h b/bestla/bestla/kernel_jit.h index eb387dcc5..d1667c4f6 100644 --- a/bestla/bestla/kernel_jit.h +++ b/bestla/bestla/kernel_jit.h @@ -305,21 +305,21 @@ class DecompresssS3 { vextractf32x8(ymm3, bit2_data_zmm, 0x1); vpand(ymm4, LowMask, ymm2); - vpsrlw(ymm2, bit2_data_zmm, 2); + vpsrlw(ymm2, ymm2, 2); vpand(ymm5, LowMask, ymm2); - vpsrlw(ymm2, bit2_data_zmm, 4); + vpsrlw(ymm2, ymm2, 2); vpand(ymm6, LowMask, ymm2); - vpsrlw(ymm2, bit2_data_zmm, 6); + vpsrlw(ymm2, ymm2, 2); vpand(ymm7, LowMask, ymm2); vinsertf32x8(zmm4, zmm4, ymm5, 1); vinsertf32x8(zmm6, zmm6, ymm7, 1); vpand(ymm8, LowMask, ymm3); - vpsrlw(ymm3, bit2_data_zmm, 2); + vpsrlw(ymm3, ymm3, 2); vpand(ymm9, LowMask, ymm3); - vpsrlw(ymm3, bit2_data_zmm, 4); + vpsrlw(ymm3, ymm3, 2); vpand(ymm10, LowMask, ymm3); - vpsrlw(ymm3, bit2_data_zmm, 6); + vpsrlw(ymm3, ymm3, 2); vpand(ymm11, LowMask, ymm3); vinsertf32x8(zmm8, zmm8, ymm9, 1); vinsertf32x8(zmm10, zmm10, ymm11, 1); @@ -340,7 +340,7 @@ class DecompresssS3 { vpsllvd(zmm4, zmm4, zmm_shift); vpsllvd(zmm6, zmm6, zmm_shift); - vpsllvd(zmm8, zmm10, zmm_shift); + vpsllvd(zmm8, zmm8, zmm_shift); vpsllvd(zmm10, zmm10, zmm_shift); imul(reg_tmp, reg_iter, 256); diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index 077687d15..26f69fc39 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -511,13 +511,13 @@ class DecompressKBlockS3S8Fp { int interleave_n_offset, int unpack_elt, void* tmp, size_t tmpsize) { BTLA_CODE ret = BTLA_CODE::NotSupport; #if CompileAVX512F() - // if constexpr (utils::isa_base::avx512f && !std::is_same_v<_DST_T, int8_t>) { + if constexpr (utils::isa_base::avx512f && !std::is_same_v<_DST_T, int8_t>) { ret = avx512f::decompress_kblock_s3_s8fp(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, reinterpret_cast(tmp), tmpsize); - // } else { - // jit::DecompresssS3::forward_avx512f(bit2ptr, bit1ptr, dstptr, unpack_elt); - // ret = BTLA_CODE::Success; - // } + } else { + jit::DecompresssS3::forward_avx512f(bit2ptr, bit1ptr, dstptr, unpack_elt); + ret = BTLA_CODE::Success; + } #endif assert(ret == BTLA_CODE::Success); return ret; diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 7dfdc37a1..54da2b7d1 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -881,8 +881,8 @@ class UTBenchmark_CompFp32 { void ut_s4() { // benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); - benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); - benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + benchmark_all(32, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); + benchmark_all(32, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(2048, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(4096, 4096, 11008, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); diff --git a/bestla/bestla/ut/kernel_jit.cpp b/bestla/bestla/ut/kernel_jit.cpp index ce1198c99..33aa5d4b3 100644 --- a/bestla/bestla/ut/kernel_jit.cpp +++ b/bestla/bestla/ut/kernel_jit.cpp @@ -296,5 +296,66 @@ class UT_DecompressS4S8 { #ifdef BTLA_UT_KERNEL_JIT static UT_DecompressS4S8 sUT_DecompressS4S8; #endif + +class UT_DecompressS3S8 { + public: + UT_DecompressS3S8() { + UT_START(); + ut(1, 256); + } + void ut(int row, int col) { + ut::UT_vector_s8 quanW; + quanW.resize(row * col); + quanW.fill_rand(-8, 7); + for (int i = 0; i < row * col; i++) quanW.data()[i] = (quanW.data()[i] * 16) & 0xe0; + std::vector bit2_data(row * col / 4); + std::vector bit1_data(row * col / 8); + std::vector ref(row * col); + std::vector tar(row * col); + + auto bit2_interleave = [&](int8_t* src, int8_t* dst) { + for (int i = 0; i < 128 / 4; i++) { + dst[4 * i] = src[i]; + dst[4 * i + 1] = src[128 / 4 + i]; + dst[4 * i + 2] = src[128 / 4 * 2 + i]; + dst[4 * i + 3] = src[128 / 4 * 3 + i]; + } + }; + + int8_t interleave_buf[128]; + + for (int i = 0; i < row; i++) { + for (int j = 0; j < col; j += 128) { + bit2_interleave(const_cast(quanW.data() + i * col + j), interleave_buf); + for (int k = 0; k < 32; k++) { + bit2_data[i * col / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5; + bit2_data[i * col / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5; + bit2_data[i * col / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5; + bit2_data[i * col / 4 + j / 4 + k].d = interleave_buf[4 * k + 3] >> 5; + } + } + } + // store 1 bit without interleave as mask. + for (int i = 0; i < row; i++) { + for (int j = 0; j < col; j += 8) { + bit1_data[i * col / 8 + j / 8].a = quanW.data()[i * col + j] >> 7; + bit1_data[i * col / 8 + j / 8].b = quanW.data()[i * col + j + 1] >> 7; + bit1_data[i * col / 8 + j / 8].c = quanW.data()[i * col + j + 2] >> 7; + bit1_data[i * col / 8 + j / 8].d = quanW.data()[i * col + j + 3] >> 7; + bit1_data[i * col / 8 + j / 8].e = quanW.data()[i * col + j + 4] >> 7; + bit1_data[i * col / 8 + j / 8].f = quanW.data()[i * col + j + 5] >> 7; + bit1_data[i * col / 8 + j / 8].g = quanW.data()[i * col + j + 6] >> 7; + bit1_data[i * col / 8 + j / 8].h = quanW.data()[i * col + j + 7] >> 7; + } + } + + kernel::avx512f::decompress_kblock_s3_s8fp(bit2_data.data(), bit1_data.data(), + ref.data(), 0, row * col, nullptr, -1); + kernel::jit::DecompresssS3::forward_avx512f(bit2_data.data(), bit1_data.data(), tar.data(), row * col); + buffer_error(tar.data(), ref.data(), ref.size()); + } +}; +static UT_DecompressS3S8 sUT_DecompressS3S8; + } // namespace ut } // namespace bestla From d334b6398cde1be9b978f25ece73e7e6f226808e Mon Sep 17 00:00:00 2001 From: "Zhe,Wang" Date: Tue, 23 Jan 2024 22:54:47 -0800 Subject: [PATCH 22/33] update benchmark --- bestla/bestla/ut/bestla_prologue_b.cpp | 32 +++---- bestla/bestla/ut/kernel_jit.cpp | 110 ++++++++++++------------- 2 files changed, 72 insertions(+), 70 deletions(-) diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 54da2b7d1..cb7dc1f61 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -190,12 +190,12 @@ class UT_BlockQunatize_S3 { // ut(1, 4096, 4096, 32, 56); // ut(128, 16384, 4096, 32, 56); // ut(1, 16384, 4096, 32, 56); - ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, 56); - ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, 56); - ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, 56); - ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 56); - ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, 56); - ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, 56); + ut, BTLA_ISA::AMX_INT8>(1024, 4096, 16384, 128, 8); + ut, BTLA_ISA::AMX_INT8>(48, 4096, 16384, 128, 8); + ut, BTLA_ISA::AMX_INT8>(1024, 4096, 4096, 128, 8); + ut, BTLA_ISA::AMX_INT8>(48, 4096, 4096, 128, 8); + ut, BTLA_ISA::AMX_INT8>(1024, 16384, 4096, 128, 8); + ut, BTLA_ISA::AMX_INT8>(48, 16384, 4096, 128, 8); // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 16384, 128, 56); // ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 16384, 128, 56); // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 4096, 128, 56); @@ -215,12 +215,12 @@ class UT_BlockQunatize_S3 { // ut(1, 4096, 4096, 32, 64); // ut(128, 16384, 4096, 32, 64); // ut(1, 16384, 4096, 32, 64); - ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, 64); - ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, 64); - ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, 64); - ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 64); - ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, 64); - ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, 64); + // ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, 64); + // ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, 64); + // ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, 64); + // ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 64); + // ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, 64); + // ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, 64); // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 16384, 128, 64); // ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 16384, 128, 64); // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 4096, 128, 64); @@ -881,8 +881,10 @@ class UTBenchmark_CompFp32 { void ut_s4() { // benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); - benchmark_all(32, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); - benchmark_all(32, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + benchmark_all(48, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); + benchmark_all(48, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + benchmark_all(1024, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); + benchmark_all(1024, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(2048, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(4096, 4096, 11008, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); @@ -1054,7 +1056,7 @@ class UTBenchmark_CompFp32 { // C.data(), testtime, 48, qtype); // benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), - B.data(), C.data(), testtime, 48, qtype); + B.data(), C.data(), testtime, 8, qtype); } } }; diff --git a/bestla/bestla/ut/kernel_jit.cpp b/bestla/bestla/ut/kernel_jit.cpp index 33aa5d4b3..1c195472f 100644 --- a/bestla/bestla/ut/kernel_jit.cpp +++ b/bestla/bestla/ut/kernel_jit.cpp @@ -297,65 +297,65 @@ class UT_DecompressS4S8 { static UT_DecompressS4S8 sUT_DecompressS4S8; #endif -class UT_DecompressS3S8 { - public: - UT_DecompressS3S8() { - UT_START(); - ut(1, 256); - } - void ut(int row, int col) { - ut::UT_vector_s8 quanW; - quanW.resize(row * col); - quanW.fill_rand(-8, 7); - for (int i = 0; i < row * col; i++) quanW.data()[i] = (quanW.data()[i] * 16) & 0xe0; - std::vector bit2_data(row * col / 4); - std::vector bit1_data(row * col / 8); - std::vector ref(row * col); - std::vector tar(row * col); +// class UT_DecompressS3S8 { +// public: +// UT_DecompressS3S8() { +// UT_START(); +// ut(1, 256); +// } +// void ut(int row, int col) { +// ut::UT_vector_s8 quanW; +// quanW.resize(row * col); +// quanW.fill_rand(-8, 7); +// for (int i = 0; i < row * col; i++) quanW.data()[i] = (quanW.data()[i] * 16) & 0xe0; +// std::vector bit2_data(row * col / 4); +// std::vector bit1_data(row * col / 8); +// std::vector ref(row * col); +// std::vector tar(row * col); - auto bit2_interleave = [&](int8_t* src, int8_t* dst) { - for (int i = 0; i < 128 / 4; i++) { - dst[4 * i] = src[i]; - dst[4 * i + 1] = src[128 / 4 + i]; - dst[4 * i + 2] = src[128 / 4 * 2 + i]; - dst[4 * i + 3] = src[128 / 4 * 3 + i]; - } - }; +// auto bit2_interleave = [&](int8_t* src, int8_t* dst) { +// for (int i = 0; i < 128 / 4; i++) { +// dst[4 * i] = src[i]; +// dst[4 * i + 1] = src[128 / 4 + i]; +// dst[4 * i + 2] = src[128 / 4 * 2 + i]; +// dst[4 * i + 3] = src[128 / 4 * 3 + i]; +// } +// }; - int8_t interleave_buf[128]; +// int8_t interleave_buf[128]; - for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j += 128) { - bit2_interleave(const_cast(quanW.data() + i * col + j), interleave_buf); - for (int k = 0; k < 32; k++) { - bit2_data[i * col / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5; - bit2_data[i * col / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5; - bit2_data[i * col / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5; - bit2_data[i * col / 4 + j / 4 + k].d = interleave_buf[4 * k + 3] >> 5; - } - } - } - // store 1 bit without interleave as mask. - for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j += 8) { - bit1_data[i * col / 8 + j / 8].a = quanW.data()[i * col + j] >> 7; - bit1_data[i * col / 8 + j / 8].b = quanW.data()[i * col + j + 1] >> 7; - bit1_data[i * col / 8 + j / 8].c = quanW.data()[i * col + j + 2] >> 7; - bit1_data[i * col / 8 + j / 8].d = quanW.data()[i * col + j + 3] >> 7; - bit1_data[i * col / 8 + j / 8].e = quanW.data()[i * col + j + 4] >> 7; - bit1_data[i * col / 8 + j / 8].f = quanW.data()[i * col + j + 5] >> 7; - bit1_data[i * col / 8 + j / 8].g = quanW.data()[i * col + j + 6] >> 7; - bit1_data[i * col / 8 + j / 8].h = quanW.data()[i * col + j + 7] >> 7; - } - } +// for (int i = 0; i < row; i++) { +// for (int j = 0; j < col; j += 128) { +// bit2_interleave(const_cast(quanW.data() + i * col + j), interleave_buf); +// for (int k = 0; k < 32; k++) { +// bit2_data[i * col / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5; +// bit2_data[i * col / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5; +// bit2_data[i * col / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5; +// bit2_data[i * col / 4 + j / 4 + k].d = interleave_buf[4 * k + 3] >> 5; +// } +// } +// } +// // store 1 bit without interleave as mask. +// for (int i = 0; i < row; i++) { +// for (int j = 0; j < col; j += 8) { +// bit1_data[i * col / 8 + j / 8].a = quanW.data()[i * col + j] >> 7; +// bit1_data[i * col / 8 + j / 8].b = quanW.data()[i * col + j + 1] >> 7; +// bit1_data[i * col / 8 + j / 8].c = quanW.data()[i * col + j + 2] >> 7; +// bit1_data[i * col / 8 + j / 8].d = quanW.data()[i * col + j + 3] >> 7; +// bit1_data[i * col / 8 + j / 8].e = quanW.data()[i * col + j + 4] >> 7; +// bit1_data[i * col / 8 + j / 8].f = quanW.data()[i * col + j + 5] >> 7; +// bit1_data[i * col / 8 + j / 8].g = quanW.data()[i * col + j + 6] >> 7; +// bit1_data[i * col / 8 + j / 8].h = quanW.data()[i * col + j + 7] >> 7; +// } +// } - kernel::avx512f::decompress_kblock_s3_s8fp(bit2_data.data(), bit1_data.data(), - ref.data(), 0, row * col, nullptr, -1); - kernel::jit::DecompresssS3::forward_avx512f(bit2_data.data(), bit1_data.data(), tar.data(), row * col); - buffer_error(tar.data(), ref.data(), ref.size()); - } -}; -static UT_DecompressS3S8 sUT_DecompressS3S8; +// kernel::avx512f::decompress_kblock_s3_s8fp(bit2_data.data(), bit1_data.data(), +// ref.data(), 0, row * col, nullptr, -1); +// kernel::jit::DecompresssS3::forward_avx512f(bit2_data.data(), bit1_data.data(), tar.data(), row * col); +// buffer_error(tar.data(), ref.data(), ref.size()); +// } +// }; +// static UT_DecompressS3S8 sUT_DecompressS3S8; } // namespace ut } // namespace bestla From caf342a8aa2c7052ae7325c763aadf1c66f31874 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Thu, 25 Jan 2024 14:04:08 +0800 Subject: [PATCH 23/33] bit3 quant-param in ns --- bestla/CMakeLists.txt | 2 +- neural_speed/models/model_utils/quant_config.h | 5 ++++- neural_speed/models/model_utils/quant_utils.cpp | 3 +++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/bestla/CMakeLists.txt b/bestla/CMakeLists.txt index 834e61160..2b17a8603 100644 --- a/bestla/CMakeLists.txt +++ b/bestla/CMakeLists.txt @@ -7,7 +7,7 @@ file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) option(BTLA_USE_OPENMP "Enable OpenMP thread pool" OFF) option(BTLA_UT_ALL "Enable all unit tests" OFF) -option(BTLA_UT_DEBUG "Enable debug unit tests" ON) +option(BTLA_UT_DEBUG "Enable debug unit tests" OFF) option(BTLA_UT_EPILOGUE "Enable unit test for epilogue" OFF) option(BTLA_UT_PROLOGUE_A "Enable unit test for activation prologue" OFF) option(BTLA_UT_PROLOGUE_B "Enable unit test for weight prologue" OFF) diff --git a/neural_speed/models/model_utils/quant_config.h b/neural_speed/models/model_utils/quant_config.h index f470c7598..e6ee60040 100644 --- a/neural_speed/models/model_utils/quant_config.h +++ b/neural_speed/models/model_utils/quant_config.h @@ -18,8 +18,11 @@ #include "core/data_types.h" #include "bestla/bestla.h" -enum class quant_bits : int { q4 = 0, q8, fp4_e2m1, nf4, fp8_e4m3, fp8_e5m2, count }; +enum class quant_bits : int { q4 = 0, q3, q8, fp4_e2m1, nf4, fp8_e4m3, fp8_e5m2, count }; static inline quant_bits parse_bits(const std::string& bits) { + if (bits == "int3") { + return quant_bits::q3; + } if (bits == "int4") { return quant_bits::q4; } diff --git a/neural_speed/models/model_utils/quant_utils.cpp b/neural_speed/models/model_utils/quant_utils.cpp index 0eacbc771..5d68165f1 100644 --- a/neural_speed/models/model_utils/quant_utils.cpp +++ b/neural_speed/models/model_utils/quant_utils.cpp @@ -280,6 +280,9 @@ size_t bestla_quantize(const float* f32ptr, void* dstpr, const quant_params_inte bestla::parallel::StdThreading threading(nthread); #endif BTLA_DTYPE quant_type = BTLA_DTYPE::S4_CLIP; + if (params.bits == quant_bits::q3) { + quant_type = BTLA_DTYPE::S3_CLIP; + } if (params.bits == quant_bits::q8) { quant_type = BTLA_DTYPE::S8; } From bca3bb2367358fb2946fdb08e05476cccfd4e77e Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Fri, 26 Jan 2024 11:06:53 +0800 Subject: [PATCH 24/33] add s3_clip --- neural_speed/core/layers/bestla_gemm.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/neural_speed/core/layers/bestla_gemm.cpp b/neural_speed/core/layers/bestla_gemm.cpp index f2421f9d4..471dc0e03 100644 --- a/neural_speed/core/layers/bestla_gemm.cpp +++ b/neural_speed/core/layers/bestla_gemm.cpp @@ -264,6 +264,7 @@ size_t BTLAGemmPackBSize(size_t N, size_t K, size_t BlkSize, BTLA_DTYPE QuantTyp ne_comp_type CompType, int* shuffle_indice) { switch (QuantType) { case BTLA_DTYPE::S4_CLIP: + case BTLA_DTYPE::S3_CLIP: case BTLA_DTYPE::S4_FULLRANGE: case BTLA_DTYPE::S8: return BTLAGemmPackBSizeLocal(N, K, BlkSize, QuantType, ScaleDtype, @@ -364,6 +365,7 @@ bool BTLAGemmQuantPackB(void* PackedBuf, const float* FpData, size_t N, size_t K void* ThreadPool) { switch (QuantType) { case BTLA_DTYPE::S4_CLIP: + case BTLA_DTYPE::S3_CLIP: case BTLA_DTYPE::S4_FULLRANGE: case BTLA_DTYPE::S8: return BTLAGemmQuantPackBLocal( @@ -566,6 +568,7 @@ size_t BTLAGemmPackBSize(size_t N, size_t K, size_t BlkSize, BTLA_DTYPE QuantTyp ne_comp_type CompType, int* shuffle_indice) { switch (QuantType) { case BTLA_DTYPE::S4_CLIP: + case BTLA_DTYPE::S3_CLIP: case BTLA_DTYPE::S4_FULLRANGE: case BTLA_DTYPE::S8: return BTLAGemmPackBSizeLocal(N, K, BlkSize, QuantType, ScaleDtype, @@ -589,6 +592,7 @@ bool BTLAGemmQuantPackB(void* PackedBuf, const float* FpData, size_t N, size_t K void* ThreadPool) { switch (QuantType) { case BTLA_DTYPE::S4_CLIP: + case BTLA_DTYPE::S3_CLIP: case BTLA_DTYPE::S4_FULLRANGE: case BTLA_DTYPE::S8: return BTLAGemmQuantPackBLocal( @@ -611,6 +615,7 @@ bool BTLAGemmPackB(void* PackedBuf, const int8_t* QData, const float* Scales, co ne_comp_type CompType, int* shuffle_indice, void* ThreadPool) { // only for integer quant switch (QuantType) { + case BTLA_DTYPE::S3_CLIP: case BTLA_DTYPE::S4_CLIP: case BTLA_DTYPE::S4_FULLRANGE: case BTLA_DTYPE::S8: From 3a73350fad0b7bc6109fe976b3b633db0d634936 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Fri, 26 Jan 2024 17:49:26 +0800 Subject: [PATCH 25/33] support scalef32 --- bestla/CMakeLists.txt | 2 +- bestla/bestla/bestla_parallel.h | 1 + bestla/bestla/bestla_prologue_b.h | 21 +++++++++++++++++++-- bestla/bestla/ut/bestla_prologue_b.cpp | 4 ++-- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/bestla/CMakeLists.txt b/bestla/CMakeLists.txt index 2b17a8603..834e61160 100644 --- a/bestla/CMakeLists.txt +++ b/bestla/CMakeLists.txt @@ -7,7 +7,7 @@ file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) option(BTLA_USE_OPENMP "Enable OpenMP thread pool" OFF) option(BTLA_UT_ALL "Enable all unit tests" OFF) -option(BTLA_UT_DEBUG "Enable debug unit tests" OFF) +option(BTLA_UT_DEBUG "Enable debug unit tests" ON) option(BTLA_UT_EPILOGUE "Enable unit test for epilogue" OFF) option(BTLA_UT_PROLOGUE_A "Enable unit test for activation prologue" OFF) option(BTLA_UT_PROLOGUE_B "Enable unit test for weight prologue" OFF) diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index c780aa4cf..22c852f86 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -486,6 +486,7 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { mKBlock = config.problem.dims[4]; BaseScheduler::update(config); auto blks = utils::updiv(this->mBlock[2], mKBlock); + this->mBlock[2] = utils::padto(this->mBlock[2], 128); // TODO(zhe): remove it ,only for vnni 3bit linear this->mL2Use += static_cast(blks) * (this->mBlock[1] + this->mStep[0]) * (sizeof(float) + sizeof(int8_t) + sizeof(float)); // scale+zp+reduce assert(this->mL2Use <= this->mL2Size - ReservedSize); diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 315b31909..e3e42b85c 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -190,12 +190,15 @@ class WeightKBlockNInteger { void unpackWeight(const int N, const int K, StorageWeight* stor, float* B, const int ldb, parallel::IThreading* threading) { - parallel::Scheduler2D _para({threading->num_threads(), K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE}); + // parallel::Scheduler2D _para({threading->num_threads(), K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE}); + parallel::Scheduler2D _para({threading->num_threads(), K, N, 32, + _GemmCore_T::NTILE}); // TODO(zhe): remove it, only for 3bit vnni woq linear threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); if (thdp.valid) { - auto rowpad = utils::padto(thdp.size[0], _GemmCore_T::KTILE); + // auto rowpad = utils::padto(thdp.size[0], _GemmCore_T::KTILE); + auto rowpad = utils::padto(thdp.size[0], 32); // as above auto colpad = utils::padto(thdp.size[1], _GemmCore_T::NTILE); auto dequant = utils::amalloc((size_t)rowpad * colpad); auto dstptr = dequant; @@ -799,6 +802,20 @@ class WeightKBlockNInteger { *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { + int8_t* bit3_ptr = wptr->template WPtr(); + auto elt_offset = + n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 64); + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64); + auto row = NPad / _GemmCore_T::NTILE; + assert(elt_offset % 8 == 0); + auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); + auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); + kernel::wrapper::DecompressKBlockS3Fp<_T, _GemmCore_T::PACK_ROW>::template forward( + bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, k_size / _GemmCore_T::PACK_ROW, + ColSize, sptr, zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW, + wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize); } else { assert(0); } diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index cb7dc1f61..12ee5a235 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -246,8 +246,8 @@ class UT_BlockQunatize_S3 { using PrologueB = prologue_b::gemm::WeightKBlockNInteger; PrologueB kernel; - auto ptr = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::BF16, BTLA_DTYPE::F32, false); - auto ptr_ref = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::BF16, BTLA_DTYPE::F32, false); + auto ptr = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S3_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::F32, false); + auto ptr_ref = kernel.createStorage(n, k, blocksize, BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::F32, false); avector buffer(ptr.mSize); avector buffer_ref(ptr_ref.mSize); ptr.assign(buffer.data()); From 58950176566f372a21fc38a627959a3f0a626ce1 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Fri, 26 Jan 2024 19:38:29 +0800 Subject: [PATCH 26/33] remove repeat code --- bestla/bestla/bestla_prologue_b.h | 62 +++++++------------------------ 1 file changed, 14 insertions(+), 48 deletions(-) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index e3e42b85c..393bf569f 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -715,54 +715,20 @@ class WeightKBlockNInteger { auto KPad = wptr->mKPad; int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - if (wptr->SDtype() == BTLA_DTYPE::F32) { - auto sptr = wptr->template SPtr() + n_offset + i; - if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S8) { - kernel::wrapper::DecompressKBlockS8S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } - } else if (wptr->SDtype() == BTLA_DTYPE::BF16) { - auto sptr = wptr->template SPtr() + n_offset + i; - if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { - kernel::wrapper::DecompressKBlockS4S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + - i * KPad / 2, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S8) { - kernel::wrapper::DecompressKBlockS8S8Fp::template forward( - wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, - *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); - } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { - int8_t* bit3_ptr = wptr->template WPtr(); - auto elt_offset = - n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 64); - auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64); - auto row = NPad / _GemmCore_T::NTILE; - assert(elt_offset % 8 == 0); - auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); - auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); - kernel::wrapper::DecompressKBlockS3S8Fp::template forward( - bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, - k_size / _GemmCore_T::PACK_ROW * ColSize, tmpcache, cachesize); - } else { - assert(0); - } + if (wptr->mDType == BTLA_DTYPE::S4_CLIP) { + kernel::wrapper::DecompressKBlockS4S8Fp::template forward( + wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + + i * KPad / 2, + *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) { + kernel::wrapper::DecompressKBlockS4S8Fp::template forward( + wptr->template WPtr() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 + + i * KPad / 2, + *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S8) { + kernel::wrapper::DecompressKBlockS8S8Fp::template forward( + wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, + *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); } } *dststep = k_size; From e74541fe4c563c9e4081990a402d1448a0cce517 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Tue, 30 Jan 2024 09:54:51 +0800 Subject: [PATCH 27/33] jit unroll store 2-zmm --- bestla/bestla/kernel_jit.h | 34 ++++---------------------- bestla/bestla/ut/bestla_prologue_b.cpp | 12 ++++----- 2 files changed, 11 insertions(+), 35 deletions(-) diff --git a/bestla/bestla/kernel_jit.h b/bestla/bestla/kernel_jit.h index d1667c4f6..b00e99f87 100644 --- a/bestla/bestla/kernel_jit.h +++ b/bestla/bestla/kernel_jit.h @@ -293,16 +293,12 @@ class DecompresssS3 { mov(reg_bit2ptr, ptr[parambase + OFFSET(bit2ptr)]); mov(reg_dst, ptr[parambase + OFFSET(dstptr)]); L("loop_label"); - imul(reg_tmp, reg_iter, 32); + imul(reg_tmp, reg_iter, 16); kmovq(bit1_mask1, ptr[reg_bit1ptr + reg_tmp]); kmovq(bit1_mask2, ptr[reg_bit1ptr + reg_tmp + 8]); - kmovq(bit1_mask3, ptr[reg_bit1ptr + reg_tmp + 16]); - kmovq(bit1_mask4, ptr[reg_bit1ptr + reg_tmp + 24]); Xbyak::Zmm bit2_data_zmm = zmm0; - imul(reg_tmp, reg_iter, 64); - vmovups(bit2_data_zmm, ptr[reg_bit2ptr + reg_tmp]); - vextractf32x8(ymm2, bit2_data_zmm, 0x0); - vextractf32x8(ymm3, bit2_data_zmm, 0x1); + imul(reg_tmp, reg_iter, 32); + vmovups(ymm2, ptr[reg_bit2ptr + reg_tmp]); vpand(ymm4, LowMask, ymm2); vpsrlw(ymm2, ymm2, 2); @@ -314,40 +310,20 @@ class DecompresssS3 { vinsertf32x8(zmm4, zmm4, ymm5, 1); vinsertf32x8(zmm6, zmm6, ymm7, 1); - vpand(ymm8, LowMask, ymm3); - vpsrlw(ymm3, ymm3, 2); - vpand(ymm9, LowMask, ymm3); - vpsrlw(ymm3, ymm3, 2); - vpand(ymm10, LowMask, ymm3); - vpsrlw(ymm3, ymm3, 2); - vpand(ymm11, LowMask, ymm3); - vinsertf32x8(zmm8, zmm8, ymm9, 1); - vinsertf32x8(zmm10, zmm10, ymm11, 1); - vxorps(zmm12, zmm12); vxorps(zmm13, zmm13); - vxorps(zmm14, zmm14); - vxorps(zmm15, zmm15); vmovdqu8(zmm12 | bit1_mask1, zmm_0x04); vmovdqu8(zmm13 | bit1_mask2, zmm_0x04); - vmovdqu8(zmm14 | bit1_mask3, zmm_0x04); - vmovdqu8(zmm15 | bit1_mask4, zmm_0x04); vpaddb(zmm4, zmm4, zmm12); vpaddb(zmm6, zmm6, zmm13); - vpaddb(zmm8, zmm8, zmm14); - vpaddb(zmm10, zmm10, zmm15); vpsllvd(zmm4, zmm4, zmm_shift); vpsllvd(zmm6, zmm6, zmm_shift); - vpsllvd(zmm8, zmm8, zmm_shift); - vpsllvd(zmm10, zmm10, zmm_shift); - imul(reg_tmp, reg_iter, 256); + imul(reg_tmp, reg_iter, 128); vmovups(ptr[reg_dst + reg_tmp], zmm4); vmovups(ptr[reg_dst + reg_tmp + 64], zmm6); - vmovups(ptr[reg_dst + reg_tmp + 128], zmm8); - vmovups(ptr[reg_dst + reg_tmp + 192], zmm10); add(reg_iter, 1); cmp(reg_iter, reg_loop); @@ -375,7 +351,7 @@ class DecompresssS3 { }; static void forward_avx512f(void* bit2ptr, void* bit1ptr, void* dstptr, int unpack_elt) { static MicroKernelAVX512F ker; - auto param = MicroKernelAVX512F::params{bit2ptr, bit1ptr, dstptr, unpack_elt / 256, 0x03, 0x4, 5}; + auto param = MicroKernelAVX512F::params{bit2ptr, bit1ptr, dstptr, unpack_elt / 128, 0x03, 0x4, 5}; ker.mKernel(¶m); } }; diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 12ee5a235..5e83e4dd0 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -323,7 +323,7 @@ class UT_BlockQunatize_S3 { // } } }; -static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; +// static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; #ifdef BTLA_UT_PROLOGUE_B static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; #endif @@ -881,10 +881,10 @@ class UTBenchmark_CompFp32 { void ut_s4() { // benchmark_all(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); - benchmark_all(48, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); - benchmark_all(48, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); - benchmark_all(1024, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); - benchmark_all(1024, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + benchmark_all(32, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); + benchmark_all(32, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); + // benchmark_all(1024, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP); + // benchmark_all(1024, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(2048, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(4096, 4096, 11008, 128, BTLA_DTYPE::S4_CLIP); // benchmark_all(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE); @@ -1056,7 +1056,7 @@ class UTBenchmark_CompFp32 { // C.data(), testtime, 48, qtype); // benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), B.data(), benchmark_mem, LOG, Wei, Scale_T>(m, n, k, blocksize, batch, A.data(), - B.data(), C.data(), testtime, 8, qtype); + B.data(), C.data(), testtime, 56, qtype); } } }; From ed378fec31d82d74d81ea841822d69ea65c609a1 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Tue, 30 Jan 2024 13:31:17 +0800 Subject: [PATCH 28/33] jit getfpkblockwei & fix 3bit getfpwei in proB --- bestla/bestla/bestla_prologue_b.h | 14 +++++++++++++ bestla/bestla/kernel_jit.h | 35 ++++++++++++++++++++++++------- bestla/bestla/kernel_wrapper.h | 15 ++++++------- 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 393bf569f..1d6422911 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -729,6 +729,20 @@ class WeightKBlockNInteger { kernel::wrapper::DecompressKBlockS8S8Fp::template forward( wptr->template WPtr() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize); + } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { + int8_t* bit3_ptr = wptr->template WPtr(); + auto elt_offset = + n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 64); + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64); + auto row = NPad / _GemmCore_T::NTILE; + assert(elt_offset % 8 == 0); + auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); + auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); + kernel::wrapper::DecompressKBlockS3S8Fp::template forward( + bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE, + k_size / _GemmCore_T::PACK_ROW * ColSize, tmpcache, cachesize); + } else { + assert(0); } } *dststep = k_size; diff --git a/bestla/bestla/kernel_jit.h b/bestla/bestla/kernel_jit.h index b00e99f87..b19040f2c 100644 --- a/bestla/bestla/kernel_jit.h +++ b/bestla/bestla/kernel_jit.h @@ -252,10 +252,11 @@ class DequanS8FP { }; class DecompresssS3 { public: + template class MicroKernelAVX512F : protected xbyak::JitAvx512f { public: struct params { - void *bit2ptr, *bit1ptr, *dstptr; + void *bit2ptr, *bit1ptr, *dstptr, *tmpbuf; int unpack_elt; int8_t ox3, ox4; int ox5; @@ -279,6 +280,7 @@ class DecompresssS3 { reg_iter = st.t[3]; reg_dst = st.t[4]; reg_tmp = st.t[5]; + reg_cache = st.t[6]; reg_ret = rax; xor_(reg_loop, reg_loop); mov(reg_loop.cvt32(), ptr[parambase + OFFSET(unpack_elt)]); @@ -292,6 +294,7 @@ class DecompresssS3 { mov(reg_bit1ptr, ptr[parambase + OFFSET(bit1ptr)]); mov(reg_bit2ptr, ptr[parambase + OFFSET(bit2ptr)]); mov(reg_dst, ptr[parambase + OFFSET(dstptr)]); + if constexpr (!std::is_same_v<_DST_T, int8_t>) mov(reg_cache, ptr[parambase + OFFSET(tmpbuf)]); L("loop_label"); imul(reg_tmp, reg_iter, 16); kmovq(bit1_mask1, ptr[reg_bit1ptr + reg_tmp]); @@ -321,9 +324,25 @@ class DecompresssS3 { vpsllvd(zmm4, zmm4, zmm_shift); vpsllvd(zmm6, zmm6, zmm_shift); - imul(reg_tmp, reg_iter, 128); - vmovups(ptr[reg_dst + reg_tmp], zmm4); - vmovups(ptr[reg_dst + reg_tmp + 64], zmm6); + if constexpr (std::is_same_v<_DST_T, int8_t>) { + imul(reg_tmp, reg_iter, 128); + vmovups(ptr[reg_dst + reg_tmp], zmm4); + vmovups(ptr[reg_dst + reg_tmp + 64], zmm6); + } else if constexpr (std::is_same_v<_DST_T, float> || std::is_same_v<_DST_T, utils::bf16>) { + vmovups(ptr[reg_cache], zmm4); + vmovups(ptr[reg_cache + 64], zmm6); + for (int i = 0; i < 8; i++) vpmovsxbd(Xbyak::Zmm(16 + i), ptr[reg_cache + 16 * i]); + for (int i = 0; i < 8; i++) vcvtdq2ps(Xbyak::Zmm(16 + i), Xbyak::Zmm(16 + i)); + imul(reg_tmp, reg_iter, 128 * sizeof(_DST_T)); + if constexpr (std::is_same_v<_DST_T, float>) { + for (int i = 0; i < 8; i++) vmovups(ptr[reg_dst + reg_tmp + i * 64], Xbyak::Zmm(16 + i)); + } else { + for (int i = 0; i < 8; i++) vcvtneps2bf16(Xbyak::Ymm(16 + i), Xbyak::Zmm(16 + i)); + for (int i = 0; i < 8; i++) vmovups(ptr[reg_dst + reg_tmp + i * 32], Xbyak::Ymm(16 + i)); + } + } else { + assert(0); + } add(reg_iter, 1); cmp(reg_iter, reg_loop); @@ -343,15 +362,17 @@ class DecompresssS3 { Xbyak::Reg64 reg_iter; Xbyak::Reg64 reg_dst; Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_cache; Xbyak::Reg64 reg_ret; Xbyak::Opmask bit1_mask1 = Xbyak::Opmask(1); Xbyak::Opmask bit1_mask2 = Xbyak::Opmask(2); Xbyak::Opmask bit1_mask3 = Xbyak::Opmask(3); Xbyak::Opmask bit1_mask4 = Xbyak::Opmask(4); }; - static void forward_avx512f(void* bit2ptr, void* bit1ptr, void* dstptr, int unpack_elt) { - static MicroKernelAVX512F ker; - auto param = MicroKernelAVX512F::params{bit2ptr, bit1ptr, dstptr, unpack_elt / 128, 0x03, 0x4, 5}; + template + static void forward_avx512f(void* bit2ptr, void* bit1ptr, _DST_T* dstptr, void* tmpbuf, int unpack_elt) { + static MicroKernelAVX512F<_DST_T> ker; + typename MicroKernelAVX512F<_DST_T>::params param{bit2ptr, bit1ptr, dstptr, tmpbuf, unpack_elt / 128, 0x03, 0x4, 5}; ker.mKernel(¶m); } }; diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index 26f69fc39..45e53228e 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -511,13 +511,14 @@ class DecompressKBlockS3S8Fp { int interleave_n_offset, int unpack_elt, void* tmp, size_t tmpsize) { BTLA_CODE ret = BTLA_CODE::NotSupport; #if CompileAVX512F() - if constexpr (utils::isa_base::avx512f && !std::is_same_v<_DST_T, int8_t>) { - ret = avx512f::decompress_kblock_s3_s8fp(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, - reinterpret_cast(tmp), tmpsize); - } else { - jit::DecompresssS3::forward_avx512f(bit2ptr, bit1ptr, dstptr, unpack_elt); - ret = BTLA_CODE::Success; - } + // if constexpr (utils::isa_base::avx512f && !std::is_same_v<_DST_T, int8_t>) { + // ret = avx512f::decompress_kblock_s3_s8fp(bit2ptr, bit1ptr, dstptr, interleave_n_offset, + // unpack_elt, + // reinterpret_cast(tmp), tmpsize); + // } else { + jit::DecompresssS3::forward_avx512f(bit2ptr, bit1ptr, dstptr, tmp, unpack_elt); + ret = BTLA_CODE::Success; + // } #endif assert(ret == BTLA_CODE::Success); return ret; From dd2a00030157f19cdae2d2cc15393b679eb8ca3b Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Tue, 30 Jan 2024 15:27:31 +0800 Subject: [PATCH 29/33] head/tail process --- bestla/bestla/kernel_avx512f.h | 111 ++++++-------------------------- bestla/bestla/ut/kernel_jit.cpp | 61 ------------------ 2 files changed, 21 insertions(+), 151 deletions(-) diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index ea031e912..d59dcc32e 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -14,6 +14,7 @@ #pragma once #include "bestla.h" #include "bestla_utils.h" +#include "kernel_jit.h" #include "kernel_ref.h" #include @@ -647,34 +648,11 @@ static inline BTLA_CODE decompress_kblock_s4_fp(utils::int4x2* srcptr, _DST_T* d template inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, _DST_T* dstptr, int interleave_n_offset, int unpack_elt, int8_t* tmp, size_t tmpsize) { - auto head_ignore_num = interleave_n_offset % 64; + auto head_ignore_num = interleave_n_offset % 128; auto zmm_0x04 = _mm512_set1_epi8(0x04); auto zmm_0x00 = _mm512_set1_epi8(0x00); auto zmm_shift = _mm512_set1_epi32(5); - // auto bit3_interleave_decompress = [&](__m128i bit2_data, utils::bit1x8* src2) { - auto bit3_interleave_decompress = [&](utils::bit2x4* src1, utils::bit1x8* src2) { - const __m128i lowMask = _mm_set1_epi8(0x03); - const __m128i bit2_data = _mm_loadu_si128((const __m128i*)src1); - // __m128i bit2_data; - auto xmm0 = _mm_and_si128(lowMask, bit2_data); // uop:1 p:015 - auto xmm1 = _mm_and_si128(lowMask, _mm_srli_epi16(bit2_data, 2)); // uop:1 p:01 - auto xmm2 = _mm_and_si128(lowMask, _mm_srli_epi16(bit2_data, 4)); - auto xmm3 = _mm_and_si128(lowMask, _mm_srli_epi16(bit2_data, 6)); - auto ymm1 = _mm256_set_m128i(xmm1, xmm0); // uop:1 p5 - auto ymm2 = _mm256_set_m128i(xmm3, xmm2); - auto zmm = _mm512_inserti32x8(_mm512_castsi256_si512(ymm1), ymm2, 0x1); - // make bit1-storage as mmask64, then cvt the mask to the int8-value. - unsigned long long* bit1_ptr = reinterpret_cast(src2); - auto bit1_mask = _cvtu64_mask64(*bit1_ptr); - // __mmask64 bit1_mask; - // __m512i zmm; - auto zmm2 = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask, zmm_0x04); - zmm = _mm512_add_epi8(zmm, zmm2); - zmm = _mm512_sllv_epi32(zmm, zmm_shift); // int3_clip => int8 - return zmm; - }; - auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) { const __m256i lowMask = _mm256_set1_epi8(0x03); const __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1); @@ -701,78 +679,31 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 assert(head_ignore_num % 8 == 0); - if (head_ignore_num > unpack_elt) { - assert(0); - } auto base_bit2ptr = bit2ptr - head_ignore_num / 4; auto base_bit1ptr = bit1ptr - head_ignore_num / 8; - int8_t* base_unpack_buf; - if constexpr (std::is_same_v<_DST_T, int8_t>) { - base_unpack_buf = dstptr - head_ignore_num; - } else { - base_unpack_buf = tmp - head_ignore_num; - } int compress_wei_ptr_offset = 0; + int8_t* s8_ptr = reinterpret_cast(tmp); + auto head_write_num = 128 - head_ignore_num; if (head_ignore_num != 0) { - assert(0); - // auto unpack_mask = _cvtu64_mask64(0xffffffffffffffff << head_ignore_num); - // auto head_zmm = bit3_interleave_decompress(base_bit2ptr, base_bit1ptr); - // _mm512_mask_storeu_epi8(base_unpack_buf, unpack_mask, head_zmm); - // compress_wei_ptr_offset += 64; - } - auto body_loop = (unpack_elt - (64 - head_ignore_num) % 64) / 64; - auto tail_proc_num = (unpack_elt - (64 - head_ignore_num) % 64) % 64; - - // for (int i = 0; i < body_loop; i++) { - for (int i = 0; i < body_loop / 2; i++) { - if constexpr (!std::is_same_v<_DST_T, int8_t>) { - bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 4, - base_bit1ptr + compress_wei_ptr_offset / 8, - reinterpret_cast(base_unpack_buf)); - for (int j = 0; j < 128; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j); - // auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, - // base_bit1ptr + compress_wei_ptr_offset / 8); - // _mm512_storeu_epi8(base_unpack_buf, zmm); - // auto xmm1 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf)); - // auto xmm2 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf + 16)); - // auto xmm3 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf + 32)); - // auto xmm4 = _mm_loadu_si128(reinterpret_cast(base_unpack_buf + 48)); - // auto zmm1 = _mm512_cvtepi8_epi32(xmm1); - // auto zmm2 = _mm512_cvtepi8_epi32(xmm2); - // auto zmm3 = _mm512_cvtepi8_epi32(xmm3); - // auto zmm4 = _mm512_cvtepi8_epi32(xmm4); - // if constexpr (std::is_same_v<_DST_T, utils::bf16>) { - // auto ymm1 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm1)); - // auto ymm2 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm2)); - // auto ymm3 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm3)); - // auto ymm4 = zmm_cvt_fp32_bf16(_mm512_cvtepi32_ps(zmm4)); - // _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset), ymm1); - // _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset + 16), ymm2); - // _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset + 32), ymm3); - // _mm256_storeu_si256(reinterpret_cast<__m256i*>(dstptr + compress_wei_ptr_offset + 48), ymm4); - // } else { - // _mm512_storeu_ps(dstptr + compress_wei_ptr_offset, _mm512_cvtepi32_ps(zmm1)); - // _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 16, _mm512_cvtepi32_ps(zmm2)); - // _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 32, _mm512_cvtepi32_ps(zmm3)); - // _mm512_storeu_ps(dstptr + compress_wei_ptr_offset + 48, _mm512_cvtepi32_ps(zmm4)); - // } - } else { - // auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, - // base_bit1ptr + compress_wei_ptr_offset / 8); - bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 4, - base_bit1ptr + compress_wei_ptr_offset / 8, - reinterpret_cast(base_unpack_buf) + compress_wei_ptr_offset); - // _mm512_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, zmm); - } - // compress_wei_ptr_offset += 64; - compress_wei_ptr_offset += 128; + printf("head process\n"); + bit3_interleave_decompress_pack128(base_bit2ptr, base_bit1ptr, tmp); + for (int i = 0; i < head_write_num; i++) dstptr[i] = s8_ptr[head_ignore_num + i]; + compress_wei_ptr_offset += head_write_num; } + + auto body_loop = (unpack_elt - head_write_num % 128) / 128; + auto tail_proc_num = (unpack_elt - head_write_num % 128) % 128; + + bestla::kernel::jit::DecompresssS3::forward_avx512f(bit2ptr + compress_wei_ptr_offset / 4, + bit1ptr + compress_wei_ptr_offset / 8, + dstptr + compress_wei_ptr_offset, tmp, body_loop * 128); + compress_wei_ptr_offset += body_loop * 128; if (tail_proc_num > 0) { - assert(0); - // auto unpack_mask = _cvtu64_mask64(0xffffffffffffffff >> (64 - tail_proc_num)); - // auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4, - // base_bit1ptr + compress_wei_ptr_offset / 8); - // _mm512_mask_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, unpack_mask, zmm); + printf("tail process\n"); + bit3_interleave_decompress_pack128(base_bit2ptr, base_bit1ptr, tmp); + bit3_interleave_decompress_pack128(bit2ptr + compress_wei_ptr_offset / 4, bit1ptr + compress_wei_ptr_offset / 8, + tmp); + for (int i = 0; i < tail_proc_num; i++) dstptr[compress_wei_ptr_offset + i] = s8_ptr[i]; } return BTLA_CODE::Success; } diff --git a/bestla/bestla/ut/kernel_jit.cpp b/bestla/bestla/ut/kernel_jit.cpp index 1c195472f..ce1198c99 100644 --- a/bestla/bestla/ut/kernel_jit.cpp +++ b/bestla/bestla/ut/kernel_jit.cpp @@ -296,66 +296,5 @@ class UT_DecompressS4S8 { #ifdef BTLA_UT_KERNEL_JIT static UT_DecompressS4S8 sUT_DecompressS4S8; #endif - -// class UT_DecompressS3S8 { -// public: -// UT_DecompressS3S8() { -// UT_START(); -// ut(1, 256); -// } -// void ut(int row, int col) { -// ut::UT_vector_s8 quanW; -// quanW.resize(row * col); -// quanW.fill_rand(-8, 7); -// for (int i = 0; i < row * col; i++) quanW.data()[i] = (quanW.data()[i] * 16) & 0xe0; -// std::vector bit2_data(row * col / 4); -// std::vector bit1_data(row * col / 8); -// std::vector ref(row * col); -// std::vector tar(row * col); - -// auto bit2_interleave = [&](int8_t* src, int8_t* dst) { -// for (int i = 0; i < 128 / 4; i++) { -// dst[4 * i] = src[i]; -// dst[4 * i + 1] = src[128 / 4 + i]; -// dst[4 * i + 2] = src[128 / 4 * 2 + i]; -// dst[4 * i + 3] = src[128 / 4 * 3 + i]; -// } -// }; - -// int8_t interleave_buf[128]; - -// for (int i = 0; i < row; i++) { -// for (int j = 0; j < col; j += 128) { -// bit2_interleave(const_cast(quanW.data() + i * col + j), interleave_buf); -// for (int k = 0; k < 32; k++) { -// bit2_data[i * col / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5; -// bit2_data[i * col / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5; -// bit2_data[i * col / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5; -// bit2_data[i * col / 4 + j / 4 + k].d = interleave_buf[4 * k + 3] >> 5; -// } -// } -// } -// // store 1 bit without interleave as mask. -// for (int i = 0; i < row; i++) { -// for (int j = 0; j < col; j += 8) { -// bit1_data[i * col / 8 + j / 8].a = quanW.data()[i * col + j] >> 7; -// bit1_data[i * col / 8 + j / 8].b = quanW.data()[i * col + j + 1] >> 7; -// bit1_data[i * col / 8 + j / 8].c = quanW.data()[i * col + j + 2] >> 7; -// bit1_data[i * col / 8 + j / 8].d = quanW.data()[i * col + j + 3] >> 7; -// bit1_data[i * col / 8 + j / 8].e = quanW.data()[i * col + j + 4] >> 7; -// bit1_data[i * col / 8 + j / 8].f = quanW.data()[i * col + j + 5] >> 7; -// bit1_data[i * col / 8 + j / 8].g = quanW.data()[i * col + j + 6] >> 7; -// bit1_data[i * col / 8 + j / 8].h = quanW.data()[i * col + j + 7] >> 7; -// } -// } - -// kernel::avx512f::decompress_kblock_s3_s8fp(bit2_data.data(), bit1_data.data(), -// ref.data(), 0, row * col, nullptr, -1); -// kernel::jit::DecompresssS3::forward_avx512f(bit2_data.data(), bit1_data.data(), tar.data(), row * col); -// buffer_error(tar.data(), ref.data(), ref.size()); -// } -// }; -// static UT_DecompressS3S8 sUT_DecompressS3S8; - } // namespace ut } // namespace bestla From d4d4f61d8c2f5dd926035b08cd2bb83085daeecb Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Tue, 30 Jan 2024 15:33:20 +0800 Subject: [PATCH 30/33] clean code --- bestla/bestla/kernel_ref.h | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/bestla/bestla/kernel_ref.h b/bestla/bestla/kernel_ref.h index 08aa3358f..41231d1ec 100644 --- a/bestla/bestla/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -181,16 +181,7 @@ static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, i // #include static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int row, int col, int ld_src, int ld_dst) { - assert(col % 64 == 0); - // interleave + store 2bit. - - // for (int i = 0; i < row * col; i++) { - // char tmp; - // memcpy(&tmp, &srcptr[i], 1); - // tmp &= 0xe0; - // std::cout << int(*(reinterpret_cast(&tmp))) << std::endl; - // } - // std::cout << "==============" << std::endl; + assert(col % 128 == 0); auto bit2_interleave = [&](int8_t* src, int8_t* dst) { for (int i = 0; i < 128 / 4; i++) { @@ -201,24 +192,12 @@ static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x } }; - // auto bit2_interleave = [&](int8_t* src, int8_t* dst) { - // for (int i = 0; i < 64 / 4; i++) { - // dst[4 * i] = src[i]; - // dst[4 * i + 1] = src[64 / 4 + i]; - // dst[4 * i + 2] = src[64 / 4 * 2 + i]; - // dst[4 * i + 3] = src[64 / 4 * 3 + i]; - // } - // }; - - int8_t interleave_buf[128]; for (int i = 0; i < row; i++) { for (int j = 0; j < col; j += 128) { - // for (int j = 0; j < col; j += 64) { bit2_interleave(const_cast(srcptr + i * ld_src + j), interleave_buf); for (int k = 0; k < 32; k++) { - // for (int k = 0; k < 16; k++) { bit2ptr[i * ld_dst / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5; bit2ptr[i * ld_dst / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5; bit2ptr[i * ld_dst / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5; From c381892d7967759c44d2500da8863c1816b30238 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Tue, 30 Jan 2024 16:18:06 +0800 Subject: [PATCH 31/33] update kpad & head/tail process pass ut --- bestla/bestla/bestla_parallel.h | 1 - bestla/bestla/bestla_prologue_b.h | 26 +++++++++++--------------- bestla/bestla/bestla_storage.h | 8 ++++---- bestla/bestla/kernel_avx512f.h | 2 -- 4 files changed, 15 insertions(+), 22 deletions(-) diff --git a/bestla/bestla/bestla_parallel.h b/bestla/bestla/bestla_parallel.h index 22c852f86..c780aa4cf 100644 --- a/bestla/bestla/bestla_parallel.h +++ b/bestla/bestla/bestla_parallel.h @@ -486,7 +486,6 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { mKBlock = config.problem.dims[4]; BaseScheduler::update(config); auto blks = utils::updiv(this->mBlock[2], mKBlock); - this->mBlock[2] = utils::padto(this->mBlock[2], 128); // TODO(zhe): remove it ,only for vnni 3bit linear this->mL2Use += static_cast(blks) * (this->mBlock[1] + this->mStep[0]) * (sizeof(float) + sizeof(int8_t) + sizeof(float)); // scale+zp+reduce assert(this->mL2Use <= this->mL2Size - ReservedSize); diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index 1d6422911..ac7b5a408 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -122,7 +122,6 @@ class WeightKBlockNInteger { bool is_asym) { int KPad = utils::padto(k, _GemmCore_T::KTILE); int NPad = utils::padto(n, _GemmCore_T::NTILE); - // if (qtype == BTLA_DTYPE::S3_CLIP) NPad = utils::padto(n, _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW); StorageWeight tmp(_GemmCore_T::ID); tmp.resize(NPad, KPad, blocksize <= 0 ? KPad : blocksize, n, k, qtype, scat, redt, is_asym); return tmp; @@ -190,15 +189,12 @@ class WeightKBlockNInteger { void unpackWeight(const int N, const int K, StorageWeight* stor, float* B, const int ldb, parallel::IThreading* threading) { - // parallel::Scheduler2D _para({threading->num_threads(), K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE}); - parallel::Scheduler2D _para({threading->num_threads(), K, N, 32, - _GemmCore_T::NTILE}); // TODO(zhe): remove it, only for 3bit vnni woq linear + parallel::Scheduler2D _para({threading->num_threads(), K, N, _GemmCore_T::KTILE, _GemmCore_T::NTILE}); threading->parallel_for([&](int tidx) { parallel::ThreadProblem2D thdp{tidx}; _para.getIndex(thdp); if (thdp.valid) { - // auto rowpad = utils::padto(thdp.size[0], _GemmCore_T::KTILE); - auto rowpad = utils::padto(thdp.size[0], 32); // as above + auto rowpad = utils::padto(thdp.size[0], _GemmCore_T::KTILE); auto colpad = utils::padto(thdp.size[1], _GemmCore_T::NTILE); auto dequant = utils::amalloc((size_t)rowpad * colpad); auto dstptr = dequant; @@ -732,8 +728,8 @@ class WeightKBlockNInteger { } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { int8_t* bit3_ptr = wptr->template WPtr(); auto elt_offset = - n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 64); - auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64); + n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 128); + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128); auto row = NPad / _GemmCore_T::NTILE; assert(elt_offset % 8 == 0); auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); @@ -785,8 +781,8 @@ class WeightKBlockNInteger { } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { int8_t* bit3_ptr = wptr->template WPtr(); auto elt_offset = - n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 64); - auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64); + n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 128); + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128); auto row = NPad / _GemmCore_T::NTILE; assert(elt_offset % 8 == 0); auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); @@ -826,8 +822,8 @@ class WeightKBlockNInteger { } else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) { int8_t* bit3_ptr = wptr->template WPtr(); auto elt_offset = - n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 64); - auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64); + n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE + i * utils::padto(KPad, 128); + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128); auto row = NPad / _GemmCore_T::NTILE; assert(elt_offset % 8 == 0); auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); @@ -901,10 +897,10 @@ class WeightKBlockNInteger { auto NPad = wptr->mNPad; int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW; auto row = NPad / _GemmCore_T::NTILE; - auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64); - auto base_offset = n_offset * utils::padto(KPad, 64) + k_offset * _GemmCore_T::NTILE; + auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 128); + auto base_offset = n_offset * utils::padto(KPad, 128) + k_offset * _GemmCore_T::NTILE; for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) { - auto elt_offset = base_offset + i * utils::padto(KPad, 64); + auto elt_offset = base_offset + i * utils::padto(KPad, 128); assert(elt_offset % 8 == 0); auto bit2ptr = reinterpret_cast(bit3_ptr + elt_offset / 4); auto bit1ptr = reinterpret_cast(bit3_ptr + row * ld_dst / 4 + elt_offset / 8); diff --git a/bestla/bestla/bestla_storage.h b/bestla/bestla/bestla_storage.h index 81b48387f..7b13adbe9 100644 --- a/bestla/bestla/bestla_storage.h +++ b/bestla/bestla/bestla_storage.h @@ -707,10 +707,10 @@ class StorageWeightKBlockNInteger : public IWeightKBlockBase { auto bits = utils::bestla_dtype_bits(qtype); auto elesize = static_cast(NPad) * KPad; if (qtype == BTLA_DTYPE::S3_CLIP) - elesize = static_cast(utils::padto(KPad, 64)) * - NPad; // pad K-dim to 64 because 64pack round2 interleave. round2 interleave ld_dim == PACK_ROW * - // pad_to(KPad,64) * NTILE - auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here + elesize = + static_cast(utils::padto(KPad, 128)) * NPad; // pad K-dim to 128 because 128pack round2 interleave. + // round2 interleave ld_dim == pad_to(KPad,128) * NTILE + auto bytes = utils::updiv(elesize * bits, 8); // add 3bits, 5btis, 7bits size calculation here mQBuf.resize(bytes); int nk_scale = utils::updiv(KPad, Block); auto gemm_comp = bestla::gemm::CoreAttr::get_comp(mCoreId); diff --git a/bestla/bestla/kernel_avx512f.h b/bestla/bestla/kernel_avx512f.h index d59dcc32e..6783ee8e6 100644 --- a/bestla/bestla/kernel_avx512f.h +++ b/bestla/bestla/kernel_avx512f.h @@ -685,7 +685,6 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 int8_t* s8_ptr = reinterpret_cast(tmp); auto head_write_num = 128 - head_ignore_num; if (head_ignore_num != 0) { - printf("head process\n"); bit3_interleave_decompress_pack128(base_bit2ptr, base_bit1ptr, tmp); for (int i = 0; i < head_write_num; i++) dstptr[i] = s8_ptr[head_ignore_num + i]; compress_wei_ptr_offset += head_write_num; @@ -699,7 +698,6 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8 dstptr + compress_wei_ptr_offset, tmp, body_loop * 128); compress_wei_ptr_offset += body_loop * 128; if (tail_proc_num > 0) { - printf("tail process\n"); bit3_interleave_decompress_pack128(base_bit2ptr, base_bit1ptr, tmp); bit3_interleave_decompress_pack128(bit2ptr + compress_wei_ptr_offset / 4, bit1ptr + compress_wei_ptr_offset / 8, tmp); From ae68c14df9e601a39ca30b39e0721fcc5e804702 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Wed, 31 Jan 2024 08:36:03 +0800 Subject: [PATCH 32/33] clean code & add repack script --- bestla/bestla/kernel_wrapper.h | 10 +-- bestla/bestla/ut/bestla_prologue_b.cpp | 73 ++----------------- .../convert/convert_quantized_gptj.py | 19 +++-- .../models/model_utils/quant_utils.cpp | 1 + scripts/python_api_example_for_gptq.py | 32 ++++++++ 5 files changed, 53 insertions(+), 82 deletions(-) create mode 100644 scripts/python_api_example_for_gptq.py diff --git a/bestla/bestla/kernel_wrapper.h b/bestla/bestla/kernel_wrapper.h index 45e53228e..5b7c7f5b8 100644 --- a/bestla/bestla/kernel_wrapper.h +++ b/bestla/bestla/kernel_wrapper.h @@ -511,14 +511,8 @@ class DecompressKBlockS3S8Fp { int interleave_n_offset, int unpack_elt, void* tmp, size_t tmpsize) { BTLA_CODE ret = BTLA_CODE::NotSupport; #if CompileAVX512F() - // if constexpr (utils::isa_base::avx512f && !std::is_same_v<_DST_T, int8_t>) { - // ret = avx512f::decompress_kblock_s3_s8fp(bit2ptr, bit1ptr, dstptr, interleave_n_offset, - // unpack_elt, - // reinterpret_cast(tmp), tmpsize); - // } else { - jit::DecompresssS3::forward_avx512f(bit2ptr, bit1ptr, dstptr, tmp, unpack_elt); - ret = BTLA_CODE::Success; - // } + ret = avx512f::decompress_kblock_s3_s8fp(bit2ptr, bit1ptr, dstptr, interleave_n_offset, unpack_elt, + reinterpret_cast(tmp), tmpsize); #endif assert(ret == BTLA_CODE::Success); return ret; diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index 5e83e4dd0..b6db94dad 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -173,60 +173,15 @@ class UT_BlockQunatize_F8 { static UT_BlockQunatize_F8 sUT_BlockQunatize_F8; #endif -class UT_BlockQunatize_S3 { +class UT_S3_WOQ { public: - UT_BlockQunatize_S3() { + UT_S3_WOQ() { UT_START(); CheckISA(AVX512F); - // ut(128, 4096, 16384, 32, 56); - // ut(1, 4096, 16384, 32, 56); - // ut(128, 4096, 4096, 32, 56); - // ut(1, 4096, 4096, 32, 56); - // ut(128, 16384, 4096, 32, 56); - // ut(1, 16384, 4096, 32, 56); - // ut(128, 4096, 16384, 32, 56); - // ut(1, 4096, 16384, 32, 56); - // ut(128, 4096, 4096, 32, 56); - // ut(1, 4096, 4096, 32, 56); - // ut(128, 16384, 4096, 32, 56); - // ut(1, 16384, 4096, 32, 56); - ut, BTLA_ISA::AMX_INT8>(1024, 4096, 16384, 128, 8); - ut, BTLA_ISA::AMX_INT8>(48, 4096, 16384, 128, 8); - ut, BTLA_ISA::AMX_INT8>(1024, 4096, 4096, 128, 8); - ut, BTLA_ISA::AMX_INT8>(48, 4096, 4096, 128, 8); - ut, BTLA_ISA::AMX_INT8>(1024, 16384, 4096, 128, 8); - ut, BTLA_ISA::AMX_INT8>(48, 16384, 4096, 128, 8); - // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 16384, 128, 56); - // ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 16384, 128, 56); - // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 4096, 128, 56); - // ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 56); - // ut, BTLA_ISA::AVX512_VNNI>(128, 16384, 4096, 128, 56); - // ut, BTLA_ISA::AVX512_VNNI>(1, 16384, 4096, 128, 56); - // emr case - // ut(128, 4096, 16384, 32, 64); - // ut(1, 4096, 16384, 32, 64); - // ut(128, 4096, 4096, 32, 64); - // ut(1, 4096, 4096, 32, 64); - // ut(128, 16384, 4096, 32, 64); - // ut(1, 16384, 4096, 32, 64); - // ut(128, 4096, 16384, 32, 64); - // ut(1, 4096, 16384, 32, 64); - // ut(128, 4096, 4096, 32, 64); - // ut(1, 4096, 4096, 32, 64); - // ut(128, 16384, 4096, 32, 64); - // ut(1, 16384, 4096, 32, 64); - // ut, BTLA_ISA::AMX_INT8>(128, 4096, 16384, 128, 64); - // ut, BTLA_ISA::AMX_INT8>(1, 4096, 16384, 128, 64); - // ut, BTLA_ISA::AMX_INT8>(128, 4096, 4096, 128, 64); - // ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 64); - // ut, BTLA_ISA::AMX_INT8>(128, 16384, 4096, 128, 64); - // ut, BTLA_ISA::AMX_INT8>(1, 16384, 4096, 128, 64); - // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 16384, 128, 64); - // ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 16384, 128, 64); - // ut, BTLA_ISA::AVX512_VNNI>(128, 4096, 4096, 128, 64); - // ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 64); - // ut, BTLA_ISA::AVX512_VNNI>(128, 16384, 4096, 128, 64); - // ut, BTLA_ISA::AVX512_VNNI>(1, 16384, 4096, 128, 64); + ut(1, 4096, 4096, 32, 56); + ut(1, 4096, 4096, 32, 56); + ut, BTLA_ISA::AMX_INT8>(1, 4096, 4096, 128, 56); + ut, BTLA_ISA::AVX512_VNNI>(1, 4096, 4096, 128, 56); } template @@ -309,23 +264,10 @@ class UT_BlockQunatize_S3 { parallel::GemmRunWithA(launcher, args_ref, &DefaultThreading); } buffer_error(matC.data(), refC.data(), matC.size(), 0.001f); - // avector dequant(n * k, 0); - // avector dequant_ref(n * k, 0); - // kernel.unpackWeight(n, k, &ptr, dequant.data(), n, &DefaultThreading); - // kernel.unpackWeight(n, k, &ptr_ref, dequant_ref.data(), n, &DefaultThreading); - // for (int i = 0; i < k; i++) { - // for (int j = 0; j < n; j++) { - // if ((dequant[i * n + j] - dequant_ref[i * n + j]) != 0) { - // std::cout << "i: " << i << " j:" << j << std::endl; - // std::cout << dequant[i * n + j] << " vs " << dequant_ref[i * n + j] << std::endl; - // } - // } - // } } }; -// static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; #ifdef BTLA_UT_PROLOGUE_B -static UT_BlockQunatize_S3 sUT_BlockQunatize_S3; +static UT_S3_WOQ sUT_S3_WOQ; #endif class UT_TransposeBlockQuantize_F4 { public: @@ -1060,7 +1002,6 @@ class UTBenchmark_CompFp32 { } } }; -static UTBenchmark_CompFp32 sUTBenchmark_CompFp32; #ifdef BTLA_UT_PROLOGUE_B_ static UTBenchmark_CompFp32 sUTBenchmark_CompFp32; #endif diff --git a/neural_speed/convert/convert_quantized_gptj.py b/neural_speed/convert/convert_quantized_gptj.py index 8d2a22147..7e565af55 100644 --- a/neural_speed/convert/convert_quantized_gptj.py +++ b/neural_speed/convert/convert_quantized_gptj.py @@ -39,14 +39,15 @@ def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config): scales = model[f"{src_name}.scales"] qweight = model[f"{src_name}.qweight"] - int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config) - int_weight = int_weight.view(-1,int_weight.shape[-1]) + int_weight, gptq_scales, gptq_zeros = unpack_weight( + qweight, scales, qzeros, q_config) + int_weight = int_weight.view(-1, int_weight.shape[-1]) # shuffle weight in GPTQ when act order is on - if 'desc_act'in q_config and q_config['desc_act']: + if 'desc_act' in q_config and q_config['desc_act']: g_idx = model[f"{src_name}.g_idx"] int_weight2 = int_weight.clone() - group_size=q_config['group_size'] + group_size = q_config['group_size'] group_dict = {} for i in range(len(g_idx)): group_idx = g_idx[i].item() @@ -84,7 +85,7 @@ def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config): gptq_zeros = np.empty(0, dtype=np.int8) else: gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy()) - if 'desc_act'in q_config and q_config['desc_act']: + if 'desc_act' in q_config and q_config['desc_act']: g_idx = np.ascontiguousarray(g_idx.numpy()) else: g_idx = np.empty(0, dtype=np.int32) @@ -111,13 +112,14 @@ def main(args_in: Optional[List[str]] = None) -> None: model, config, quantize_config = load_quantized_model(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True) fout = open(out_path, "wb") # 1. write hparams hparams = config n_layer = hparams["n_layer"] - fout.write(b"ggjt"[::-1]) #0x67676d6c)) # magic: ggml in hex + fout.write(b"ggjt"[::-1]) # 0x67676d6c)) # magic: ggml in hex values = [ 1, # file version hparams["vocab_size"], @@ -140,7 +142,8 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) - fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps + fout.write(struct.pack("f", hparams.get( + "rms_norm_eps", 1e-6))) # rms norm eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1)) diff --git a/neural_speed/models/model_utils/quant_utils.cpp b/neural_speed/models/model_utils/quant_utils.cpp index 5d68165f1..65e943875 100644 --- a/neural_speed/models/model_utils/quant_utils.cpp +++ b/neural_speed/models/model_utils/quant_utils.cpp @@ -238,6 +238,7 @@ size_t bestla_qpack(const int8_t* src_w, const float* src_scales, const int8_t* if (params.bits == quant_bits::q8) { quant_type = BTLA_DTYPE::S8; } + if (params.bits == quant_bits::q3) quant_type = BTLA_DTYPE::S3_CLIP; auto dtype_type = static_cast( bestla::utils::bestla_dtype_get_mask_val(quant_type, BTLA_DTYPE::TypeMask, BTLA_DTYPE::TypeShift)); if (dtype_type == BTLA_DTYPE::TypeFloat) { diff --git a/scripts/python_api_example_for_gptq.py b/scripts/python_api_example_for_gptq.py new file mode 100644 index 000000000..940d2438b --- /dev/null +++ b/scripts/python_api_example_for_gptq.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from transformers import AutoTokenizer, TextStreamer +from neural_speed import Model + +if len(sys.argv) != 2: + print("Usage: python python_api_example.py model_path") +model_name = sys.argv[1] + +prompt = "Once upon a time, a little girl" +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +inputs = tokenizer(prompt, return_tensors="pt").input_ids +streamer = TextStreamer(tokenizer) + +model = Model() +model.init(model_name, weight_dtype="int3", compute_dtype="int8", use_gptq=True) +outputs = model.generate(inputs, streamer=streamer, max_new_tokens=300, do_sample=True) From ba1b3f8d8fe662a6108fe0782240d65903e15e05 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Sun, 18 Feb 2024 10:38:19 +0800 Subject: [PATCH 33/33] perfect code --- bestla/CMakeLists.txt | 2 +- bestla/bestla/bestla_prologue_b.h | 3 -- bestla/bestla/kernel_ref.h | 1 - bestla/bestla/ut/bestla_prologue_b.cpp | 5 ++- .../convert/convert_quantized_gptj.py | 2 +- scripts/python_api_example_for_gptq.py | 32 ------------------- 6 files changed, 4 insertions(+), 41 deletions(-) delete mode 100644 scripts/python_api_example_for_gptq.py diff --git a/bestla/CMakeLists.txt b/bestla/CMakeLists.txt index 834e61160..2b17a8603 100644 --- a/bestla/CMakeLists.txt +++ b/bestla/CMakeLists.txt @@ -7,7 +7,7 @@ file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) option(BTLA_USE_OPENMP "Enable OpenMP thread pool" OFF) option(BTLA_UT_ALL "Enable all unit tests" OFF) -option(BTLA_UT_DEBUG "Enable debug unit tests" ON) +option(BTLA_UT_DEBUG "Enable debug unit tests" OFF) option(BTLA_UT_EPILOGUE "Enable unit test for epilogue" OFF) option(BTLA_UT_PROLOGUE_A "Enable unit test for activation prologue" OFF) option(BTLA_UT_PROLOGUE_B "Enable unit test for weight prologue" OFF) diff --git a/bestla/bestla/bestla_prologue_b.h b/bestla/bestla/bestla_prologue_b.h index ac7b5a408..99f3ccc90 100644 --- a/bestla/bestla/bestla_prologue_b.h +++ b/bestla/bestla/bestla_prologue_b.h @@ -13,8 +13,6 @@ // limitations under the License. #pragma once #include -#include -#include "bestla.h" #include "bestla_utils.h" #include "bestla_storage.h" #include "bestla_device.h" @@ -561,7 +559,6 @@ class WeightKBlockNInteger { static void compressBit3Weight(const int N, const int K, const int8_t* B, int8_t* dstptr, parallel::IThreading* threading) { // TODO(zhe): 1D parallel compress - // N==NPad, K==Kpad, ldb==Kpad auto ld_dst = _GemmCore_T::NTILE * utils::padto(K, 64); auto col = _GemmCore_T::NTILE * K; auto row = N / _GemmCore_T::NTILE; diff --git a/bestla/bestla/kernel_ref.h b/bestla/bestla/kernel_ref.h index 41231d1ec..7d47aa17e 100644 --- a/bestla/bestla/kernel_ref.h +++ b/bestla/bestla/kernel_ref.h @@ -178,7 +178,6 @@ static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, i return BTLA_CODE::Success; } -// #include static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int row, int col, int ld_src, int ld_dst) { assert(col % 128 == 0); diff --git a/bestla/bestla/ut/bestla_prologue_b.cpp b/bestla/bestla/ut/bestla_prologue_b.cpp index b6db94dad..29e18e35d 100644 --- a/bestla/bestla/ut/bestla_prologue_b.cpp +++ b/bestla/bestla/ut/bestla_prologue_b.cpp @@ -1,4 +1,3 @@ -#include "bestla.h" #include "bestla_gemm.h" #include "bestla_prologue_b.h" #include "bestla_parallel.h" @@ -886,7 +885,7 @@ class UTBenchmark_CompFp32 { kernel.mProB.packWeight(n, k, B, n, &packBs[0], &DefaultThreading); for (size_t i = 1; i < batch; i++) { memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); - memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(float)); + memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); } auto psize = (size_t)m * n * k * 2; auto memsize = (size_t)packBs[0].mSize + (m * k + m * n) * sizeof(float); @@ -946,7 +945,7 @@ class UTBenchmark_CompFp32 { kernel.mProB.packWeight(n, k, B, n, &packBs[0], &DefaultThreading); for (size_t i = 1; i < batch; i++) { memcpy(packBs[i].template WPtr(), packBs[0].template WPtr(), packBs[0].template WSize()); - memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(uint16_t)); + memcpy(packBs[i].template SPtr(), packBs[0].template SPtr(), packBs[0].CSize() * sizeof(Scale_T)); } auto psize = (size_t)m * n * k * 2; auto memsize = (size_t)packBs[0].mSize + (m * k + m * n) * sizeof(float); diff --git a/neural_speed/convert/convert_quantized_gptj.py b/neural_speed/convert/convert_quantized_gptj.py index 7e565af55..9711a698b 100644 --- a/neural_speed/convert/convert_quantized_gptj.py +++ b/neural_speed/convert/convert_quantized_gptj.py @@ -97,7 +97,7 @@ def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config): alg="sym" if q_config['sym'] else "asym", compute_dtype="int8") dst.flatten()[:byte_size].tofile(fout) - print(f"converting {dst_name} quantized tensor to bestla q4 block") + print(f"converting {dst_name} quantized tensor to bestla q{q_config['bits']} block") def main(args_in: Optional[List[str]] = None) -> None: diff --git a/scripts/python_api_example_for_gptq.py b/scripts/python_api_example_for_gptq.py deleted file mode 100644 index 940d2438b..000000000 --- a/scripts/python_api_example_for_gptq.py +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys -from transformers import AutoTokenizer, TextStreamer -from neural_speed import Model - -if len(sys.argv) != 2: - print("Usage: python python_api_example.py model_path") -model_name = sys.argv[1] - -prompt = "Once upon a time, a little girl" -tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) -inputs = tokenizer(prompt, return_tensors="pt").input_ids -streamer = TextStreamer(tokenizer) - -model = Model() -model.init(model_name, weight_dtype="int3", compute_dtype="int8", use_gptq=True) -outputs = model.generate(inputs, streamer=streamer, max_new_tokens=300, do_sample=True)