Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
bit2bit2bit1bit1-pack
Browse files Browse the repository at this point in the history
  • Loading branch information
zhewang1-intc committed Jan 24, 2024
1 parent dad2282 commit ec1092e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 25 deletions.
48 changes: 34 additions & 14 deletions bestla/bestla/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -651,29 +651,49 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8
auto zmm_0x04 = _mm512_set1_epi8(0x04);
auto zmm_0x00 = _mm512_set1_epi8(0x00);
auto zmm_shift = _mm512_set1_epi32(5);

auto bit3_interleave_decompress_pack128 = [&](void* src, int8_t* dst) {
const __m256i lowMask = _mm256_set1_epi8(0x03);
const __m256i bit2_data = _mm256_loadu_si256((const __m256i*)src);
auto ymm0 = _mm256_and_si256(lowMask, bit2_data); // uop:1 p:015
auto ymm1 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 2)); // uop:1 p:01
auto ymm2 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 4));
auto ymm3 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data, 6));
const __m256i lowMask = _mm256_set1_epi8(0x03);

auto bit3_interleave_decompress_pack256 = [&](void* src, int8_t* dst) {
const __m512i bit2_data_zmm = _mm512_loadu_si512(src);
auto bit2_data1 = _mm512_extracti32x8_epi32(bit2_data_zmm, 0x0);
auto bit2_data2 = _mm512_extracti32x8_epi32(bit2_data_zmm, 0x1);
auto ymm0 = _mm256_and_si256(lowMask, bit2_data1); // uop:1 p:015
auto ymm1 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data1, 2)); // uop:1 p:01
auto ymm2 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data1, 4));
auto ymm3 = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data1, 6));
auto zmm1 = _mm512_inserti32x8(_mm512_castsi256_si512(ymm0), ymm1, 0x1); // lat3, tp1 uop1 p:5
auto zmm2 = _mm512_inserti32x8(_mm512_castsi256_si512(ymm2), ymm3, 0x1);

unsigned long long* bit1_ptr = reinterpret_cast<unsigned long long*>(src + 32);
auto ymm0b = _mm256_and_si256(lowMask, bit2_data2); // uop:1 p:015
auto ymm1b = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data2, 2)); // uop:1 p:01
auto ymm2b = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data2, 4));
auto ymm3b = _mm256_and_si256(lowMask, _mm256_srli_epi16(bit2_data2, 6));
auto zmm1b = _mm512_inserti32x8(_mm512_castsi256_si512(ymm0b), ymm1b, 0x1); // lat3, tp1 uop1 p:5
auto zmm2b = _mm512_inserti32x8(_mm512_castsi256_si512(ymm2b), ymm3b, 0x1);

unsigned long long* bit1_ptr = reinterpret_cast<unsigned long long*>(src + 64);
auto bit1_mask1 = _cvtu64_mask64(*bit1_ptr);
auto bit1_mask2 = _cvtu64_mask64(*(bit1_ptr + 1));
auto bit1_mask3 = _cvtu64_mask64(*(bit1_ptr + 2));
auto bit1_mask4 = _cvtu64_mask64(*(bit1_ptr + 3));
auto zmm1_ = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask1, zmm_0x04);
auto zmm2_ = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask2, zmm_0x04);
auto zmm1_b = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask3, zmm_0x04);
auto zmm2_b = _mm512_mask_mov_epi8(zmm_0x00, bit1_mask4, zmm_0x04);
zmm1 = _mm512_add_epi8(zmm1, zmm1_);
zmm2 = _mm512_add_epi8(zmm2, zmm2_);
zmm1 = _mm512_sllv_epi32(zmm1, zmm_shift); // int3_clip => int8
zmm2 = _mm512_sllv_epi32(zmm2, zmm_shift); // int3_clip => int8

zmm1b = _mm512_add_epi8(zmm1b, zmm1_b);
zmm2b = _mm512_add_epi8(zmm2b, zmm2_b);
zmm1b = _mm512_sllv_epi32(zmm1b, zmm_shift); // int3_clip => int8
zmm2b = _mm512_sllv_epi32(zmm2b, zmm_shift); // int3_clip => int8

_mm512_storeu_epi8((__m512i*)dst, zmm1);
_mm512_storeu_epi8((__m512i*)(dst + 64), zmm2);
_mm512_storeu_epi8((__m512i*)(dst + 128), zmm1b);
_mm512_storeu_epi8((__m512i*)(dst + 192), zmm2b);
};

assert(head_ignore_num % 8 == 0);
Expand Down Expand Up @@ -701,11 +721,11 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8
auto tail_proc_num = (unpack_elt - (64 - head_ignore_num) % 64) % 64;

