Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
XeTLA Zero-Passthrough (#321)
Browse files Browse the repository at this point in the history
* sync fmha

* support pass_thru for 2024.1

* XeTLA use mask with zero-passthrough

* reformat

---------

Co-authored-by: Sun, Jiwei1 <[email protected]>
  • Loading branch information
DDEle and sunjiweiswift authored Aug 9, 2024
1 parent 16a9a20 commit 8fc6c57
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 73 deletions.
11 changes: 11 additions & 0 deletions include/common/core/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,23 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
xetla_vector<OffsetT, N / VS> byte_offsets,
xetla_mask<N / VS> mask,
xetla_vector<T, N> pass_thru) {
#if __INTEL_LLVM_COMPILER >= 20240200
__ESIMD_NS::properties props{
__ESIMD_NS::cache_hint_L1<gpu::xetla::detail::get_cache_hint(L1H)>,
__ESIMD_NS::cache_hint_L2<gpu::xetla::detail::get_cache_hint(L2H)>,
__ESIMD_NS::alignment<alignment>};

return __ESIMD_NS::gather<T, N, VS>(p, byte_offsets, mask, pass_thru, props);
#else
constexpr data_size DS = data_size::default_size;
return __ESIMD_ENS::lsc_gather<
T,
VS,
gpu::xetla::detail::get_data_size(DS),
gpu::xetla::detail::get_cache_hint(L1H),
gpu::xetla::detail::get_cache_hint(L2H),
N / VS>(p, byte_offsets, mask, pass_thru);
#endif
}

