diff --git a/bestla/bestla/bestla_epilogue.h b/bestla/bestla/bestla_epilogue.h index f2228c22f..b2349013b 100644 --- a/bestla/bestla/bestla_epilogue.h +++ b/bestla/bestla/bestla_epilogue.h @@ -93,6 +93,7 @@ struct ParamAlphaBetaProcess { template class AlphaBetaProcessFp32 { public: + using DType = float; using Param = ParamAlphaBetaProcess; BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, diff --git a/neural_speed/core/layers/mha_dense.cpp b/neural_speed/core/layers/mha_dense.cpp index 5539bd3dd..13c4cea4c 100644 --- a/neural_speed/core/layers/mha_dense.cpp +++ b/neural_speed/core/layers/mha_dense.cpp @@ -80,6 +80,7 @@ struct attn_fwd_args_t { int step_k_bs, step_k_head_num, step_k_sl, step_k_head_size; int step_v_bs, step_v_head_num, step_v_sl, step_v_head_size; int step_dst_bs, step_dst_head_num, step_dst_sl; + int n_prompt; // caller grantees that K/V for first n_prompt tokens are identical among batches }; struct mha_problem_t { @@ -657,7 +658,7 @@ class mha_interface_t { const auto num_heads = p.batch_size * p.head_num; // Total number of heads device::CpuBase cb; // Note: DO NOT use cb.mNumThreads; use th.num_threads() instead - const bool is_causal = (p.attn_flags & NE_ATTN_FLAG_IS_CAUSAL) != 0; + const bool is_causal = (p.attn_flags & NE_ATTN_FLAG_IS_CAUSAL) != 0 && p.sl_q > 1; const bool is_alibi = (p.attn_flags & NE_ATTN_FLAG_IS_ALIBI8) != 0; assert(!is_causal || p.sl_q <= p.sl_kv); assert(("alibi not supported!", !is_alibi)); @@ -1338,7 +1339,7 @@ class mha_stable_interface_t { assert((p.V_layout != ATTN_FWD_LAYOUT_PLAIN || p.step_k_sl == 1)); const auto num_heads = p.batch_size * p.head_num; // Total number of heads device::CpuBase cb; // Note: DO NOT use cb.mNumThreads; use th.num_threads() instead - const bool is_causal = (p.attn_flags & NE_ATTN_FLAG_IS_CAUSAL) != 0; + const bool is_causal = (p.attn_flags & NE_ATTN_FLAG_IS_CAUSAL) != 0 && p.sl_q > 1; const bool is_alibi = (p.attn_flags & NE_ATTN_FLAG_IS_ALIBI8) != 0; assert(!is_causal || p.sl_q <= p.sl_kv); assert(("head_num must be a multiple of heads_kv!", p.head_num % p.heads_kv == 0)); @@ -1505,6 +1506,284 @@ class mha_stable_interface_t { return BTLA_CODE::Success; } + BTLA_CODE compute_beams(const attn_fwd_args_t& p, const parallel::IThreading& th) { + assert((std::is_same::value || p.Q_sc == 1)); + assert((std::is_same::value || p.K_sc == 1)); + assert((std::is_same::value || p.V_sc == 1)); + assert((std::is_same::value || p.dst_sc == 1)); + + assert((p.Q_layout == ATTN_FWD_LAYOUT_PLAIN && p.dst_layout == ATTN_FWD_LAYOUT_PLAIN)); + assert((p.K_layout == ATTN_FWD_LAYOUT_PLAIN || + (std::is_same::value && p.K_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4) || + (std::is_same::value && p.K_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2))); + assert((p.V_layout == ATTN_FWD_LAYOUT_PLAIN || + (std::is_same::value && p.V_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4) || + (std::is_same::value && p.V_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2))); + + assert((!std::is_same>::value) || + p.K_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 || + p.K_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2); // WeightForward can only be used with preprocessed layout + assert( + (!std::is_same>::value) || + p.V_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 || + p.V_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2); // WeightForward can only be used with preprocessed layout + + assert((p.K_layout != ATTN_FWD_LAYOUT_PLAIN || p.step_v_head_size == 1)); + assert((p.V_layout != ATTN_FWD_LAYOUT_PLAIN || p.step_k_sl == 1)); + assert(p.sl_q == 1 && p.batch_size > 1); // beam search next-token cases + device::CpuBase cb; // Note: DO NOT use cb.mNumThreads; use th.num_threads() instead + const bool is_causal = (p.attn_flags & NE_ATTN_FLAG_IS_CAUSAL) != 0 && p.sl_q > 1; + const bool is_alibi = (p.attn_flags & NE_ATTN_FLAG_IS_ALIBI8) != 0; + assert(!is_causal || p.sl_q <= p.sl_kv); + assert(("head_num must be a multiple of heads_kv!", p.head_num % p.heads_kv == 0)); + assert(("Not Implemented", !is_alibi)); + assert(("Not Implemented", !is_causal)); + const auto group_heads = p.head_num / p.heads_kv; + const auto sl_diff = p.sl_kv - p.sl_q; + + // TP will need the real rank oder of k + int32_t k_offset = 0; + int32_t log_head_num = p.head_num; +#ifdef NS_TP_MODEL + NE_ASSERT(("Not implemented", false)) +#endif + + // alibi slope + const int n_heads_log2_floor = 1 << static_cast(floor(log2(log_head_num))); + const float m0 = powf(2.0f, -(8.f) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(8.f / 2.0f) / n_heads_log2_floor); + + const auto m_tiles = updiv(p.batch_size, M_TILE); + assert(p.batch_size <= M_TILE && m_tiles == 1); + const auto num_tasks = p.head_num; + + using Scheduler2D = bestla::parallel::Scheduler2D; + const Scheduler2D parl({th.num_threads(), {num_tasks, 1}, {1, 1}}); // main parallel scheduler + + th.parallel_for([&](int tid) { + const int tmp_s_size = M_TILE * padto(padto(p.sl_kv, GemmQK::NTILE), GemmPV::KTILE); + const int tmp_p_size = tmp_s_size; + const int tmp_bytes = tmp_s_size * sizeof(float); // S & exp + const auto tmp_s = reinterpret_cast(p.tmp + tid * tmp_bytes); + using PType = typename GemmPV::AType; + const auto tmp_p = reinterpret_cast(tmp_s); // overwrite tmp_s row-wisely + + // calculate mm + softmax + mm + { + typename parallel::ThreadProblem2D thdp{tid}; + parl.getIndex(thdp); + const auto [task_start, _assert0] = thdp.loc; + auto [task_size, _assert_max1] = thdp.size; + assert(task_size == 0 || _assert0 == 0); + assert(task_size == 0 || _assert_max1 == 1 || _assert_max1 == 0); + if (_assert_max1 == 0 || !thdp.valid) task_size = 0; + + for (int task_id = task_start; task_id < task_start + task_size; ++task_id) { + const int ihn = task_id; + const int ihkv = ihn / group_heads; + + const auto alibi_ihn_m = 0.f; // Alibi not implemented + + float s_max[M_TILE]{}; // maximum for each row of the S matrix + std::fill_n(s_max, M_TILE, -INFINITY); + + // ptr to Q / dst matrix of the current head + const auto head_q_bs0 = p.Q + ihn * p.step_q_head_num; + // const auto head_k = p.K + ibs * p.step_k_bs + ihkv * p.step_k_head_num; + // const auto head_v = p.V + ibs * p.step_v_bs + ihkv * p.step_v_head_num; + const auto head_k_bs0 = p.K + ihkv * p.step_k_head_num; // bs here is beam + const auto head_v_bs0 = p.V + ihkv * p.step_v_head_num; // bs here is beam + // const auto head_dst = p.dst + ibs * p.step_dst_bs + ihn * p.step_dst_head_num; + const auto head_dst_bs0 = p.dst + ihn * p.step_dst_head_num; + + assert(!is_causal); + const auto unmasked_size = p.sl_kv; + + const auto unmasked_size_pad_qk = std::min(p.sl_kv, padto(unmasked_size, GemmQK::NTILE)); + const auto unmasked_size_pad_pv = std::min(p.sl_kv, padto(unmasked_size, GemmPV::KTILE)); + const int ld_tmp_s = padto(padto(unmasked_size_pad_pv, GemmQK::NTILE), GemmPV::KTILE); + static_assert(sizeof(float) >= sizeof(PType), "PType exceeded float size!"); + const int ld_tmp_p = ld_tmp_s * sizeof(float) / sizeof(PType); + const auto qk_prok_ldb = p.step_k_sl == 1 ? p.step_k_head_size + : p.K_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 ? p.step_k_sl + : p.K_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 ? p.step_k_sl + : (assert(0), 0); + + const auto n_prompt_le_n = padto_le(p.n_prompt, GemmQK::NTILE); + typename parallel::gemm::ThreadProblemBase tpQKBatch{ + /* ThreadProblem2D */ {tid, {}, {0, 0}, {p.batch_size, n_prompt_le_n}, true}, + /* .block = */ {M_TILE, GemmQK::NTILE, p.head_size}, + /* .stacksize = */ cb.mL2Cache, + /* .tmpcachesize = */ cb.mL2Cache, + }; + l_qk.run( // QxK => S ==exp==> P + QKArgs{ + utils::GemmProblem{ + /* .batch */ 1, + /* .M = */ p.batch_size, + /* .N = */ n_prompt_le_n, + /* .K = */ p.head_size, + }, + /* .paramA = */ + QKProQArgs{ + head_q_bs0, + p.step_q_bs, + }, + /* .paramB = */ + QKProKArgs{ + /* .B = */ head_k_bs0, + /* .ldb = */ qk_prok_ldb, + /* .is_padded = */ true, + }, // K should be pre-transposed + /* .paramC = */ + QKEpiArgs{ + /* .dst = */ tmp_s, + /* .dst_sum = */ s_max, + /* .ld_dst = */ ld_tmp_s, + /* .scale = */ p.QK_scale * p.Q_sc * p.K_sc, + /* .causal_offset = */ -1, + /* .alibi_slope = */ alibi_ihn_m, + }, + // /* .workspace = */ nullptr, + }, + tpQKBatch); + for (int ibs = 0; ibs < p.batch_size; ++ibs) { + typename parallel::gemm::ThreadProblemBase tpQKBeam{ + /* ThreadProblem2D */ {tid, {}, {ibs, n_prompt_le_n}, {1, p.sl_kv - n_prompt_le_n}, true}, + /* .block = */ {M_TILE, GemmQK::NTILE, p.head_size}, + /* .stacksize = */ cb.mL2Cache, + /* .tmpcachesize = */ cb.mL2Cache, + }; + l_qk.run( // QxK => S ==exp==> P + QKArgs{ + utils::GemmProblem{ + /* .batch */ 1, + /* .M = */ 1, + /* .N = */ p.sl_kv, + /* .K = */ p.head_size, + }, + /* .paramA = */ + QKProQArgs{ + head_q_bs0, + p.step_q_bs, + }, + /* .paramB = */ + QKProKArgs{ + /* .B = */ head_k_bs0 + ibs * p.step_k_bs, + /* .ldb = */ qk_prok_ldb, + /* .is_padded = */ true, + }, // K should be pre-transposed + /* .paramC = */ + QKEpiArgs{ + /* .dst = */ tmp_s, + /* .dst_sum = */ s_max, + /* .ld_dst = */ ld_tmp_s, + /* .scale = */ p.QK_scale * p.Q_sc * p.K_sc, + /* .causal_offset = */ is_causal ? sl_diff : -1, + /* .alibi_slope = */ alibi_ihn_m, + }, + // /* .workspace = */ nullptr, + }, + tpQKBeam); + } + + // softmax (with pre-computed row_max) + assert(!is_causal); + const auto unmasked_size_start = p.sl_kv; + float expsum[M_TILE]{}; // maximum for each row of the S matrix + const auto softmax_npad_size = padto(unmasked_size_pad_pv, GemmPV::KTILE); + inplace_precompute_max_softmax_t::forward( // + p.batch_size, unmasked_size_start, softmax_npad_size, // m / n + is_causal, tmp_s, tmp_p, s_max, expsum, ld_tmp_s, ld_tmp_p); // + + const auto pv_scale = expsum; + for (int i = 0; i < M_TILE; ++i) pv_scale[i] = p.V_sc / UINT8_MAX / expsum[i] / p.dst_sc; + + const auto pv_prov_ldb = p.step_v_head_size == 1 ? p.step_v_sl + : p.V_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK4 ? p.step_v_head_size + : p.V_layout == ATTN_FWD_LAYOUT_NTILE48_ROWPACK2 ? p.step_v_head_size + : (assert(0), 0); + + const auto n_prompt_le_k = padto_le(p.n_prompt, GemmPV::KTILE); + typename parallel::gemm::ThreadProblemBase tpPVBatch{ + /* ThreadProblem2D */ {tid, {}, {0, 0}, {p.batch_size, p.head_size}, true}, + /* .block = */ {M_TILE, GemmPV::NTILE, n_prompt_le_k}, + /* .stacksize = */ cb.mL2Cache, + /* .tmpcachesize = */ cb.mL2Cache, + }; + l_pv.run( // PxV => O + PVArgs{ + utils::GemmProblem{ + /* .batch */ 1, + /* .M = */ p.batch_size, + /* .N = */ p.head_size, + /* .K = */ n_prompt_le_k, + }, + /* .paramA = */ PVProPArgs{tmp_p, ld_tmp_p}, + /* .paramB = */ + PVProVArgs{ + /* .B = */ head_v_bs0, + /* .ldb = */ pv_prov_ldb, + /* .is_padded = */ true, + }, + /* .paramC = */ + PVEpiArgs{ + /* .C = */ head_dst_bs0, + /* .D = */ head_dst_bs0, + /* .ldc = */ p.step_dst_bs, + /* .ldd = */ p.step_dst_bs, + /* .alpha = */ 1.f, + /* .beta = */ 0, + }, + // /* .workspace = */ nullptr, + }, + tpPVBatch); + for (int ibs = 0; ibs < p.batch_size; ++ibs) { + if constexpr (std::is_same_v::Param>) { + typename parallel::gemm::ThreadProblemBase tpPVBeam{ + /* ThreadProblem2D */ {tid, {}, {ibs, 0}, {1, p.head_size}, true}, + /* .block = */ {M_TILE, GemmPV::NTILE, unmasked_size_pad_pv - n_prompt_le_k}, + /* .stacksize = */ cb.mL2Cache, + /* .tmpcachesize = */ cb.mL2Cache, + }; + l_pv.run( // PxV => O + PVArgs{ + utils::GemmProblem{ + /* .batch */ 1, + /* .M = */ p.batch_size, + /* .N = */ p.head_size, + /* .K = */ unmasked_size_pad_pv - n_prompt_le_k, + }, + /* .paramA = */ PVProPArgs{tmp_p + n_prompt_le_k, ld_tmp_p}, + /* .paramB = */ + PVProVArgs{ + /* .B = */ head_v_bs0 + ibs * p.step_v_bs + n_prompt_le_k * GemmPV::NTILE, + /* .ldb = */ pv_prov_ldb, + /* .is_padded = */ true, + }, + /* .paramC = */ + PVEpiArgs{ + /* .C = */ head_dst_bs0, + /* .D = */ head_dst_bs0, + /* .ldc = */ p.step_dst_bs, + /* .ldd = */ p.step_dst_bs, + /* .alpha = */ 1.f, + /* .beta = */ 1.f, + }, + // /* .workspace = */ nullptr, + }, + tpPVBeam); + } else { + static_assert(false, "Not implemented"); + } + } + } + } + }); + return BTLA_CODE::Success; + } + protected: L_Max l_qk; L_Scale l_pv; @@ -1679,15 +1958,27 @@ void bestla_fusion_attn_forward(const attn_fwd_args_t< prologue_a::gemm::ActivationConverterFp32, // ::weight_forward_n_tile48_t, // ::ScaleTrackMaxFp32Fp32>; // - using GemmKernelBF16 = ::launcher_base_weight_t< // - BTLA_ISA::AMX_BF16, // - gemm::HCoreRowNAmxbf16<48, 16>, // - ::activation_identity_t, // pretty sure we have enough paddings for P-matrix - ::weight_forward_n_tile48_t, // - epilogue::gemm::AccumulatorWriteBackFp32>; // - static mha_stable_interface_t mha; - [[maybe_unused]] const auto ret = mha.compute(params, *pth); - assert(ret == BTLA_CODE::Success); + if (params.n_prompt > 0 && params.batch_size > 1) { // beam search optimization + using GemmKernelBF16 = ::launcher_base_weight_t< // + BTLA_ISA::AMX_BF16, // + gemm::HCoreRowNAmxbf16<48, 16>, // + ::activation_identity_t, // pretty sure we have enough paddings for P-matrix + ::weight_forward_n_tile48_t, // + epilogue::gemm::AlphaBetaProcessFp32>; // + static mha_stable_interface_t mha; + [[maybe_unused]] const auto ret = mha.compute_beams(params, *pth); + assert(ret == BTLA_CODE::Success); + } else { + using GemmKernelBF16 = ::launcher_base_weight_t< // + BTLA_ISA::AMX_BF16, // + gemm::HCoreRowNAmxbf16<48, 16>, // + ::activation_identity_t, // pretty sure we have enough paddings for P-matrix + ::weight_forward_n_tile48_t, // + epilogue::gemm::AccumulatorWriteBackFp32>; // + static mha_stable_interface_t mha; + [[maybe_unused]] const auto ret = mha.compute(params, *pth); + assert(ret == BTLA_CODE::Success); + } } else { assert(0); } @@ -1695,7 +1986,7 @@ void bestla_fusion_attn_forward(const attn_fwd_args_t< template void bestla_fusion_attn_forward_ref(const attn_fwd_args_t& p) { - const bool is_causal = (p.attn_flags & NE_ATTN_FLAG_IS_CAUSAL) != 0; + const bool is_causal = (p.attn_flags & NE_ATTN_FLAG_IS_CAUSAL) != 0 && p.sl_q > 1; const bool is_alibi = (p.attn_flags & NE_ATTN_FLAG_IS_ALIBI8) != 0; assert(!is_causal || p.sl_q <= p.sl_kv); assert(("head_num must be a multiple of heads_kv!", p.head_num % p.heads_kv == 0)); @@ -1937,6 +2228,7 @@ void bestla_reordered_attn_fp32_forward(const bestla_reordered_attn_fp32_fp32_fw /* .step_dst_bs = */ params->step_dst_bs, /* .step_dst_head_num = */ params->step_dst_head_num, /* .step_dst_sl = */ params->step_dst_sl, + /* .n_prompt = */ params->n_prompt, }; return bestla_fusion_attn_forward(bestla_params); } diff --git a/neural_speed/core/layers/mha_dense.h b/neural_speed/core/layers/mha_dense.h index af8581b5b..0683e2d72 100644 --- a/neural_speed/core/layers/mha_dense.h +++ b/neural_speed/core/layers/mha_dense.h @@ -80,6 +80,7 @@ typedef struct attn_fp32_fp16_fp16_fp32_fwd_args_t { int step_k_bs, step_k_head_num, step_k_sl, step_k_head_size; int step_v_bs, step_v_head_num, step_v_sl, step_v_head_size; int step_dst_bs, step_dst_head_num, step_dst_sl; + int n_prompt; // caller grantees that K/V for first n_prompt tokens are identical among batches } attn_fp32_fp16_fp16_fp32_fwd_args_t; void bestla_fusion_attn_bf16_forward(const attn_bf16_fwd_args_t* params); @@ -165,6 +166,7 @@ typedef struct bestla_reordered_attn_fp32_fp32_fwd_args_t { int stride_k_bs, stride_k_head_num, stride_k_sl, stride_k_head_size; int stride_v_bs, stride_v_head_num, stride_v_sl, stride_v_head_size; int step_dst_bs, step_dst_head_num, step_dst_sl; + int n_prompt; // caller grantees that K/V for first n_prompt tokens are identical among batches } bestla_reordered_attn_fp32_fp32_fwd_args_t; void bestla_reordered_attn_fp32_forward(const bestla_reordered_attn_fp32_fp32_fwd_args_t* params); diff --git a/neural_speed/core/layers/ne_test_layers_utils.hpp b/neural_speed/core/layers/ne_test_layers_utils.hpp index fce60c499..67027be07 100644 --- a/neural_speed/core/layers/ne_test_layers_utils.hpp +++ b/neural_speed/core/layers/ne_test_layers_utils.hpp @@ -19,7 +19,7 @@ #include #include -#include "bestla/jit_blas_utils.h" +#include "bestla/bestla_utils.h" #ifndef NS_TESTS static_assert(false, "Only include this header file for testing!"); diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.c index a6154bda2..400b81767 100644 --- a/neural_speed/core/ne_layers.c +++ b/neural_speed/core/ne_layers.c @@ -50,12 +50,6 @@ #include "ne.h" #include "ne_bestla.h" -// if C99 - static_assert is noop -// ref: https://stackoverflow.com/a/53923785/4039976 -#ifndef static_assert -#define static_assert(cond, msg) struct global_scope_noop_trick -#endif - #if defined(_WIN32) #include @@ -3274,9 +3268,16 @@ struct ne_tensor* ne_conv_1d_ph(struct ne_context* ctx, struct ne_tensor* a, str } // ne_flash_attn - struct ne_tensor* ne_flash_attn(struct ne_context* ctx, struct ne_tensor* q, struct ne_tensor* k, struct ne_tensor* v, float scale, ne_attn_flags_t flags) { + const ne_attn_op_params_t attn_op_param = { + .flags = flags, + .scale = scale, + }; + return ne_flash_attn_with_params(ctx, q, k, v, &attn_op_param); +}; +struct ne_tensor* ne_flash_attn_with_params(struct ne_context* ctx, struct ne_tensor* q, struct ne_tensor* k, + struct ne_tensor* v, const ne_attn_op_params_t* op_params) { NE_ASSERT(ne_can_mul_mat(k, q)); int batch = q->ne[3]; int headnum = q->ne[2]; @@ -3303,8 +3304,7 @@ struct ne_tensor* ne_flash_attn(struct ne_context* ctx, struct ne_tensor* q, str result->src1 = k; result->opt[0] = v; result->opt[1] = tmp_t; - *(float*)result->padding = scale; - *(ne_attn_flags_t*)&result->padding[sizeof(scale)] = flags; + memcpy(result->op_params, op_params, sizeof(ne_attn_op_params_t)); return result; } @@ -8744,8 +8744,7 @@ static void ne_compute_forward_flash_attn_f32_f16_f16(const struct ne_compute_pa int step_v_head_size = v->nb[1] / veles; int step_v_head_num = v->nb[2] / veles; int step_v_bs = k->nb[3] / veles; - float scale = *(float*)dst->padding; - ne_attn_flags_t flags = *(bool*)&dst->padding[sizeof(scale)]; + const ne_attn_op_params_t* op_params = (ne_attn_op_params_t*)dst->op_params; attn_fp32_fp16_fp16_fp32_fwd_args_t args = { .Q = (float*)q->data, .K = (ne_fp16_t*)k->data, @@ -8756,8 +8755,8 @@ static void ne_compute_forward_flash_attn_f32_f16_f16(const struct ne_compute_pa .V_sc = 1.f, .dst_sc = 1.f, .tmp = tmp->data, - .QK_scale = scale, - .attn_flags = flags, + .QK_scale = op_params->scale, + .attn_flags = op_params->flags, .batch_size = batch, .head_num = headnum, .heads_kv = heads_kv, @@ -8782,6 +8781,7 @@ static void ne_compute_forward_flash_attn_f32_f16_f16(const struct ne_compute_pa .step_dst_bs = seq_cur * embedsize, .step_dst_head_num = headsize, .step_dst_sl = embedsize, + .n_prompt = op_params->n_prompt, }; bestla_fusion_attn_fp32_fp16_fp16_fp32_forward(&args); } @@ -8801,8 +8801,9 @@ static void ne_compute_forward_flash_attn_reordered(const struct ne_compute_para const int64_t dst_ele_size = ne_element_size(dst); // const int64_t seq_past = seq_all - seq_cur; - float scale = *(float*)dst->padding; - ne_attn_flags_t flags = *(ne_attn_flags_t*)&dst->padding[sizeof(scale)]; + const ne_attn_op_params_t* op_params = (ne_attn_op_params_t*)dst->op_params; + float scale = op_params->scale; + ne_attn_flags_t flags = op_params->flags; NE_ASSERT(k->type == NE_TYPE_BTLA && v->type == NE_TYPE_BTLA); ATTN_FWD_LAYOUT K_layout = *(ATTN_FWD_LAYOUT*)(&k->nb[0]); @@ -8848,6 +8849,7 @@ static void ne_compute_forward_flash_attn_reordered(const struct ne_compute_para .step_dst_bs = dst->nb[3] / dst_ele_size, .step_dst_head_num = dst->nb[1] / dst_ele_size, .step_dst_sl = dst->nb[2] / dst_ele_size, + .n_prompt = op_params->n_prompt, }; bestla_reordered_attn_fp32_forward(&args); } diff --git a/neural_speed/core/ne_layers.h b/neural_speed/core/ne_layers.h index 032283696..34ba76201 100644 --- a/neural_speed/core/ne_layers.h +++ b/neural_speed/core/ne_layers.h @@ -37,6 +37,12 @@ #include "core/data_types.h" #include "layers/layers.h" +// if C99 - static_assert is noop +// ref: https://stackoverflow.com/a/53923785/4039976 +#if !defined(__cplusplus) || __cplusplus < 201103L +#define static_assert(cond, msg) struct global_scope_noop_trick +#endif + #define NE_QNT_VERSION 2 // bump this on quantization format changes #define NE_QNT_VERSION_FACTOR 1000 // do not change this @@ -69,6 +75,13 @@ typedef enum NE_ATTN_FLAG { } NE_ATTN_FLAG; typedef uint32_t ne_attn_flags_t; +typedef struct ne_attn_op_params_t { + ne_attn_flags_t flags; + float scale; + int n_prompt; +} ne_attn_op_params_t; +static_assert(sizeof(ne_attn_op_params_t) <= NE_MAX_OP_PARAMS, "ATTN OP PARAM too large!"); + // convert FP16 <-> FP32 NE_API float ne_fp16_to_fp32(ne_fp16_t x); NE_API ne_fp16_t ne_fp32_to_fp16(float x); @@ -440,6 +453,8 @@ NE_API struct ne_tensor* ne_conv_1d_ph(struct ne_context* ctx, struct ne_tensor* NE_API struct ne_tensor* ne_flash_attn(struct ne_context* ctx, struct ne_tensor* q, struct ne_tensor* k, struct ne_tensor* v, float scale, ne_attn_flags_t flags); +NE_API struct ne_tensor* ne_flash_attn_with_params(struct ne_context* ctx, struct ne_tensor* q, struct ne_tensor* k, + struct ne_tensor* v, const ne_attn_op_params_t* op_params); // set no_zeroing to true to prevent zeroing unaligned seq NE_API struct ne_tensor* ne_flash_attn_update_k(struct ne_context* ctx, struct ne_tensor* cache, struct ne_tensor* cur, int n_past, bool no_zeroing); diff --git a/neural_speed/models/gptj/gptj.cpp b/neural_speed/models/gptj/gptj.cpp index 49709975b..1eeef540d 100644 --- a/neural_speed/models/gptj/gptj.cpp +++ b/neural_speed/models/gptj/gptj.cpp @@ -208,10 +208,11 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu if (lctx.cont_batching) { size_t off_sl = 0; // per_request rope - for (int gi = 0; gi < infer_groups.size(); ++gi) { - const int qk_bs = infer_groups[gi].size(); - const int qk_sl = n_tokens[infer_groups[gi].front()]; - const int qk_n_past = n_pasts[infer_groups[gi].front()]; + + for (const auto& curr_group : infer_groups) { + const int qk_bs = curr_group.size(); + const int qk_sl = n_tokens[curr_group.front()]; + const int qk_n_past = n_pasts[curr_group.front()]; struct ne_tensor* Qcur_req = ne_view_4d(ctx0, Qcur, head_size, n_head, qk_sl, qk_bs, ne_element_size(Qcur) * head_size, ne_element_size(Qcur) * head_size * n_head, ne_element_size(Qcur) * head_size * n_head * qk_sl, @@ -313,11 +314,11 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu const auto k_size = kv_cache_info.k_bytes; const auto v_size = kv_cache_info.v_bytes; size_t off_sl = 0; - for (int gi = 0; gi < infer_groups.size(); ++gi) { - const int update_bs = infer_groups[gi].size(); - const int update_sl = n_tokens[infer_groups[gi].front()]; - const int update_block_id = block_ids[infer_groups[gi].front()]; - const int update_n_past = n_pasts[infer_groups[gi].front()]; + for (const auto& curr_group : infer_groups) { + const int update_bs = curr_group.size(); + const int update_sl = n_tokens[curr_group.front()]; + const int update_block_id = block_ids[curr_group.front()]; + const int update_n_past = n_pasts[curr_group.front()]; struct ne_tensor* k_cache_g = ne_view_4d(ctx0, kv_self.k, // tensor head_size, n_ctx, n_head, update_bs, // ne 0, 0, k_size, // nb (bestla managed) @@ -345,20 +346,19 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu struct ne_tensor* KQV_merged_contiguous = ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, seq_len_sum, NE_SIZE_CALC); size_t off_sl = 0; - for (int gi = 0; gi < infer_groups.size(); ++gi) { - const int attn_bs = infer_groups[gi].size(); - const int attn_sl = n_tokens[infer_groups[gi].front()]; - const int attn_block_id = block_ids[infer_groups[gi].front()]; - const int attn_n_past = n_pasts[infer_groups[gi].front()]; - const int attn_n_total = n_totals[infer_groups[gi].front()]; + for (const auto& curr_group : infer_groups) { + const int attn_bs = curr_group.size(); + const int attn_sl = n_tokens[curr_group.front()]; + const int attn_block_id = block_ids[curr_group.front()]; + const int attn_n_past = n_pasts[curr_group.front()]; + const int attn_n_total = n_totals[curr_group.front()]; struct ne_tensor* Q = ne_permute(ctx0, ne_view_4d(ctx0, Qcur, head_size, n_head, attn_sl, attn_bs, ne_element_size(Qcur) * head_size, ne_element_size(Qcur) * head_size * n_head, ne_element_size(Qcur) * head_size * n_head * attn_sl, off_sl * ne_element_size(Qcur)), 0, 2, 1, 3); - std::string suffix = std::to_string(gi); - ne_set_name(Q, std::string("Q_" + suffix).c_str()); + ne_set_name(Q, "Q"); struct ne_tensor *K, *V; const int n_cached_gi = shift_roped_k ? n_cached : attn_n_past + attn_sl; if (run_mha_reordered) { @@ -412,7 +412,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu } } else { std::vector attn_block_ids; - for (const auto& bsi : infer_groups[gi]) { + for (const auto& bsi : curr_group) { attn_block_ids.push_back(block_ids[bsi]); } K = model_kv_cache_seq_concat(&gf, &lctx, ctx0, head_size, n_cached_gi, n_head, attn_bs, attn_block_ids, il); @@ -431,20 +431,27 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu V = model_kv_cache_seq_concat(&gf, &lctx, ctx0, n_cached_gi, head_size, n_head, attn_bs, attn_block_ids, il, false); } - ne_set_name(K, std::string("K_" + suffix).c_str()); - ne_set_name(V, std::string("V_" + suffix).c_str()); + ne_set_name(K, "K"); + ne_set_name(V, "V"); struct ne_tensor* KQV_merged_gi; const float attn_scale = 1.0f / sqrtf(static_cast(head_size)); ne_attn_flags_t attn_flags = NE_ATTN_FLAG_NONE; +#ifndef NDEBUG + for (const auto bi : curr_group) { + NE_ASSERT(inputs[bi].n_prompt_tokens == inputs[curr_group[0]].n_prompt_tokens); + } +#endif + const int n_prompt = curr_group.size() == 1 ? 0 : inputs[curr_group[0]].n_prompt_tokens; if (attn_n_total == 0 || !shift_roped_k) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases - if (run_mha_reordered) { // reordered kv-cache bf16 mha must be used if run_mha_reordered - struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags); + const auto attn_op_param = ne_attn_op_params_t{attn_flags, attn_scale, n_prompt}; + if (run_mha_reordered) { // reordered kv-cache bf16 mha must be used if run_mha_reordered + struct ne_tensor* KQV_Out = ne_flash_attn_with_params(ctx0, Q, K, V, &attn_op_param); KQV_merged_gi = ne_view_2d(ctx0, KQV_Out, head_size * n_head, attn_sl * attn_bs, head_size * n_head * ne_element_size(KQV_Out), 0); } else if (run_mha_fp16) { // non-reordered kv-cache fp16 mha - struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags); + struct ne_tensor* KQV_Out = ne_flash_attn_with_params(ctx0, Q, K, V, &attn_op_param); KQV_merged_gi = ne_view_2d(ctx0, KQV_Out, head_size * n_head, attn_sl * attn_bs, head_size * n_head * ne_element_size(KQV_Out), 0); } else if (attn_n_total == 0 && run_mha_bf16_first) { @@ -462,37 +469,37 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu } else { // K * Q struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); - ne_set_name(KQ, std::string("KQ_" + suffix).c_str()); + ne_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd/n_head) struct ne_tensor* KQ_scale = ne_new_f32(ctx0, attn_scale); - ne_set_name(KQ_scale, std::string("1/sqrt(n_embd/n_head)_" + suffix).c_str()); + ne_set_name(KQ_scale, "1/sqrt(n_embd/n_head)"); // KQ_scaled shape [n_cached, N, n_head, 1] struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, KQ_scale); - ne_set_name(KQ_scaled, std::string("KQ_scaled_" + suffix).c_str()); + ne_set_name(KQ_scaled, "KQ_scaled"); // KQ_scaled = mask_past(KQ_scaled) if (attn_n_total == 0 || !shift_roped_k || !no_padding) { - std::vector attn_n_padding(infer_groups[gi].size(), 0); - for (int npa = 0; !n_padding.empty() && npa < infer_groups[gi].size(); ++npa) { - attn_n_padding[npa] = n_padding[infer_groups[gi][npa]]; + std::vector attn_n_padding(curr_group.size(), 0); + for (int npa = 0; !n_padding.empty() && npa < curr_group.size(); ++npa) { + attn_n_padding[npa] = n_padding[curr_group[npa]]; } KQ_scaled = ne_diag_mask_inf_with_padding_inplace(ctx0, KQ_scaled, attn_n_past, attn_n_padding.data()); - ne_set_name(KQ_scaled, std::string("KQ_masked_" + suffix).c_str()); + ne_set_name(KQ_scaled, "KQ_masked"); } // KQ = soft_max(KQ_masked) struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_scaled); - ne_set_name(KQ_soft_max, std::string("KQ_soft_max_" + suffix).c_str()); + ne_set_name(KQ_soft_max, "KQ_soft_max"); struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); - ne_set_name(KQV, std::string("KQV_" + suffix).c_str()); + ne_set_name(KQV, "KQV"); // KQV_merged = KQV.permute(0, 2, 1, 3) KQV_merged_gi = ne_permute(ctx0, KQV, 0, 2, 1, 3); } - ne_set_name(KQV_merged_gi, std::string("KQV_merged_" + suffix).c_str()); + ne_set_name(KQV_merged_gi, "KQV_merged"); ne_build_forward_expand(&gf, ne_cpy(ctx0, KQV_merged_gi, ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs, head_size * n_head * ne_element_size(KQV_merged_contiguous),