diff --git a/tests/integration/fmha/fmha.cpp b/tests/integration/fmha/fmha.cpp index 1921cf206..7c881c247 100644 --- a/tests/integration/fmha/fmha.cpp +++ b/tests/integration/fmha/fmha.cpp @@ -245,6 +245,7 @@ void fmha_run_(const test_params_t& p, uint32_t iter, uint32_t warmup) { false, kSeqLast, false, + false, false>; using accum_t = typename fmha_forward_op_t::accum_t; @@ -346,6 +347,8 @@ void fmha_run_(const test_params_t& p, uint32_t iter, uint32_t warmup) { kUseBias ? klen_pad32 * qlen : 0, kUseBias ? 0 : 0, // broadcast on N (head num) kUseBias ? klen_pad32 : 0, + nullptr, + nullptr, softmax_scale, 0, 0, diff --git a/tests/integration/fmha/fmha_forward.hpp b/tests/integration/fmha/fmha_forward.hpp index 623ea6d7f..03a2bc99a 100644 --- a/tests/integration/fmha/fmha_forward.hpp +++ b/tests/integration/fmha/fmha_forward.hpp @@ -6,12 +6,14 @@ Fused Multi-Head Attention Forward This is an implementation of the Flash Attention algorithm (see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf) */ +#include #include #include "fmha_forward_policy.h" #include "fmha_utils.h" namespace gpu::xetla { namespace fmha { + template < typename fmha_policy, typename scalar_t, @@ -21,7 +23,8 @@ template < bool kIsCausal, bool kSeqLast, bool kIsTraining, - bool kIsDropout> + bool kIsDropout, + bool kVarlen> class fmha_forward_t { public: using accum_t = float; @@ -47,6 +50,9 @@ class fmha_forward_t { uint32_t bias_strideB; uint32_t bias_strideN; uint32_t bias_strideF; + // Sequence length info + int32_t* cu_seqlen_q; + int32_t* cu_seqlen_k; // Softmax scale is the reciprocal square root of head size by default accum_t sm_scale; // Dropout scale is computed from dropout prob @@ -77,6 +83,8 @@ class fmha_forward_t { uint32_t bias_strideB, uint32_t bias_strideN, uint32_t bias_strideF, + int32_t* cu_seqlen_q, + int32_t* cu_seqlen_k, accum_t sm_scale, accum_t dropout_prob, uint32_t alibi_padded_block_size, @@ -100,6 +108,8 @@ class fmha_forward_t { bias_strideB(bias_strideB), bias_strideN(bias_strideN), bias_strideF(bias_strideF), + cu_seqlen_q(cu_seqlen_q), + cu_seqlen_k(cu_seqlen_k), sm_scale(sm_scale), dp_prob(dropout_prob), dp_scale(1.f / (1.f - dropout_prob)), @@ -115,31 +125,25 @@ class fmha_forward_t { static constexpr uint32_t accum_step = fmha_policy::accum_step; static constexpr uint32_t stages = fmha_policy::stages; static constexpr uint32_t sync_freq = fmha_policy::sync_freq; - static constexpr uint32_t kBr = fmha_policy::kBr; - static constexpr uint32_t kBc = fmha_policy::kBc; - static constexpr uint32_t kHm = fmha_policy::kHm; - static constexpr uint32_t kSgBr = fmha_policy::kSgBr; - static constexpr uint32_t kSgBc = fmha_policy::kSgBc; - static constexpr uint32_t kSgHm = fmha_policy::kSgHm; - using comp_attr = std::conditional_t< - std::is_same_v && (arch_tag < gpu_arch::XeHpc), - group::compute_attr_t, - group::compute_attr_t>; + using comp_attr = group::compute_attr_t; using knobs = group::perf_tuning_knob_t; - - // use fpu when M==1 even if xmx is available - static constexpr bool _use_xmx = arch_tag >= gpu_arch::XeHpg && kSgBr != 1; using compute_policy_BrBc = std::conditional_t< - _use_xmx, + (arch_tag >= gpu_arch::XeHpg), group::compute_policy_default_xmx, group::compute_policy_default_fpu>; - // TODO(Yi): add k slicing? + // TODO: add k slicing using compute_policy_BrBm = std::conditional_t< - _use_xmx, + (arch_tag >= gpu_arch::XeHpg), group::compute_policy_default_xmx, group::compute_policy_default_fpu>; // ---------------- // Tile shape and Threads // ---------------- // + static constexpr uint32_t kBr = fmha_policy::kBr; + static constexpr uint32_t kBc = fmha_policy::kBc; + static constexpr uint32_t kHm = fmha_policy::kHm; + static constexpr uint32_t kSgBr = fmha_policy::kSgBr; + static constexpr uint32_t kSgBc = fmha_policy::kSgBc; + static constexpr uint32_t kSgHm = fmha_policy::kSgHm; using tile_shape_BrBc = group::tile_shape_t; using tile_shape_BrHm = group::tile_shape_t; @@ -268,6 +272,21 @@ class fmha_forward_t { args.O_ptr, {end_x, end_y, b_stride * args.uB}, {start_acc, start_y}); + } else if constexpr (kVarlen) { + int32_t start_y = args.cu_seqlen_q[batch_id] + item.get_group(1) * kBr; + uint32_t end_y = start_y + kBr; + int32_t limit_y = args.cu_seqlen_q[batch_id + 1]; + end_y = end_y < limit_y ? end_y : limit_y; + + int32_t start_acc = head_id * args.uH; + uint32_t end_x = start_acc + args.uH; + const uint32_t ld_qo = args.uH * args.uN; + + mem_desc_Qi.init( + args.Q_ptr, {end_x, end_y, ld_qo}, {start_acc, start_y}); + + mem_desc_Oi.init( + args.O_ptr, {end_x, end_y, ld_qo}, {start_acc, start_y}); } else { // 2d mem: [BxF, NxH] // startF int32_t start_y = batch_id * args.uF + item.get_group(1) * kBr; @@ -277,16 +296,13 @@ class fmha_forward_t { end_y = end_y > boundary_y ? boundary_y : end_y; int32_t start_acc = head_id * args.uH; + uint32_t end_acc = start_acc + args.uH; const uint32_t ld_qo = args.uH * args.uN; mem_desc_Qi.init( - args.Q_ptr, - {args.uH * args.uN, end_y, ld_qo}, - {start_acc, start_y}); + args.Q_ptr, {end_acc, end_y, ld_qo}, {start_acc, start_y}); mem_desc_Oi.init( - args.O_ptr, - {args.uH * args.uN, end_y, ld_qo}, - {start_acc, start_y}); + args.O_ptr, {end_acc, end_y, ld_qo}, {start_acc, start_y}); } int32_t start_x_ml = item.get_group(1) * kBr + sg_idy * kSgBr; @@ -331,6 +347,24 @@ class fmha_forward_t { args.V_ptr, {end_y, end_x, b_stride * args.uB}, {start_acc, start_x}); + } else if (kVarlen) { + int32_t start_x = startT + args.cu_seqlen_k[batch_id]; + uint32_t end_x = start_x + kBc; + int32_t limit_x = args.cu_seqlen_k[batch_id + 1]; + end_x = end_x < limit_x ? end_x : limit_x; + + int32_t start_acc = head_id * args.uNkv / args.uN * args.uH; + uint32_t end_y = start_acc + args.uH; + mem_desc_Kj_T.init( + args.K_ptr, + {end_x, end_y, args.uNkv * args.uH}, + {start_x, start_acc}); + + mem_desc_Vj.init( + args.V_ptr, + {end_y, end_x, args.uNkv * args.uH}, + {start_acc, start_x}); + } else { int32_t start_x = batch_id * args.uT + startT; uint32_t end_x = start_x + kBc; @@ -338,20 +372,21 @@ class fmha_forward_t { end_x = end_x > boundary_x ? boundary_x : end_x; int32_t start_acc = head_id_kv * args.uH; + uint32_t end_acc = start_acc + args.uH; mem_desc_Kj_T.init( args.K_ptr, - {end_x, args.uH * args.uNkv, args.uH * args.uNkv}, + {end_x, end_acc, args.uH * args.uNkv}, {start_x, start_acc}); mem_desc_Vj.init( args.V_ptr, - {args.uH * args.uNkv, end_x, args.uH * args.uNkv}, + {end_acc, end_x, args.uH * args.uNkv}, {start_acc, start_x}); } // B, N, 1, T // gid * T + startT - if constexpr (kUseAlibi) { + if constexpr (kUseAlibi && !kVarlen) { int32_t batch_start = gid * args.uAT; int32_t start_x = batch_start + startT; uint32_t end_x = startT + kBc; @@ -363,6 +398,15 @@ class fmha_forward_t { args.A_ptr, {end_x, 1, args.uAT * args.uN * args.uB}, {start_x, 0}); } + // B, N or N + if constexpr (kUseAlibi && kVarlen) { + // assume uAt in varlen equals N or 0 + int32_t start_x = batch_id * args.uAT + head_id; + uint32_t end_x = start_x + 1; + end_x = end_x >= args.uN ? end_x : args.uN; + mem_desc_Ai.init(args.A_ptr, {end_x, 1, 1}, {start_x, 0}); + } + if constexpr (kUseBias && !kIsCausal) { int32_t start_x = startT; uint32_t end_x = start_x + kBc; @@ -442,7 +486,7 @@ class fmha_forward_t { matAccSij.reg *= args.sm_scale; // + beta * alibi - if constexpr (kUseAlibi) { + if constexpr (kUseAlibi && !kVarlen) { using alibi_op_t = bias_add_op_t; using alibi_args_t = typename alibi_op_t::arguments_t; @@ -455,6 +499,16 @@ class fmha_forward_t { alibi_op(matAccSij, ctx.mem_desc_Ai.coord, alibi_args); } + if constexpr (kUseAlibi && kVarlen) { + using alibi_op_t = + bias_add_op_t; + using alibi_args_t = typename alibi_op_t::arguments_t; + + alibi_op_t alibi_op; + alibi_args_t alibi_args(ctx.mem_desc_Ai.base, ctx.mem_desc_Ai.shape); + alibi_op(matAccSij, ctx.mem_desc_Ai.coord, alibi_args); + } + // Add attn_mask if needed if constexpr (kUseBias && !kIsCausal) { if (args.is_bias_add) { @@ -533,6 +587,7 @@ class fmha_forward_t { /// @brief apply mask to matAccSij. inline void apply_mask( + nd_item<3>& item, matAccSij_t& matAccSij, arguments_t& args, uint32_t startF, @@ -540,7 +595,14 @@ class fmha_forward_t { using tile_mask = tile_mask_t; uint32_t sg_startT = startT + ctx.sg_idx * kSgBc; - uint32_t remainT = std::max(int(args.uT) - int(sg_startT), 0); + uint32_t real_T; + if constexpr (kVarlen) { + int32_t batch_id = item.get_group(0) / args.uN; + real_T = args.cu_seqlen_k[batch_id + 1] - args.cu_seqlen_k[batch_id]; + } else { + real_T = args.uT; + } + uint32_t remainT = std::max(int(real_T) - int(sg_startT), 0); if (remainT < kSgBc) { tile_mask::padding_mask(matAccSij, remainT); } @@ -867,6 +929,19 @@ class fmha_forward_t { // initialize context for flash mha loops ctx.init_context(item, args); + uint32_t gid = item.get_group(0); + uint32_t batch_id = gid / args.uN; // get batch idx + // Early exit when current thread access data exceed actual seqlen in varlen + // fwd + if constexpr (kVarlen) { + int32_t actual_seqlen_q = + args.cu_seqlen_q[batch_id + 1] - args.cu_seqlen_q[batch_id]; + int32_t seqlen_q = item.get_group(1) * kBr; + + if (seqlen_q >= actual_seqlen_q) { + return; + } + } // preload Qi to local memory preload_Qi(args); // initialize matAccOi for accumulate the output @@ -877,6 +952,15 @@ class fmha_forward_t { // iterate through the keys for (uint32_t startT = 0; startT < args.uT; startT += kBc) { + // Early leave for varlen_fwd if we found current seqlen exceed the actual + // seqlen. + if constexpr (kVarlen) { + int32_t actual_seqlen = + args.cu_seqlen_k[batch_id + 1] - args.cu_seqlen_k[batch_id]; + if (startT >= actual_seqlen) { + break; + } + } if constexpr (kIsCausal) { if (startT >= endF) break; @@ -887,7 +971,7 @@ class fmha_forward_t { matAccSij_t matAccSij(0); gemm_Sij(matAccSij, args); // apply mask - apply_mask(matAccSij, args, startF, startT); + apply_mask(item, matAccSij, args, startF, startT); // softmax dp_mask_tile_t mask_in; softmax_fwd(matAccSij, matAccOi, mask_in, args); diff --git a/tests/integration/fmha/fmha_utils.h b/tests/integration/fmha/fmha_utils.h index 1aef9a5f4..25e19814c 100644 --- a/tests/integration/fmha/fmha_utils.h +++ b/tests/integration/fmha/fmha_utils.h @@ -24,7 +24,7 @@ struct tile_mask_t { uint32_t start_x, uint32_t start_y) { #pragma unroll - for (int i = 0; i < tile_size_y / block_size_y; i++) { + for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) { uint32_t blk_start_y = start_y + i * block_size_y; #pragma unroll for (int j = 0; j < num_block_x; j++) { @@ -203,16 +203,24 @@ struct group_row_reduce_t { } }; +enum class add_type : uint8_t { + single_line = 0, // add one line of data of given coord to target tile + single_element = 1 // add single data of given coord to target tile +}; + /// @brief Is the bias_add op functor. /// Load the 1d bias data from memory and get the input from matAcc, update the /// output in place. Used in epilogue::tile_op or chained_tile_op. /// @tparam dtype_bias Is the data type of bias buffer. /// @tparam arch_tag Is the hardware architecture tag. -template -struct bias_add_op_t; -/// @brief Is the bias_add op functor, specialized for Xe architecture. +template < + typename dtype_bias_, + gpu_arch arch_tag = gpu_arch::XeHpc, + add_type add_tag = add_type::single_line> +struct bias_add_op_t {}; + template -struct bias_add_op_t { +struct bias_add_op_t { using dtype_bias = dtype_bias_; using mem_desc_bias_t = mem_desc_t; @@ -245,12 +253,15 @@ struct bias_add_op_t { using bias_tile_desc_t = subgroup:: tile_desc_t; using bias_t = subgroup::tile_t; - using mem_desc_bias_t = - mem_desc_t; using bias_payload_t = subgroup::mem_payload_t< - mem_desc_bias_t, + mem_desc_t, bias_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v< + bias_tile_desc_t, + mem_desc_t< + dtype_bias, + mem_desc_bias_t::layout, + mem_desc_bias_t::space>>, arch_tag>; coord_t bias_coord(coord.x, coord.y); mem_desc_bias_t mem_desc_bias(args.base, args.shape, bias_coord); @@ -291,7 +302,7 @@ struct bias_add_op_t { tail_start_y * tile_size_x + j * tail_block_elems) .xetla_format(); #pragma unroll - for (int row_i = 0; row_i < tail_size_y; row_i++) { + for (uint32_t row_i = 0; row_i < tail_size_y; row_i++) { auto src_reg = bias.reg.xetla_select(j * block_size_x); dst_reg.row(row_i) = @@ -302,6 +313,90 @@ struct bias_add_op_t { } } }; + +template +struct bias_add_op_t { + using dtype_bias = dtype_bias_; + using mem_desc_bias_t = + mem_desc_t; + using shape_t = typename mem_desc_bias_t::shape_t; + using coord_t = typename mem_desc_bias_t::coord_t; + using base_t = typename mem_desc_bias_t::base_t; + + struct arguments_t { + shape_t shape; + base_t base; + inline arguments_t() = default; + inline arguments_t(base_t base_, shape_t shape_) + : base(base_), shape(shape_) {} + }; + template + __XETLA_API KERNEL_FUNC void operator()( + matAcc_t& matAcc, + const coord_t& coord, + const arguments_t& args, + [[maybe_unused]] uint32_t slm_base = 0, + [[maybe_unused]] uint32_t nbarrier_base = 0) { + using dtype_acc = typename matAcc_t::dtype; + static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x; + static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y; + static constexpr uint32_t block_size_x = matAcc_t::block_size_x; + static constexpr uint32_t block_size_y = matAcc_t::block_size_y; + static constexpr int32_t num_block_x = matAcc_t::num_block_x; + static constexpr uint32_t block_elems = matAcc_t::block_elems; + + dtype_bias* ptr = static_cast(args.base.base); + int32_t pos_x = coord.x > args.shape.x - 1 ? args.shape.x - 1 : coord.x; + int32_t pos_y = coord.y > args.shape.y - 1 ? args.shape.y - 1 : coord.y; + uint32_t offset = (pos_y + pos_x * args.shape.stride) * sizeof(dtype_bias); + auto bias_data_vector = xetla_load_global< + dtype_bias, + 1, + data_size::default_size, + cache_hint::cached, + cache_hint::cached, + 16>(ptr, offset); + dtype_acc bias_data = + xetla_cvt(bias_data_vector)[0]; + + { + auto bias_reg = + xetla_vector(bias_data); +#pragma unroll + for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) { +#pragma unroll + for (int j = 0; j < num_block_x; j++) { + auto dst_reg = + matAcc.reg + .xetla_select( + (i * num_block_x + j) * block_elems) + .xetla_format(); + + dst_reg += bias_reg; + } + } + } + // process the tail + if constexpr ((tile_size_y % block_size_y) != 0) { + constexpr uint32_t tail_start_y = + tile_size_y / block_size_y * block_size_y; + constexpr int32_t tail_size_y = tile_size_y % block_size_y; + constexpr int32_t tail_block_elems = tail_size_y * block_size_x; + auto bias_reg = + xetla_vector(bias_data); +#pragma unroll + for (int j = 0; j < num_block_x; j++) { + auto dst_reg = + matAcc.reg + .xetla_select( + tail_start_y * tile_size_x + j * tail_block_elems) + .xetla_format(); + dst_reg += bias_reg; + } + } + } +}; + struct tile_mul { template static xetla_vector inline func(