Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
update dequantize functor
Browse files Browse the repository at this point in the history
  • Loading branch information
zhewang1-intc committed Aug 7, 2024
1 parent b5186a4 commit 47ceae0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,12 @@ struct int4_dequantize_t<
}
SW_BARRIER();
dequantize(mat_dequant_weight, mat_qweight, scale, zp, dequantize_args);
for (uint32_t j = 0; j < sg_tile_n; j++) {
trans_mat_dequant_weight.reg.xetla_select<k_stride, sg_tile_n>(j) =
mat_dequant_weight.reg.xetla_select<k_stride, 1>(j * k_stride);
}
subgroup::tile_store(
trans_mat_dequant_weight, mat_dequant_weight_payload);
// for (uint32_t j = 0; j < sg_tile_n; j++) {
// trans_mat_dequant_weight.reg.xetla_select<k_stride, sg_tile_n>(j) =
// mat_dequant_weight.reg.xetla_select<k_stride, 1>(j * k_stride);
// }
subgroup::tile_store(mat_dequant_weight, mat_dequant_weight_payload);
// trans_mat_dequant_weight, mat_dequant_weight_payload);
mat_dequant_weight_payload.template update_tdesc<tdesc_update_dir::y_dir>(
mat_dequant_weight_t::tile_size_y);
SW_BARRIER();
Expand Down
53 changes: 23 additions & 30 deletions include/subgroup/tile/impl/tile_op_functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ struct dequant_int4_weight_t {
matB_t& matB,
scale_t& scale,
zero_pt_t& zero_pt,
// [[maybe_unused]] const coord_t& coord,
[[maybe_unused]] const arguments_t& args,
[[maybe_unused]] uint32_t slm_base = 0,
[[maybe_unused]] uint32_t nbarrier_base = 0) {
Expand All @@ -90,12 +89,14 @@ struct dequant_int4_weight_t {
for (uint32_t j = 0; j < num_block_x; ++j) {
int block_id = (i * num_block_x + j);
// Must be little-endian
auto matB_blk = matB.reg.xetla_format<uint8_t>()
.xetla_select<matB_acc_t::block_elems / 2, 1>(
block_id * matB_acc_t::block_elems / 2);

auto dst_blk = matB_acc.reg.xetla_select<matB_acc_t::block_elems, 1>(
block_id * matB_acc_t::block_elems);
xetla_vector<uint8_t, matB_acc_t::block_elems / 2> matB_blk =
matB.reg.xetla_format<int8_t>()
.xetla_select<matB_acc_t::block_elems / 2, 1>(
block_id * matB_acc_t::block_elems / 2);
auto dst_blk = matB_acc.reg
.xetla_select<matB_acc_t::block_elems, 1>(
block_id * matB_acc_t::block_elems)
.xetla_format<typename matB_acc_t::dtype>();

// int8 includes 2 4bits data.
xetla_vector<int8_t, matB_acc_t::block_elems> cvt_blk_i8;
Expand All @@ -108,9 +109,10 @@ struct dequant_int4_weight_t {
// highest 4 bit
{
cvt_blk_i8.xetla_select<matB_acc_t::block_elems / 2, 2>(1) =
matB_blk >> 4;
xetla_shr<int8_t, uint8_t, matB_acc_t::block_elems / 2>(
matB_blk, 4);
}

dst_blk = cvt_blk_i8;
// (b_i8 - zero_pt_i8) x scale = fp16
constexpr uint32_t step = std::min(block_size_y_b, dequant_s);
#pragma unroll
Expand All @@ -119,43 +121,34 @@ struct dequant_int4_weight_t {
for (uint32_t ii = 0; ii < block_size_y_b; ii += step) {
uint32_t offset_y_in_tile = i * block_size_y_b + ii;
uint32_t offset_x_in_tile = j * block_size_x_b + jj;

uint32_t scale_idx =
(offset_y_in_tile) / dequant_s * scale_t::block_size_x +
offset_x_in_tile;
typename matB_acc_t::dtype scale_value =
(typename scale_t::dtype)scale.reg[scale_idx];
typename matB_acc_t::dtype add_number;

if constexpr (quant_mode == quant_mode::I4_ASYM) {
uint32_t zero_pt_idx =
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
offset_x_in_tile / pack_ratio;
native_type_t<typename matB_t::dtype> zero_pt_pack =
zero_pt.reg[zero_pt_idx];

int8_t zero_pt_i8 =
(zero_pt_pack >>
(4 * ((args.wg_start_n + offset_x_in_tile) % pack_ratio))) &
0xf;
// sycl::ext::oneapi::experimental::printf(
// "zero_pt.reg[%d} %x zero_pt_i8 %x offset_x_in_tile:%d
// \n", zero_pt_idx, zero_pt_pack, (int32_t)zero_pt_i8 ,
// offset_x_in_tile);

cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
zero_pt_i8;
add_number = scale_value * zero_pt_i8;
} else if constexpr (quant_mode == quant_mode::I4_SYM) {
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
int8_t(8);
add_number = scale_value * int8_t(-8);
}
#pragma unroll
for (uint32_t iii = 0; iii < step; iii += 16) {
dst_blk.xetla_select<16, 1>(jj * block_size_y_b + ii + iii) =
dst_blk.xetla_select<16, 1>(jj * block_size_y_b + ii + iii) *
scale_value +
add_number;
}
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
scale.reg[scale_idx];

// sycl::ext::oneapi::experimental::printf(
// "scale[%d] %f \n",
// scale_idx,
// float(sycl::half(scale.reg.xetla_select<1, 1>(scale_idx))));
}
}
}
Expand Down

0 comments on commit 47ceae0

Please sign in to comment.