Skip to content

Commit

Permalink
repair flash attention in _ext
Browse files Browse the repository at this point in the history
this does not fix the currently broken fa behind the define, which is only used by VAE

Co-authored-by: FSSRepo <[email protected]>
  • Loading branch information
Green-Sky and FSSRepo committed Sep 1, 2024
1 parent 58d5473 commit 2634bea
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,13 +735,35 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*

float scale = (1.0f / sqrt((float)d_head));

bool use_flash_attn = false;
ggml_tensor* kqv = NULL;
LOG_DEBUG("attention_ext L_k:%d n_head:%d C:%d d_head:%d", L_k, n_head, C, d_head);

bool use_flash_attn = true;
// L_k == n_context AND l_k == n_token ????
use_flash_attn = use_flash_attn && L_k % 256 == 0;
use_flash_attn = use_flash_attn && d_head % 64 == 0; // why

if (mask != nullptr) {
// TODO: figure out if we can bend t5 to work too
use_flash_attn = use_flash_attn && mask->ne[2] == 1;
use_flash_attn = use_flash_attn && mask->ne[3] == 1;
}

// TODO: more pad or disable for funny tensor shapes

ggml_tensor* kqv = nullptr;
if (use_flash_attn) {
LOG_DEBUG("using flash attention");

k = ggml_cast(ctx, k, GGML_TYPE_F16);

v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
LOG_DEBUG("k->ne[1] == %d", k->ne[1]);
v = ggml_cast(ctx, v, GGML_TYPE_F16);

kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0);
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);

kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0);
} else {
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
Expand All @@ -757,10 +779,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
kq = ggml_soft_max_inplace(ctx, kq);

kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]

kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head]
}

kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, L_q, n_head, d_head]
kqv = ggml_cont(ctx, kqv);
kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]

return kqv;
Expand Down

0 comments on commit 2634bea

Please sign in to comment.