Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change repeat to expand in GQA #628

Merged
merged 3 commits into from
Sep 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
return original_is_causal


def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Perform repeat of kv heads along a particular dimension.

hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
n_rep: amount of repetitions of kv_n_heads
Unlike torch.repeat_interleave, this function avoids allocating new memory.
"""
if n_rep == 1:
return hidden

b, s, kv_n_heads, d = hidden.shape

hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)

return hidden.reshape(b, s, kv_n_heads * n_rep, d)


def scaled_multihead_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -84,8 +101,11 @@ def scaled_multihead_dot_product_attention(

# grouped query case
if kv_n_heads > 1 and kv_n_heads < n_heads:
k = k.repeat_interleave(n_heads // kv_n_heads, dim=1)
v = v.repeat_interleave(n_heads // kv_n_heads, dim=1)
# necessary to do a transpose to swap (b h s d) -> (b s h d) for repeat_kv_for_gqa function
k = repeat_kv_for_gqa(k.transpose(1, 2),
n_heads // kv_n_heads).transpose(1, 2)
v = repeat_kv_for_gqa(v.transpose(1, 2),
n_heads // kv_n_heads).transpose(1, 2)

if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
Expand Down Expand Up @@ -243,10 +263,16 @@ def flash_attn_fn(
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels
# done along the head dimension = 1
key_unpad = key_unpad.repeat_interleave(n_heads // kv_n_heads, dim=1)
value_unpad = value_unpad.repeat_interleave(n_heads // kv_n_heads,
dim=1)

# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)

dropout_p = dropout_p if training else 0.0

Expand Down Expand Up @@ -383,9 +409,8 @@ def triton_flash_attn_fn(
elif kv_n_heads < n_heads:
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
# We repeat each kv head by the group size number to use the underlying MHA kernels
# done along dim = 2, unlike the implementation for flash and torch attn
key = key.repeat_interleave(n_heads // kv_n_heads, dim=2)
value = value.repeat_interleave(n_heads // kv_n_heads, dim=2)
key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
attn_output = flash_attn_func( # type: ignore
Expand Down