/// template <typename T, int N, int VS, typename OffsetT,
Expand Down
38 changes: 23 additions & 15 deletions include/subgroup/tile/impl/load_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ tile_load(tile_t& tile, payload_t& payload) {
constexpr uint32_t num_channel = payload_t::num_channel;
constexpr uint32_t load_elems = num_channel * payload_t::vector_size;
constexpr uint32_t pack_factor = payload_t::pack_factor;
const xetla_vector<load_dtype, load_elems> reg_zeros(0);

auto channel_offset = payload.channel_offset + payload.base_offset;
#pragma unroll
Expand Down Expand Up @@ -494,28 +495,35 @@ tile_load(tile_t& tile, payload_t& payload) {
? (xetla_vector_gen<uint32_t, num_channel>(offset_ch_dim, 1) <
size_ch_dim)
: 1;
reg_tmp = xetla_load_global<
load_dtype,
load_elems,
payload_t::vector_size,
L1,
L2>(
payload.base_ptr,
channel_offset + address_offset,
mask,
reg_zeros);
} else {
reg_tmp = xetla_load_global<
load_dtype,
load_elems,
payload_t::vector_size,
L1,
L2>(payload.base_ptr, channel_offset + address_offset, mask);
}
reg_tmp = xetla_load_global<
load_dtype,
load_elems,
payload_t::vector_size,
L1,
L2>(payload.base_ptr, channel_offset + address_offset, mask);

if constexpr (
payload_t::vector_size > 1 && payload_t::num_channel > 1) {
xetla_vector<load_dtype, load_elems> reg_tmp_trans;
#pragma unroll
for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) {
if ((bool)mask[iii]) // TODO (dingyi): Delete after driver fix
reg_tmp_trans.xetla_select<payload_t::vector_size, 1>(
iii * payload_t::vector_size) =
reg_tmp.xetla_select<
payload_t::vector_size,
payload_t::num_channel>(iii);
else // TODO (dingyi): Delete after driver fix
reg_tmp_trans.xetla_select<payload_t::vector_size, 1>(
iii * payload_t::vector_size) = 0;
reg_tmp_trans.xetla_select<payload_t::vector_size, 1>(
iii * payload_t::vector_size) =
reg_tmp.xetla_select<
payload_t::vector_size,
payload_t::num_channel>(iii);
}
reg_sub
.xetla_select<load_elems * pack_factor, 1>(
Expand Down
18 changes: 8 additions & 10 deletions include/subgroup/tile/impl/payload_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1655,12 +1655,11 @@ struct prefetch_payload_t<
reg_layout_>,
num_coop_sg_,
arch_tag_,
std::enable_if_t<
(!arch_has_2d_load_store<arch_tag_>) &&
(((block_size_y_ != 1 || tile_size_y_ != 1) &&
mem_layout_ == mem_layout::row_major) ||
((block_size_x_ != 1 || tile_size_x_ != 1) &&
mem_layout_ == mem_layout::col_major))>> {
std::enable_if_t<(!arch_has_2d_load_store<arch_tag_>)&&(
((block_size_y_ != 1 || tile_size_y_ != 1) &&
mem_layout_ == mem_layout::row_major) ||
((block_size_x_ != 1 || tile_size_x_ != 1) &&
mem_layout_ == mem_layout::col_major))>> {
using dtype = native_type_t<dtype_>;
using mem_desc_t =
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
Expand Down Expand Up @@ -1902,10 +1901,9 @@ struct prefetch_payload_t<
reg_layout_>,
num_coop_sg_,
arch_tag_,
std::enable_if_t<
(arch_has_2d_load_store<arch_tag_>) &&
(((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
std::enable_if_t<(arch_has_2d_load_store<arch_tag_>)&&(
((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
using dtype = dtype_;
using mem_desc_t =
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
Expand Down
69 changes: 31 additions & 38 deletions tests/integration/fmha/fmha_forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ class fmha_forward_t {
using comp_attr = group::compute_attr_t<scalar_t, scalar_t, accum_t>;
using knobs = group::perf_tuning_knob_t<accum_step, stages, sync_freq>;
using compute_policy_BrBc = std::conditional_t<
(arch_tag >= gpu_arch::XeHpg),
(arch_has_xmx<arch_tag>),
group::compute_policy_default_xmx<comp_attr, knobs, arch_tag>,
group::compute_policy_default_fpu<comp_attr, knobs, arch_tag>>;
// TODO: add k slicing
using compute_policy_BrBm = std::conditional_t<
(arch_tag >= gpu_arch::XeHpg),
(arch_has_xmx<arch_tag>),
group::compute_policy_default_xmx<comp_attr, knobs, arch_tag>,
group::compute_policy_default_fpu<comp_attr, knobs, arch_tag>>;
// ---------------- // Tile shape and Threads // ---------------- //
Expand Down Expand Up @@ -688,7 +688,7 @@ class fmha_forward_t {
uint8_t,
mem_desc_Dp_Mask_t::layout,
mem_desc_Dp_Mask_t::space>>,
gpu_arch::XeHpc>;
arch_tag>;
load_payload_mask_t load_payload_mask(ctx.mem_desc_Dpij);
subgroup::tile_load(mask_in, load_payload_mask);
matAccSij.reg = matAccSij.reg * mask_in.reg * args.dp_scale;
Expand Down Expand Up @@ -771,7 +771,7 @@ class fmha_forward_t {
uint32_t height = args.uB * args.uN * args.uF;
uint32_t offset_height = b * args.uN * args.uF + f * args.uN + n;

if constexpr (arch_tag != gpu_arch::XeHpc) {
if constexpr (!arch_has_2d_load_store<arch_tag>) {
// offset for curr work item
const uint32_t O_offset = offset_height * args.uH + h;
const auto ld_c = args.uN * args.uH;
Expand All @@ -798,30 +798,30 @@ class fmha_forward_t {
matOi_store_t matOi_store(mem_desc_Oi);
subgroup::tile_store<cache_hint::write_back, cache_hint::write_back>(
matOi, matOi_store);
return;
}

xetla_fill_tdesc<scalar_t, kSgHm, 1, 1>(
transpose_tdecs.xetla_format<uint32_t>(),
args.O_ptr,
args.uH,
height,
args.uH,
h,
offset_height);

for (uint32_t i = 0; i < kSgBr && (f + i < args.uF); ++i) {
// load data from matAccOi
auto v_acc = matAccOi.reg.xetla_select<kSgHm, 1>(i * kSgHm);
v_out = xetla_cvt<scalar_t, accum_t, kSgHm>(v_acc);

xetla_tstore_global<
scalar_t,
kSgHm,
cache_hint::write_back,
cache_hint::write_back>(transpose_tdecs, v_out);
xetla_update_tdesc_offsety(
transpose_tdecs.xetla_format<uint32_t>(), args.uN);
} else {
xetla_fill_tdesc<scalar_t, kSgHm, 1, 1>(
transpose_tdecs.xetla_format<uint32_t>(),
args.O_ptr,
args.uH,
height,
args.uH,
h,
offset_height);

for (uint32_t i = 0; i < kSgBr && (f + i < args.uF); ++i) {
// load data from matAccOi
auto v_acc = matAccOi.reg.xetla_select<kSgHm, 1>(i * kSgHm);
v_out = xetla_cvt<scalar_t, accum_t, kSgHm>(v_acc);

xetla_tstore_global<
scalar_t,
kSgHm,
cache_hint::write_back,
cache_hint::write_back,
arch_tag>(transpose_tdecs, v_out);
xetla_update_tdesc_offsety(
transpose_tdecs.xetla_format<uint32_t>(), args.uN);
}
}
}
// ====================== // preload_Qi // ====================== //
Expand Down Expand Up @@ -888,16 +888,9 @@ class fmha_forward_t {
/// @return The size of local memory required.
inline static constexpr uint32_t get_slm_size() {
constexpr uint32_t size = slm_size_Qi + slm_size_Pij + slm_size_softmax;
if constexpr (arch_tag == gpu_arch::XeHpc) {
static_assert(
size <= (128 * 1024),
"The local memory size should be less than 128KB!");

} else {
static_assert(
size <= (64 * 1024),
"The local memory size should be less than 64KB!");
}
static_assert(
size <= (arch_attr_t<arch_tag>::local_mem_size),
"The local memory size should be less than arch total local memory size");
return size;
};

Expand Down
23 changes: 13 additions & 10 deletions tests/integration/fmha/fmha_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ template <
typename mat_t,
uint32_t kNumSg,
reduce_op reduce_kind,
gpu_arch arch_tag = gpu_arch::XeHpc>
gpu_arch arch_tag>
struct group_row_reduce_t {
using T = typename mat_t::dtype;
static constexpr uint32_t kNum = mat_t::tile_desc::tile_size_y;
Expand Down Expand Up @@ -215,7 +215,7 @@ enum class add_type : uint8_t {
/// @tparam arch_tag Is the hardware architecture tag.
template <
typename dtype_bias_,
gpu_arch arch_tag = gpu_arch::XeHpc,
gpu_arch arch_tag,
add_type add_tag = add_type::single_line>
struct bias_add_op_t {};

Expand Down Expand Up @@ -324,8 +324,8 @@ struct bias_add_op_t<dtype_bias_, arch_tag, add_type::single_element> {
using base_t = typename mem_desc_bias_t::base_t;

struct arguments_t {
shape_t shape;
base_t base;
shape_t shape;
inline arguments_t() = default;
inline arguments_t(base_t base_, shape_t shape_)
: base(base_), shape(shape_) {}
Expand All @@ -351,11 +351,10 @@ struct bias_add_op_t<dtype_bias_, arch_tag, add_type::single_element> {
uint32_t offset = (pos_y + pos_x * args.shape.stride) * sizeof(dtype_bias);
auto bias_data_vector = xetla_load_global<
dtype_bias,
16,
1,
data_size::default_size,
cache_hint::cached,
cache_hint::cached,
16>(ptr, offset);
cache_hint::cached>(ptr, offset);
dtype_acc bias_data =
xetla_cvt<dtype_acc, dtype_bias, 16>(bias_data_vector)[0];

Expand Down Expand Up @@ -418,15 +417,19 @@ template <
typename mem_desc_c_t_>
class epilogue_transp_t {};

template <typename tile_op_t_, typename tile_shape_, typename mem_desc_c_t_>
template <
typename tile_op_t_,
typename tile_shape_,
typename mem_desc_c_t_,
gpu_arch arch_tag_>
class epilogue_transp_t<
epilogue_policy_tile_op<tile_op_t_, gpu_arch::XeHpc>,
epilogue_policy_tile_op<tile_op_t_, arch_tag_>,
tile_shape_,
mem_desc_c_t_> {
public:
using tile_shape = tile_shape_;
using mem_desc_c_t = mem_desc_c_t_;
static constexpr gpu_arch arch_tag = gpu_arch::XeHpc;
static constexpr gpu_arch arch_tag = arch_tag_;
static constexpr uint32_t barrier_count = 0;
static constexpr uint32_t slm_size = 0;

Expand Down Expand Up @@ -505,7 +508,7 @@ class epilogue_write_back_t<
epilogue_policy_default<arch_tag_>,
tile_shape_,
mem_desc_c_t_,
std::enable_if_t<((arch_tag_ <= gpu_arch::XeHpc))>> {
std::enable_if_t<valid_xe_arch_tag<arch_tag_>>> {
public:
using epilogue_policy = epilogue_policy_default<arch_tag_>;
using tile_shape = tile_shape_;
Expand Down

0 comments on commit 8fc6c57

Please sign in to comment.