Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
pack 1bit&2bit
Browse files Browse the repository at this point in the history
  • Loading branch information
zhewang1-intc committed Jan 24, 2024
1 parent b322d68 commit dad2282
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 78 deletions.
13 changes: 6 additions & 7 deletions bestla/bestla/bestla_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,9 @@ class WeightKBlockNInteger {
auto row = N / _GemmCore_T::NTILE;
auto pad_64_buf = utils::avector<int8_t>(row * ld_dst, 0);
kernel::wrapper::Memcpy2D::forward<BTLA_ISA::NoSIMD>(B, pad_64_buf.data(), row, col, col, ld_dst);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(dstptr);
auto bit1ptr = reinterpret_cast<utils::bit1x8*>(dstptr + row * ld_dst / 4);
auto ret =
kernel::wrapper::CompressBit3::forward<ISA_T>(pad_64_buf.data(), bit2ptr, bit1ptr, row, col, ld_dst, ld_dst);
// auto bit2ptr = reinterpret_cast<utils::bit2x4*>(dstptr);
// auto bit1ptr = reinterpret_cast<utils::bit1x8*>(dstptr + row * ld_dst / 4);
auto ret = kernel::wrapper::CompressBit3::forward<ISA_T>(pad_64_buf.data(), dstptr, row, col, ld_dst, ld_dst);
assert(ret == BTLA_CODE::Success);
}

Expand Down Expand Up @@ -752,7 +751,7 @@ class WeightKBlockNInteger {
auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64);
auto row = NPad / _GemmCore_T::NTILE;
assert(elt_offset % 8 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit3_ptr + elt_offset / 4);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit3_ptr + elt_offset / 8 * 3);
auto bit1ptr = reinterpret_cast<utils::bit1x8*>(bit3_ptr + row * ld_dst / 4 + elt_offset / 8);
kernel::wrapper::DecompressKBlockS3S8Fp<T>::template forward<ISA_T, BTLA_DTYPE::S3_CLIP>(
bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE,
Expand Down Expand Up @@ -833,7 +832,7 @@ class WeightKBlockNInteger {
auto ld_dst = _GemmCore_T::NTILE * utils::padto(KPad, 64);
auto row = NPad / _GemmCore_T::NTILE;
assert(elt_offset % 8 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit3_ptr + elt_offset / 4);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit3_ptr + elt_offset / 8 * 3);
auto bit1ptr = reinterpret_cast<utils::bit1x8*>(bit3_ptr + row * ld_dst / 4 + elt_offset / 8);
kernel::wrapper::DecompressKBlockS3Fp<_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T, utils::bf16,
BTLA_DTYPE::S3_CLIP>(
Expand Down Expand Up @@ -909,7 +908,7 @@ class WeightKBlockNInteger {
for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) {
auto elt_offset = base_offset + i * utils::padto(KPad, 64);
assert(elt_offset % 8 == 0);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit3_ptr + elt_offset / 4);
auto bit2ptr = reinterpret_cast<utils::bit2x4*>(bit3_ptr + elt_offset / 8 * 3);
auto bit1ptr = reinterpret_cast<utils::bit1x8*>(bit3_ptr + row * ld_dst / 4 + elt_offset / 8);
kernel::wrapper::DecompressKBlockS3S8Fp<int8_t>::template forward<ISA_T, BTLA_DTYPE::S3_CLIP>(
bit2ptr, bit1ptr, *dstptr + i * k_size, k_offset * _GemmCore_T::NTILE,
Expand Down
35 changes: 5 additions & 30 deletions bestla/bestla/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,40 +652,17 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8
auto zmm_0x00 = _mm512_set1_epi8(0x00);
auto zmm_shift = _mm512_set1_epi32(5);

// auto bit3_interleave_decompress = [&](__m128i bit2_data, utils::bit1x8* src2) {
auto bit3_interleave_decompress = [&](utils::bit2x4* src1, utils::bit1x8* src2) {
const __m128i lowMask = _mm_set1_epi8(0x03);
const __m128i bit2_data = _mm_loadu_si128((const __m128i*)src1);
// __m128i bit2_data;
auto xmm0 = _mm_and_si128(lowMask, bit2_data); // uop:1 p:015
auto xmm1 = _mm_and_si128(lowMask, _mm_srli_epi16(bit2_data, 2)); // uop:1 p:01
auto xmm2 = _mm_and_si128(lowMask, _mm_srli_epi16(bit2_data, 4));
auto xmm3 = _mm_and_si128(lowMask, _mm_srli_epi16(bit2_data, 6));
auto ymm1 = _mm256_set_m128i(xmm1, xmm0); // uop:1 p5
auto ymm2 = _mm256_set_m128i(xmm3, xmm2);
auto zmm = _mm512_inserti32x8(_mm512_castsi256_si512(ymm1), ymm2, 0x1);
// make bit1-storage as mmask64, then cvt the mask to the int8-value.
unsigned long long* bit1_ptr = reinterpret_cast<unsigned long long*>(src2);
auto bit1_mask = _cvtu64_mask64(*bit1_ptr);
// __mmask64 bit1_mask;
// __m512i zmm;
auto zmm2 = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask, zmm_0x04);
zmm = _mm512_add_epi8(zmm, zmm2);
zmm = _mm512_sllv_epi32(zmm, zmm_shift); // int3_clip => int8
return zmm;
};

auto bit3_interleave_decompress_pack128 = [&](utils::bit2x4* src1, utils::bit1x8* src2, int8_t* dst) {
auto bit3_interleave_decompress_pack128 = [&](void* src, int8_t* dst) {
const __m256i lowMask = _mm256_set1_epi8(0x03);
const __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src1);
const __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src);
auto ymm0 = _mm256_and_si256(lowMask, bit2_data); // uop:1 p:015
auto ymm1 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2)); // uop:1 p:01
auto ymm2 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 4));
auto ymm3 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 6));
auto zmm1 = _mm512_inserti32x8(_mm512_castsi256_si512(ymm0), ymm1, 0x1); // lat3, tp1 uop1 p:5
auto zmm2 = _mm512_inserti32x8(_mm512_castsi256_si512(ymm2), ymm3, 0x1);

