Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
XeTLA Sync FMHA Tests (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
DDEle authored Jul 24, 2024
1 parent 1081543 commit b733d92
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 39 deletions.
3 changes: 3 additions & 0 deletions tests/integration/fmha/fmha.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ void fmha_run_(const test_params_t& p, uint32_t iter, uint32_t warmup) {
false,
kSeqLast,
false,
false,
false>;
using accum_t = typename fmha_forward_op_t::accum_t;

Expand Down Expand Up @@ -346,6 +347,8 @@ void fmha_run_(const test_params_t& p, uint32_t iter, uint32_t warmup) {
kUseBias ? klen_pad32 * qlen : 0,
kUseBias ? 0 : 0, // broadcast on N (head num)
kUseBias ? klen_pad32 : 0,
nullptr,
nullptr,
softmax_scale,
0,
0,
Expand Down
142 changes: 113 additions & 29 deletions tests/integration/fmha/fmha_forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ Fused Multi-Head Attention Forward
This is an implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf)
*/
#include <sys/types.h>
#include <limits>
#include "fmha_forward_policy.h"
#include "fmha_utils.h"

namespace gpu::xetla {
namespace fmha {

template <
typename fmha_policy,
typename scalar_t,
Expand All @@ -21,7 +23,8 @@ template <
bool kIsCausal,
bool kSeqLast,
bool kIsTraining,
bool kIsDropout>
bool kIsDropout,
bool kVarlen>
class fmha_forward_t {
public:
using accum_t = float;
Expand All @@ -47,6 +50,9 @@ class fmha_forward_t {
uint32_t bias_strideB;
uint32_t bias_strideN;
uint32_t bias_strideF;
// Sequence length info
int32_t* cu_seqlen_q;
int32_t* cu_seqlen_k;
// Softmax scale is the reciprocal square root of head size by default
accum_t sm_scale;
// Dropout scale is computed from dropout prob
Expand Down Expand Up @@ -77,6 +83,8 @@ class fmha_forward_t {
uint32_t bias_strideB,
uint32_t bias_strideN,
uint32_t bias_strideF,
int32_t* cu_seqlen_q,
int32_t* cu_seqlen_k,
accum_t sm_scale,
accum_t dropout_prob,
uint32_t alibi_padded_block_size,
Expand All @@ -100,6 +108,8 @@ class fmha_forward_t {
bias_strideB(bias_strideB),
bias_strideN(bias_strideN),
bias_strideF(bias_strideF),
cu_seqlen_q(cu_seqlen_q),
cu_seqlen_k(cu_seqlen_k),
sm_scale(sm_scale),
dp_prob(dropout_prob),
dp_scale(1.f / (1.f - dropout_prob)),
Expand All @@ -115,31 +125,25 @@ class fmha_forward_t {
static constexpr uint32_t accum_step = fmha_policy::accum_step;
static constexpr uint32_t stages = fmha_policy::stages;
static constexpr uint32_t sync_freq = fmha_policy::sync_freq;
static constexpr uint32_t kBr = fmha_policy::kBr;
static constexpr uint32_t kBc = fmha_policy::kBc;
static constexpr uint32_t kHm = fmha_policy::kHm;
static constexpr uint32_t kSgBr = fmha_policy::kSgBr;
static constexpr uint32_t kSgBc = fmha_policy::kSgBc;
static constexpr uint32_t kSgHm = fmha_policy::kSgHm;

using comp_attr = std::conditional_t<
std::is_same_v<scalar_t, bf16> && (arch_tag < gpu_arch::XeHpc),
group::compute_attr_t<accum_t, accum_t, accum_t>,
group::compute_attr_t<scalar_t, scalar_t, accum_t>>;
using comp_attr = group::compute_attr_t<scalar_t, scalar_t, accum_t>;
using knobs = group::perf_tuning_knob_t<accum_step, stages, sync_freq>;

// use fpu when M==1 even if xmx is available
static constexpr bool _use_xmx = arch_tag >= gpu_arch::XeHpg && kSgBr != 1;
using compute_policy_BrBc = std::conditional_t<
_use_xmx,
(arch_tag >= gpu_arch::XeHpg),
group::compute_policy_default_xmx<comp_attr, knobs, arch_tag>,
group::compute_policy_default_fpu<comp_attr, knobs, arch_tag>>;
// TODO(Yi): add k slicing?
// TODO: add k slicing
using compute_policy_BrBm = std::conditional_t<
_use_xmx,
(arch_tag >= gpu_arch::XeHpg),
group::compute_policy_default_xmx<comp_attr, knobs, arch_tag>,
group::compute_policy_default_fpu<comp_attr, knobs, arch_tag>>;
// ---------------- // Tile shape and Threads // ---------------- //
static constexpr uint32_t kBr = fmha_policy::kBr;
static constexpr uint32_t kBc = fmha_policy::kBc;
static constexpr uint32_t kHm = fmha_policy::kHm;
static constexpr uint32_t kSgBr = fmha_policy::kSgBr;
static constexpr uint32_t kSgBc = fmha_policy::kSgBc;
static constexpr uint32_t kSgHm = fmha_policy::kSgHm;

using tile_shape_BrBc = group::tile_shape_t<kBc, kBr, kSgBc, kSgBr>;
using tile_shape_BrHm = group::tile_shape_t<kHm, kBr, kSgHm, kSgBr>;
Expand Down Expand Up @@ -268,6 +272,21 @@ class fmha_forward_t {
args.O_ptr,
{end_x, end_y, b_stride * args.uB},
{start_acc, start_y});
} else if constexpr (kVarlen) {
int32_t start_y = args.cu_seqlen_q[batch_id] + item.get_group(1) * kBr;
uint32_t end_y = start_y + kBr;
int32_t limit_y = args.cu_seqlen_q[batch_id + 1];
end_y = end_y < limit_y ? end_y : limit_y;

int32_t start_acc = head_id * args.uH;
uint32_t end_x = start_acc + args.uH;
const uint32_t ld_qo = args.uH * args.uN;

mem_desc_Qi.init(
args.Q_ptr, {end_x, end_y, ld_qo}, {start_acc, start_y});

mem_desc_Oi.init(
args.O_ptr, {end_x, end_y, ld_qo}, {start_acc, start_y});
} else { // 2d mem: [BxF, NxH]
// startF
int32_t start_y = batch_id * args.uF + item.get_group(1) * kBr;
Expand All @@ -277,16 +296,13 @@ class fmha_forward_t {
end_y = end_y > boundary_y ? boundary_y : end_y;

int32_t start_acc = head_id * args.uH;
uint32_t end_acc = start_acc + args.uH;
const uint32_t ld_qo = args.uH * args.uN;

mem_desc_Qi.init(
args.Q_ptr,
{args.uH * args.uN, end_y, ld_qo},
{start_acc, start_y});
args.Q_ptr, {end_acc, end_y, ld_qo}, {start_acc, start_y});
mem_desc_Oi.init(
args.O_ptr,
{args.uH * args.uN, end_y, ld_qo},
{start_acc, start_y});
args.O_ptr, {end_acc, end_y, ld_qo}, {start_acc, start_y});
}

int32_t start_x_ml = item.get_group(1) * kBr + sg_idy * kSgBr;
Expand Down Expand Up @@ -331,27 +347,46 @@ class fmha_forward_t {
args.V_ptr,
{end_y, end_x, b_stride * args.uB},
{start_acc, start_x});
} else if (kVarlen) {
int32_t start_x = startT + args.cu_seqlen_k[batch_id];
uint32_t end_x = start_x + kBc;
int32_t limit_x = args.cu_seqlen_k[batch_id + 1];
end_x = end_x < limit_x ? end_x : limit_x;

int32_t start_acc = head_id * args.uNkv / args.uN * args.uH;
uint32_t end_y = start_acc + args.uH;
mem_desc_Kj_T.init(
args.K_ptr,
{end_x, end_y, args.uNkv * args.uH},
{start_x, start_acc});

mem_desc_Vj.init(
args.V_ptr,
{end_y, end_x, args.uNkv * args.uH},
{start_acc, start_x});

} else {
int32_t start_x = batch_id * args.uT + startT;
uint32_t end_x = start_x + kBc;
uint32_t boundary_x = (batch_id + 1) * args.uT;
end_x = end_x > boundary_x ? boundary_x : end_x;

int32_t start_acc = head_id_kv * args.uH;
uint32_t end_acc = start_acc + args.uH;

mem_desc_Kj_T.init(
args.K_ptr,
{end_x, args.uH * args.uNkv, args.uH * args.uNkv},
{end_x, end_acc, args.uH * args.uNkv},
{start_x, start_acc});
mem_desc_Vj.init(
args.V_ptr,
{args.uH * args.uNkv, end_x, args.uH * args.uNkv},
{end_acc, end_x, args.uH * args.uNkv},
{start_acc, start_x});
}

// B, N, 1, T
// gid * T + startT
if constexpr (kUseAlibi) {
if constexpr (kUseAlibi && !kVarlen) {
int32_t batch_start = gid * args.uAT;
int32_t start_x = batch_start + startT;
uint32_t end_x = startT + kBc;
Expand All @@ -363,6 +398,15 @@ class fmha_forward_t {
args.A_ptr, {end_x, 1, args.uAT * args.uN * args.uB}, {start_x, 0});
}

// B, N or N
if constexpr (kUseAlibi && kVarlen) {
// assume uAt in varlen equals N or 0
int32_t start_x = batch_id * args.uAT + head_id;
uint32_t end_x = start_x + 1;
end_x = end_x >= args.uN ? end_x : args.uN;
mem_desc_Ai.init(args.A_ptr, {end_x, 1, 1}, {start_x, 0});
}

if constexpr (kUseBias && !kIsCausal) {
int32_t start_x = startT;
uint32_t end_x = start_x + kBc;
Expand Down Expand Up @@ -442,7 +486,7 @@ class fmha_forward_t {
matAccSij.reg *= args.sm_scale;

// + beta * alibi
if constexpr (kUseAlibi) {
if constexpr (kUseAlibi && !kVarlen) {
using alibi_op_t = bias_add_op_t<scalar_t, arch_tag>;
using alibi_args_t = typename alibi_op_t::arguments_t;

Expand All @@ -455,6 +499,16 @@ class fmha_forward_t {
alibi_op(matAccSij, ctx.mem_desc_Ai.coord, alibi_args);
}

if constexpr (kUseAlibi && kVarlen) {
using alibi_op_t =
bias_add_op_t<scalar_t, arch_tag, add_type::single_element>;
using alibi_args_t = typename alibi_op_t::arguments_t;

alibi_op_t alibi_op;
alibi_args_t alibi_args(ctx.mem_desc_Ai.base, ctx.mem_desc_Ai.shape);
alibi_op(matAccSij, ctx.mem_desc_Ai.coord, alibi_args);
}

// Add attn_mask if needed
if constexpr (kUseBias && !kIsCausal) {
if (args.is_bias_add) {
Expand Down Expand Up @@ -533,14 +587,22 @@ class fmha_forward_t {

/// @brief apply mask to matAccSij.
inline void apply_mask(
nd_item<3>& item,
matAccSij_t& matAccSij,
arguments_t& args,
uint32_t startF,
uint32_t startT) {
using tile_mask = tile_mask_t<matAccSij_t>;

uint32_t sg_startT = startT + ctx.sg_idx * kSgBc;
uint32_t remainT = std::max(int(args.uT) - int(sg_startT), 0);
uint32_t real_T;
if constexpr (kVarlen) {
int32_t batch_id = item.get_group(0) / args.uN;
real_T = args.cu_seqlen_k[batch_id + 1] - args.cu_seqlen_k[batch_id];
} else {
real_T = args.uT;
}
uint32_t remainT = std::max(int(real_T) - int(sg_startT), 0);
if (remainT < kSgBc) {
tile_mask::padding_mask(matAccSij, remainT);
}
Expand Down Expand Up @@ -867,6 +929,19 @@ class fmha_forward_t {

// initialize context for flash mha loops
ctx.init_context(item, args);
uint32_t gid = item.get_group(0);
uint32_t batch_id = gid / args.uN; // get batch idx
// Early exit when current thread access data exceed actual seqlen in varlen
// fwd
if constexpr (kVarlen) {
int32_t actual_seqlen_q =
args.cu_seqlen_q[batch_id + 1] - args.cu_seqlen_q[batch_id];
int32_t seqlen_q = item.get_group(1) * kBr;

if (seqlen_q >= actual_seqlen_q) {
return;
}
}
// preload Qi to local memory
preload_Qi(args);
// initialize matAccOi for accumulate the output
Expand All @@ -877,6 +952,15 @@ class fmha_forward_t {

// iterate through the keys
for (uint32_t startT = 0; startT < args.uT; startT += kBc) {
// Early leave for varlen_fwd if we found current seqlen exceed the actual
// seqlen.
if constexpr (kVarlen) {
int32_t actual_seqlen =
args.cu_seqlen_k[batch_id + 1] - args.cu_seqlen_k[batch_id];
if (startT >= actual_seqlen) {
break;
}
}
if constexpr (kIsCausal) {
if (startT >= endF)
break;
Expand All @@ -887,7 +971,7 @@ class fmha_forward_t {
matAccSij_t matAccSij(0);
gemm_Sij(matAccSij, args);
// apply mask
apply_mask(matAccSij, args, startF, startT);
apply_mask(item, matAccSij, args, startF, startT);
// softmax
dp_mask_tile_t mask_in;
softmax_fwd(matAccSij, matAccOi, mask_in, args);
Expand Down
Loading

0 comments on commit b733d92

Please sign in to comment.