Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
mha for beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
DDEle committed Jan 24, 2024
1 parent 1d542fe commit e2604fa
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 61 deletions.
1 change: 1 addition & 0 deletions bestla/bestla/bestla_epilogue.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ struct ParamAlphaBetaProcess {
template <BTLA_ISA ISA_T>
class AlphaBetaProcessFp32 {
public:
using DType = float;
using Param = ParamAlphaBetaProcess<float>;

BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
Expand Down
316 changes: 304 additions & 12 deletions neural_speed/core/layers/mha_dense.cpp

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions neural_speed/core/layers/mha_dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ typedef struct attn_fp32_fp16_fp16_fp32_fwd_args_t {
int step_k_bs, step_k_head_num, step_k_sl, step_k_head_size;
int step_v_bs, step_v_head_num, step_v_sl, step_v_head_size;
int step_dst_bs, step_dst_head_num, step_dst_sl;
int n_prompt; // caller grantees that K/V for first n_prompt tokens are identical among batches
} attn_fp32_fp16_fp16_fp32_fwd_args_t;

void bestla_fusion_attn_bf16_forward(const attn_bf16_fwd_args_t* params);
Expand Down Expand Up @@ -165,6 +166,7 @@ typedef struct bestla_reordered_attn_fp32_fp32_fwd_args_t {
int stride_k_bs, stride_k_head_num, stride_k_sl, stride_k_head_size;
int stride_v_bs, stride_v_head_num, stride_v_sl, stride_v_head_size;
int step_dst_bs, step_dst_head_num, step_dst_sl;
int n_prompt; // caller grantees that K/V for first n_prompt tokens are identical among batches
} bestla_reordered_attn_fp32_fp32_fwd_args_t;
void bestla_reordered_attn_fp32_forward(const bestla_reordered_attn_fp32_fp32_fwd_args_t* params);

Expand Down
2 changes: 1 addition & 1 deletion neural_speed/core/layers/ne_test_layers_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <vector>
#include <algorithm>

#include "bestla/jit_blas_utils.h"
#include "bestla/bestla_utils.h"

#ifndef NS_TESTS
static_assert(false, "Only include this header file for testing!");
Expand Down
32 changes: 17 additions & 15 deletions neural_speed/core/ne_layers.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@
#include "ne.h"
#include "ne_bestla.h"

// if C99 - static_assert is noop
// ref: https://stackoverflow.com/a/53923785/4039976
#ifndef static_assert
#define static_assert(cond, msg) struct global_scope_noop_trick
#endif

#if defined(_WIN32)

#include <windows.h>
Expand Down Expand Up @@ -3274,9 +3268,16 @@ struct ne_tensor* ne_conv_1d_ph(struct ne_context* ctx, struct ne_tensor* a, str
}

// ne_flash_attn

struct ne_tensor* ne_flash_attn(struct ne_context* ctx, struct ne_tensor* q, struct ne_tensor* k, struct ne_tensor* v,
float scale, ne_attn_flags_t flags) {
const ne_attn_op_params_t attn_op_param = {
.flags = flags,
.scale = scale,
};
return ne_flash_attn_with_params(ctx, q, k, v, &attn_op_param);
};
struct ne_tensor* ne_flash_attn_with_params(struct ne_context* ctx, struct ne_tensor* q, struct ne_tensor* k,
struct ne_tensor* v, const ne_attn_op_params_t* op_params) {
NE_ASSERT(ne_can_mul_mat(k, q));
int batch = q->ne[3];
int headnum = q->ne[2];
Expand All @@ -3303,8 +3304,7 @@ struct ne_tensor* ne_flash_attn(struct ne_context* ctx, struct ne_tensor* q, str
result->src1 = k;
result->opt[0] = v;
result->opt[1] = tmp_t;
*(float*)result->padding = scale;
*(ne_attn_flags_t*)&result->padding[sizeof(scale)] = flags;
memcpy(result->op_params, op_params, sizeof(ne_attn_op_params_t));
return result;
}

Expand Down Expand Up @@ -8744,8 +8744,7 @@ static void ne_compute_forward_flash_attn_f32_f16_f16(const struct ne_compute_pa
int step_v_head_size = v->nb[1] / veles;
int step_v_head_num = v->nb[2] / veles;
int step_v_bs = k->nb[3] / veles;
float scale = *(float*)dst->padding;
ne_attn_flags_t flags = *(bool*)&dst->padding[sizeof(scale)];
const ne_attn_op_params_t* op_params = (ne_attn_op_params_t*)dst->op_params;
attn_fp32_fp16_fp16_fp32_fwd_args_t args = {
.Q = (float*)q->data,
.K = (ne_fp16_t*)k->data,
Expand All @@ -8756,8 +8755,8 @@ static void ne_compute_forward_flash_attn_f32_f16_f16(const struct ne_compute_pa
.V_sc = 1.f,
.dst_sc = 1.f,
.tmp = tmp->data,
.QK_scale = scale,
.attn_flags = flags,
.QK_scale = op_params->scale,
.attn_flags = op_params->flags,
.batch_size = batch,
.head_num = headnum,
.heads_kv = heads_kv,
Expand All @@ -8782,6 +8781,7 @@ static void ne_compute_forward_flash_attn_f32_f16_f16(const struct ne_compute_pa
.step_dst_bs = seq_cur * embedsize,
.step_dst_head_num = headsize,
.step_dst_sl = embedsize,
.n_prompt = op_params->n_prompt,
};
bestla_fusion_attn_fp32_fp16_fp16_fp32_forward(&args);
}
Expand All @@ -8801,8 +8801,9 @@ static void ne_compute_forward_flash_attn_reordered(const struct ne_compute_para
const int64_t dst_ele_size = ne_element_size(dst);
// const int64_t seq_past = seq_all - seq_cur;

float scale = *(float*)dst->padding;
ne_attn_flags_t flags = *(ne_attn_flags_t*)&dst->padding[sizeof(scale)];
const ne_attn_op_params_t* op_params = (ne_attn_op_params_t*)dst->op_params;
float scale = op_params->scale;
ne_attn_flags_t flags = op_params->flags;

NE_ASSERT(k->type == NE_TYPE_BTLA && v->type == NE_TYPE_BTLA);
ATTN_FWD_LAYOUT K_layout = *(ATTN_FWD_LAYOUT*)(&k->nb[0]);
Expand Down Expand Up @@ -8848,6 +8849,7 @@ static void ne_compute_forward_flash_attn_reordered(const struct ne_compute_para
.step_dst_bs = dst->nb[3] / dst_ele_size,
.step_dst_head_num = dst->nb[1] / dst_ele_size,
.step_dst_sl = dst->nb[2] / dst_ele_size,
.n_prompt = op_params->n_prompt,
};
bestla_reordered_attn_fp32_forward(&args);
}
Expand Down
15 changes: 15 additions & 0 deletions neural_speed/core/ne_layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
#include "core/data_types.h"
#include "layers/layers.h"

// if C99 - static_assert is noop
// ref: https://stackoverflow.com/a/53923785/4039976
#if !defined(__cplusplus) || __cplusplus < 201103L
#define static_assert(cond, msg) struct global_scope_noop_trick
#endif

#define NE_QNT_VERSION 2 // bump this on quantization format changes
#define NE_QNT_VERSION_FACTOR 1000 // do not change this

Expand Down Expand Up @@ -69,6 +75,13 @@ typedef enum NE_ATTN_FLAG {
} NE_ATTN_FLAG;
typedef uint32_t ne_attn_flags_t;

typedef struct ne_attn_op_params_t {
ne_attn_flags_t flags;
float scale;
int n_prompt;
} ne_attn_op_params_t;
static_assert(sizeof(ne_attn_op_params_t) <= NE_MAX_OP_PARAMS, "ATTN OP PARAM too large!");

// convert FP16 <-> FP32
NE_API float ne_fp16_to_fp32(ne_fp16_t x);
NE_API ne_fp16_t ne_fp32_to_fp16(float x);
Expand Down Expand Up @@ -440,6 +453,8 @@ NE_API struct ne_tensor* ne_conv_1d_ph(struct ne_context* ctx, struct ne_tensor*

NE_API struct ne_tensor* ne_flash_attn(struct ne_context* ctx, struct ne_tensor* q, struct ne_tensor* k,
struct ne_tensor* v, float scale, ne_attn_flags_t flags);
NE_API struct ne_tensor* ne_flash_attn_with_params(struct ne_context* ctx, struct ne_tensor* q, struct ne_tensor* k,
struct ne_tensor* v, const ne_attn_op_params_t* op_params);
// set no_zeroing to true to prevent zeroing unaligned seq
NE_API struct ne_tensor* ne_flash_attn_update_k(struct ne_context* ctx, struct ne_tensor* cache, struct ne_tensor* cur,
int n_past, bool no_zeroing);
Expand Down
73 changes: 40 additions & 33 deletions neural_speed/models/gptj/gptj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,11 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
if (lctx.cont_batching) {
size_t off_sl = 0;
// per_request rope
for (int gi = 0; gi < infer_groups.size(); ++gi) {
const int qk_bs = infer_groups[gi].size();
const int qk_sl = n_tokens[infer_groups[gi].front()];
const int qk_n_past = n_pasts[infer_groups[gi].front()];

for (const auto& curr_group : infer_groups) {
const int qk_bs = curr_group.size();
const int qk_sl = n_tokens[curr_group.front()];
const int qk_n_past = n_pasts[curr_group.front()];
struct ne_tensor* Qcur_req =
ne_view_4d(ctx0, Qcur, head_size, n_head, qk_sl, qk_bs, ne_element_size(Qcur) * head_size,
ne_element_size(Qcur) * head_size * n_head, ne_element_size(Qcur) * head_size * n_head * qk_sl,
Expand Down Expand Up @@ -313,11 +314,11 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
const auto k_size = kv_cache_info.k_bytes;
const auto v_size = kv_cache_info.v_bytes;
size_t off_sl = 0;
for (int gi = 0; gi < infer_groups.size(); ++gi) {
const int update_bs = infer_groups[gi].size();
const int update_sl = n_tokens[infer_groups[gi].front()];
const int update_block_id = block_ids[infer_groups[gi].front()];
const int update_n_past = n_pasts[infer_groups[gi].front()];
for (const auto& curr_group : infer_groups) {
const int update_bs = curr_group.size();
const int update_sl = n_tokens[curr_group.front()];
const int update_block_id = block_ids[curr_group.front()];
const int update_n_past = n_pasts[curr_group.front()];
struct ne_tensor* k_cache_g = ne_view_4d(ctx0, kv_self.k, // tensor
head_size, n_ctx, n_head, update_bs, // ne
0, 0, k_size, // nb (bestla managed)
Expand Down Expand Up @@ -345,20 +346,19 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
struct ne_tensor* KQV_merged_contiguous =
ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_size * n_head, seq_len_sum, NE_SIZE_CALC);
size_t off_sl = 0;
for (int gi = 0; gi < infer_groups.size(); ++gi) {
const int attn_bs = infer_groups[gi].size();
const int attn_sl = n_tokens[infer_groups[gi].front()];
const int attn_block_id = block_ids[infer_groups[gi].front()];
const int attn_n_past = n_pasts[infer_groups[gi].front()];
const int attn_n_total = n_totals[infer_groups[gi].front()];
for (const auto& curr_group : infer_groups) {
const int attn_bs = curr_group.size();
const int attn_sl = n_tokens[curr_group.front()];
const int attn_block_id = block_ids[curr_group.front()];
const int attn_n_past = n_pasts[curr_group.front()];
const int attn_n_total = n_totals[curr_group.front()];
struct ne_tensor* Q =
ne_permute(ctx0,
ne_view_4d(ctx0, Qcur, head_size, n_head, attn_sl, attn_bs, ne_element_size(Qcur) * head_size,
ne_element_size(Qcur) * head_size * n_head,
ne_element_size(Qcur) * head_size * n_head * attn_sl, off_sl * ne_element_size(Qcur)),
0, 2, 1, 3);
std::string suffix = std::to_string(gi);
ne_set_name(Q, std::string("Q_" + suffix).c_str());
ne_set_name(Q, "Q");
struct ne_tensor *K, *V;
const int n_cached_gi = shift_roped_k ? n_cached : attn_n_past + attn_sl;
if (run_mha_reordered) {
Expand Down Expand Up @@ -412,7 +412,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
}
} else {
std::vector<int> attn_block_ids;
for (const auto& bsi : infer_groups[gi]) {
for (const auto& bsi : curr_group) {
attn_block_ids.push_back(block_ids[bsi]);
}
K = model_kv_cache_seq_concat(&gf, &lctx, ctx0, head_size, n_cached_gi, n_head, attn_bs, attn_block_ids, il);
Expand All @@ -431,20 +431,27 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
V = model_kv_cache_seq_concat(&gf, &lctx, ctx0, n_cached_gi, head_size, n_head, attn_bs, attn_block_ids, il,
false);
}
ne_set_name(K, std::string("K_" + suffix).c_str());
ne_set_name(V, std::string("V_" + suffix).c_str());
ne_set_name(K, "K");
ne_set_name(V, "V");

struct ne_tensor* KQV_merged_gi;
const float attn_scale = 1.0f / sqrtf(static_cast<float>(head_size));
ne_attn_flags_t attn_flags = NE_ATTN_FLAG_NONE;
#ifndef NDEBUG
for (const auto bi : curr_group) {
NE_ASSERT(inputs[bi].n_prompt_tokens == inputs[curr_group[0]].n_prompt_tokens);
}
#endif
const int n_prompt = curr_group.size() == 1 ? 0 : inputs[curr_group[0]].n_prompt_tokens;
if (attn_n_total == 0 || !shift_roped_k)
attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
if (run_mha_reordered) { // reordered kv-cache bf16 mha must be used if run_mha_reordered
struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags);
const auto attn_op_param = ne_attn_op_params_t{attn_flags, attn_scale, n_prompt};
if (run_mha_reordered) { // reordered kv-cache bf16 mha must be used if run_mha_reordered
struct ne_tensor* KQV_Out = ne_flash_attn_with_params(ctx0, Q, K, V, &attn_op_param);
KQV_merged_gi = ne_view_2d(ctx0, KQV_Out, head_size * n_head, attn_sl * attn_bs,
head_size * n_head * ne_element_size(KQV_Out), 0);
} else if (run_mha_fp16) { // non-reordered kv-cache fp16 mha
struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags);
struct ne_tensor* KQV_Out = ne_flash_attn_with_params(ctx0, Q, K, V, &attn_op_param);
KQV_merged_gi = ne_view_2d(ctx0, KQV_Out, head_size * n_head, attn_sl * attn_bs,
head_size * n_head * ne_element_size(KQV_Out), 0);
} else if (attn_n_total == 0 && run_mha_bf16_first) {
Expand All @@ -462,37 +469,37 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
} else {
// K * Q
struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q);
ne_set_name(KQ, std::string("KQ_" + suffix).c_str());
ne_set_name(KQ, "KQ");

// KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ne_tensor* KQ_scale = ne_new_f32(ctx0, attn_scale);
ne_set_name(KQ_scale, std::string("1/sqrt(n_embd/n_head)_" + suffix).c_str());
ne_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");

// KQ_scaled shape [n_cached, N, n_head, 1]
struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, KQ_scale);
ne_set_name(KQ_scaled, std::string("KQ_scaled_" + suffix).c_str());
ne_set_name(KQ_scaled, "KQ_scaled");

// KQ_scaled = mask_past(KQ_scaled)
if (attn_n_total == 0 || !shift_roped_k || !no_padding) {
std::vector<int> attn_n_padding(infer_groups[gi].size(), 0);
for (int npa = 0; !n_padding.empty() && npa < infer_groups[gi].size(); ++npa) {
attn_n_padding[npa] = n_padding[infer_groups[gi][npa]];
std::vector<int> attn_n_padding(curr_group.size(), 0);
for (int npa = 0; !n_padding.empty() && npa < curr_group.size(); ++npa) {
attn_n_padding[npa] = n_padding[curr_group[npa]];
}
KQ_scaled = ne_diag_mask_inf_with_padding_inplace(ctx0, KQ_scaled, attn_n_past, attn_n_padding.data());
ne_set_name(KQ_scaled, std::string("KQ_masked_" + suffix).c_str());
ne_set_name(KQ_scaled, "KQ_masked");
}

// KQ = soft_max(KQ_masked)
struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_scaled);
ne_set_name(KQ_soft_max, std::string("KQ_soft_max_" + suffix).c_str());
ne_set_name(KQ_soft_max, "KQ_soft_max");

struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max);
ne_set_name(KQV, std::string("KQV_" + suffix).c_str());
ne_set_name(KQV, "KQV");

// KQV_merged = KQV.permute(0, 2, 1, 3)
KQV_merged_gi = ne_permute(ctx0, KQV, 0, 2, 1, 3);
}
ne_set_name(KQV_merged_gi, std::string("KQV_merged_" + suffix).c_str());
ne_set_name(KQV_merged_gi, "KQV_merged");
ne_build_forward_expand(&gf, ne_cpy(ctx0, KQV_merged_gi,
ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs,
head_size * n_head * ne_element_size(KQV_merged_contiguous),
Expand Down

0 comments on commit e2604fa

Please sign in to comment.