// for (int i = 0; i < body_loop; i++) {
for (int i = 0; i < body_loop / 2; i++) {
for (int i = 0; i < body_loop / 4; i++) {
if constexpr (!std::is_same_v<_DST_T, int8_t>) {
bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 8 * 3,
bit3_interleave_decompress_pack256(base_bit2ptr + compress_wei_ptr_offset / 8 * 3,
reinterpret_cast<int8_t*>(base_unpack_buf));
for (int j = 0; j < 128; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j);
for (int j = 0; j < 256; j += 16) convert_s8_fp_v16(dstptr + compress_wei_ptr_offset + j, tmp + j);
// auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4,
// base_bit1ptr + compress_wei_ptr_offset / 8);
// _mm512_storeu_epi8(base_unpack_buf, zmm);
Expand Down Expand Up @@ -735,12 +755,12 @@ inline BTLA_CODE decompress_kblock_s3_s8fp(utils::bit2x4* bit2ptr, utils::bit1x8
} else {
// auto zmm = bit3_interleave_decompress(base_bit2ptr + compress_wei_ptr_offset / 4,
// base_bit1ptr + compress_wei_ptr_offset / 8);
bit3_interleave_decompress_pack128(base_bit2ptr + compress_wei_ptr_offset / 8 * 3,
bit3_interleave_decompress_pack256(base_bit2ptr + compress_wei_ptr_offset / 8 * 3,
reinterpret_cast<int8_t*>(base_unpack_buf) + compress_wei_ptr_offset);
// _mm512_storeu_epi8(base_unpack_buf + compress_wei_ptr_offset, zmm);
}
// compress_wei_ptr_offset += 64;
compress_wei_ptr_offset += 128;
compress_wei_ptr_offset += 256;
}
if (tail_proc_num > 0) {
assert(0);
Expand Down
12 changes: 6 additions & 6 deletions bestla/bestla/kernel_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,14 @@ class DecompresssS3 {
mov(reg_dst, ptr[parambase + OFFSET(dstptr)]);
L("loop_label");
imul(reg_tmp, reg_iter, 96);
Xbyak::Zmm bit2_data_zmm = zmm0;

vmovups(ymm2, ptr[reg_bit2ptr + reg_tmp]);
kmovq(bit1_mask1, ptr[reg_bit2ptr + reg_tmp + 32]);
kmovq(bit1_mask2, ptr[reg_bit2ptr + reg_tmp + 40]);
vmovups(ymm3, ptr[reg_bit2ptr + reg_tmp + 48]);
kmovq(bit1_mask1, ptr[reg_bit2ptr + reg_tmp + 64]);
kmovq(bit1_mask2, ptr[reg_bit2ptr + reg_tmp + 72]);
kmovq(bit1_mask3, ptr[reg_bit2ptr + reg_tmp + 80]);
kmovq(bit1_mask4, ptr[reg_bit2ptr + reg_tmp + 88]);
Xbyak::Zmm bit2_data_zmm = zmm0;
vmovups(bit2_data_zmm, ptr[reg_bit2ptr + reg_tmp]);
vextractf32x8(ymm2, bit2_data_zmm, 0x0);
vextractf32x8(ymm3, bit2_data_zmm, 0x1);

vpand(ymm4, LowMask, ymm2);
vpsrlw(ymm2, ymm2, 2);
Expand Down
8 changes: 4 additions & 4 deletions bestla/bestla/kernel_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,11 @@ static inline BTLA_CODE compress_3bit(const int8_t* srcptr, void* dst, int row,
}

for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j += 128) {
for (int j = 0; j < col; j += 256) {
bit2x4* bit2 = reinterpret_cast<bit2x4*>(dst + i * ld_dst / 8 * 3 + j / 8 * 3);
bit1x8* bit1 = reinterpret_cast<bit1x8*>(dst + i * ld_dst / 8 * 3 + j / 8 * 3 + 32);
for (int k = 0; k < 32; k++) bit2[k] = bit2_data_[i * ld_dst / 4 + j / 4 + k];
for (int k = 0; k < 16; k++) bit1[k] = bit1_data_[i * ld_dst / 8 + j / 8 + k];
bit1x8* bit1 = reinterpret_cast<bit1x8*>(dst + i * ld_dst / 8 * 3 + j / 8 * 3 + 64);
for (int k = 0; k < 64; k++) bit2[k] = bit2_data_[i * ld_dst / 4 + j / 4 + k];
for (int k = 0; k < 32; k++) bit1[k] = bit1_data_[i * ld_dst / 8 + j / 8 + k];
}
}
return BTLA_CODE::Success;
Expand Down
2 changes: 1 addition & 1 deletion bestla/bestla/ut/bestla_prologue_b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ class UT_BlockQunatize_S3 {
// }
}
};
static UT_BlockQunatize_S3 sUT_BlockQunatize_S3;
// static UT_BlockQunatize_S3 sUT_BlockQunatize_S3;
#ifdef BTLA_UT_PROLOGUE_B
static UT_BlockQunatize_S3 sUT_BlockQunatize_S3;
#endif
Expand Down

0 comments on commit ec1092e

Please sign in to comment.