unsigned long long* bit1_ptr = reinterpret_cast<unsigned long long*>(src2);
unsigned long long* bit1_ptr = reinterpret_cast<unsigned long long*>(src + 32);
auto bit1_mask1 = _cvtu64_mask64(*bit1_ptr);
auto bit1_mask2 = _cvtu64_mask64(*(bit1_ptr + 1));
auto zmm1_ = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask1, zmm_0x04);
Expand Down Expand Up @@ -726,8 +703,7 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8
// for (int i = 0; i < body_loop; i++) {
for (int i = 0; i < body_loop / 2; i++) {
if constexpr (!std::is_same_v<_DST_T, int8_t>) {
bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 4,
base_bit1ptr + compress_wei_ptr_offset / 8,
bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 8 * 3,
reinterpret_cast<int8_t*>(base_unpack_buf));
for (int j = 0; j < 128; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j);
// auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4,
Expand Down Expand Up @@ -759,8 +735,7 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8
} else {
// auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4,
// base_bit1ptr + compress_wei_ptr_offset / 8);
bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 4,
base_bit1ptr + compress_wei_ptr_offset / 8,
bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 8 * 3,
reinterpret_cast<int8_t*>(base_unpack_buf) + compress_wei_ptr_offset);
// _mm512_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, zmm);
}
Expand Down
13 changes: 6 additions & 7 deletions bestla/bestla/kernel_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,14 @@ class DecompresssS3 {
mov(reg_dst, ptr[parambase + OFFSET(dstptr)]);
L("loop_label");
imul(reg_tmp, reg_iter, 96);
kmovq(bit1_mask1, ptr[reg_bit2ptr + reg_tmp+64]);
kmovq(bit1_mask2, ptr[reg_bit2ptr + reg_tmp + 72]);
Xbyak::Zmm bit2_data_zmm = zmm0;

vmovups(ymm2, ptr[reg_bit2ptr + reg_tmp]);
kmovq(bit1_mask1, ptr[reg_bit2ptr + reg_tmp + 32]);
kmovq(bit1_mask2, ptr[reg_bit2ptr + reg_tmp + 40]);
vmovups(ymm3, ptr[reg_bit2ptr + reg_tmp + 48]);
kmovq(bit1_mask3, ptr[reg_bit2ptr + reg_tmp + 80]);
kmovq(bit1_mask4, ptr[reg_bit2ptr + reg_tmp + 88]);
Xbyak::Zmm bit2_data_zmm = zmm0;
// imul(reg_tmp, reg_iter, 64);
vmovups(bit2_data_zmm, ptr[reg_bit2ptr + reg_tmp]);
vextractf32x8(ymm2, bit2_data_zmm, 0x0);
vextractf32x8(ymm3, bit2_data_zmm, 0x1);

vpand(ymm4, LowMask, ymm2);
vpsrlw(ymm2, ymm2, 2);
Expand Down
51 changes: 25 additions & 26 deletions bestla/bestla/kernel_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ static inline BTLA_CODE compress_f4(const int8_t* srcptr, utils::f4x2* dstptr, i
}

// #include <iostream>
static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr,
int row, int col, int ld_src, int ld_dst) {
static inline BTLA_CODE compress_3bit(const int8_t* srcptr, void* dst, int row, int col, int ld_src, int ld_dst) {
assert(col % 64 == 0);
// interleave + store 2bit.

Expand All @@ -201,42 +200,42 @@ static inline BTLA_CODE compress_3bit(const int8_t* srcptr, bestla::utils::bit2x
}
};

// auto bit2_interleave = [&](int8_t* src, int8_t* dst) {
// for (int i = 0; i < 64 / 4; i++) {
// dst[4 * i] = src[i];
// dst[4 * i + 1] = src[64 / 4 + i];
// dst[4 * i + 2] = src[64 / 4 * 2 + i];
// dst[4 * i + 3] = src[64 / 4 * 3 + i];
// }
// };


int8_t interleave_buf[128];
using namespace bestla::utils;
std::vector<bit2x4> bit2_data_(row * col / 4);
std::vector<bit1x8> bit1_data_(row * col / 8);

for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j += 128) {
// for (int j = 0; j < col; j += 64) {
bit2_interleave(const_cast<int8_t*>(srcptr + i * ld_src + j), interleave_buf);
for (int k = 0; k < 32; k++) {
// for (int k = 0; k < 16; k++) {
bit2ptr[i * ld_dst / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5;
bit2ptr[i * ld_dst / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5;
bit2ptr[i * ld_dst / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5;
bit2ptr[i * ld_dst / 4 + j / 4 + k].d = interleave_buf[4 * k + 3] >> 5;
bit2_data_[i * ld_dst / 4 + j / 4 + k].a = interleave_buf[4 * k] >> 5;
bit2_data_[i * ld_dst / 4 + j / 4 + k].b = interleave_buf[4 * k + 1] >> 5;
bit2_data_[i * ld_dst / 4 + j / 4 + k].c = interleave_buf[4 * k + 2] >> 5;
bit2_data_[i * ld_dst / 4 + j / 4 + k].d = interleave_buf[4 * k + 3] >> 5;
}
}
}
// store 1 bit without interleave as mask.
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j += 8) {
bit1ptr[i * ld_dst / 8 + j / 8].a = srcptr[i * ld_src + j] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].b = srcptr[i * ld_src + j + 1] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].c = srcptr[i * ld_src + j + 2] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].d = srcptr[i * ld_src + j + 3] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].e = srcptr[i * ld_src + j + 4] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].f = srcptr[i * ld_src + j + 5] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].g = srcptr[i * ld_src + j + 6] >> 7;
bit1ptr[i * ld_dst / 8 + j / 8].h = srcptr[i * ld_src + j + 7] >> 7;
bit1_data_[i * ld_dst / 8 + j / 8].a = srcptr[i * ld_src + j] >> 7;
bit1_data_[i * ld_dst / 8 + j / 8].b = srcptr[i * ld_src + j + 1] >> 7;
bit1_data_[i * ld_dst / 8 + j / 8].c = srcptr[i * ld_src + j + 2] >> 7;
bit1_data_[i * ld_dst / 8 + j / 8].d = srcptr[i * ld_src + j + 3] >> 7;
bit1_data_[i * ld_dst / 8 + j / 8].e = srcptr[i * ld_src + j + 4] >> 7;
bit1_data_[i * ld_dst / 8 + j / 8].f = srcptr[i * ld_src + j + 5] >> 7;
bit1_data_[i * ld_dst / 8 + j / 8].g = srcptr[i * ld_src + j + 6] >> 7;
bit1_data_[i * ld_dst / 8 + j / 8].h = srcptr[i * ld_src + j + 7] >> 7;
}
}

for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j += 128) {
bit2x4* bit2 = reinterpret_cast<bit2x4*>(dst + i * ld_dst / 8 * 3 + j / 8 * 3);
bit1x8* bit1 = reinterpret_cast<bit1x8*>(dst + i * ld_dst / 8 * 3 + j / 8 * 3 + 32);
for (int k = 0; k < 32; k++) bit2[k] = bit2_data_[i * ld_dst / 4 + j / 4 + k];
for (int k = 0; k < 16; k++) bit1[k] = bit1_data_[i * ld_dst / 8 + j / 8 + k];
}
}
return BTLA_CODE::Success;
Expand Down
5 changes: 2 additions & 3 deletions bestla/bestla/kernel_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,8 @@ class CompressFp4 {
class CompressBit3 {
public:
template <BTLA_ISA ISA_T>
static inline BTLA_CODE forward(const int8_t* srcptr, bestla::utils::bit2x4* bit2ptr, utils::bit1x8* bit1ptr, int row,
int col, int ld_src, int ld_dst) {
return ref::compress_3bit(srcptr, bit2ptr, bit1ptr, row, col, ld_src, ld_dst);
static inline BTLA_CODE forward(const int8_t* srcptr, void* dst, int row, int col, int ld_src, int ld_dst) {
return ref::compress_3bit(srcptr, dst, row, col, ld_src, ld_dst);
}
};

Expand Down
8 changes: 3 additions & 5 deletions bestla/bestla/ut/bestla_prologue_b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ class UT_BlockQunatize_S3 {
// }
}
};
// static UT_BlockQunatize_S3 sUT_BlockQunatize_S3;
static UT_BlockQunatize_S3 sUT_BlockQunatize_S3;
#ifdef BTLA_UT_PROLOGUE_B
static UT_BlockQunatize_S3 sUT_BlockQunatize_S3;
#endif
Expand Down Expand Up @@ -881,10 +881,8 @@ class UTBenchmark_CompFp32 {

void ut_s4() {
// benchmark_all<prologue_b::gemm::WeightKBlockNInteger, float>(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP);
benchmark_all<prologue_b::gemm::WeightKBlockNInteger, utils::bf16>(48, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP);
benchmark_all<prologue_b::gemm::WeightKBlockNInteger, utils::bf16>(48, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP);
benchmark_all<prologue_b::gemm::WeightKBlockNInteger, utils::bf16>(1024, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP);
benchmark_all<prologue_b::gemm::WeightKBlockNInteger, utils::bf16>(1024, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP);
benchmark_all<prologue_b::gemm::WeightKBlockNInteger, utils::bf16>(1, 4096, 4096, 128, BTLA_DTYPE::S3_CLIP);
benchmark_all<prologue_b::gemm::WeightKBlockNInteger, utils::bf16>(1, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP);
// benchmark_all<prologue_b::gemm::WeightKBlockS4, float>(2048, 4096, 4096, 128, BTLA_DTYPE::S4_CLIP);
// benchmark_all<prologue_b::gemm::WeightKBlockS4, float>(4096, 4096, 11008, 128, BTLA_DTYPE::S4_CLIP);
// benchmark_all<prologue_b::gemm::WeightKBlockS4, float>(2, 4096, 4096, 32, BTLA_DTYPE::S4_FULLRANGE);
Expand Down

0 comments on commit dad2282

Please sign in to comment.