diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index da21000cf3..bea6284fb5 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -31,6 +31,23 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, return original_is_causal +def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor: + """Perform repeat of kv heads along a particular dimension. + + hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim) + n_rep: amount of repetitions of kv_n_heads + Unlike torch.repeat_interleave, this function avoids allocating new memory. + """ + if n_rep == 1: + return hidden + + b, s, kv_n_heads, d = hidden.shape + + hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d) + + return hidden.reshape(b, s, kv_n_heads * n_rep, d) + + def scaled_multihead_dot_product_attention( query: torch.Tensor, key: torch.Tensor, @@ -84,8 +101,11 @@ def scaled_multihead_dot_product_attention( # grouped query case if kv_n_heads > 1 and kv_n_heads < n_heads: - k = k.repeat_interleave(n_heads // kv_n_heads, dim=1) - v = v.repeat_interleave(n_heads // kv_n_heads, dim=1) + # necessary to do a transpose to swap (b h s d) -> (b s h d) for repeat_kv_for_gqa function + k = repeat_kv_for_gqa(k.transpose(1, 2), + n_heads // kv_n_heads).transpose(1, 2) + v = repeat_kv_for_gqa(v.transpose(1, 2), + n_heads // kv_n_heads).transpose(1, 2) if softmax_scale is None: softmax_scale = 1 / math.sqrt(d) @@ -243,10 +263,16 @@ def flash_attn_fn( elif kv_n_heads < n_heads: # Each query belong to a group of kv heads of group size n_heads // kv_n_heads # We repeat each kv head by the group size number to use the underlying MHA kernels - # done along the head dimension = 1 - key_unpad = key_unpad.repeat_interleave(n_heads // kv_n_heads, dim=1) - value_unpad = value_unpad.repeat_interleave(n_heads // kv_n_heads, - dim=1) + + # since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d) + # we use .view to modify {key, value}_unpad appropriately + + key_unpad = repeat_kv_for_gqa( + key_unpad.view(batch_size, seqlen, kv_n_heads, -1), + n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1) + value_unpad = repeat_kv_for_gqa( + value_unpad.view(batch_size, seqlen, kv_n_heads, -1), + n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1) dropout_p = dropout_p if training else 0.0 @@ -383,9 +409,8 @@ def triton_flash_attn_fn( elif kv_n_heads < n_heads: # Each query belong to a group of kv heads of group size n_heads // kv_n_heads # We repeat each kv head by the group size number to use the underlying MHA kernels - # done along dim = 2, unlike the implementation for flash and torch attn - key = key.repeat_interleave(n_heads // kv_n_heads, dim=2) - value = value.repeat_interleave(n_heads // kv_n_heads, dim=2) + key = repeat_kv_for_gqa(key, n_heads // kv_n_heads) + value = repeat_kv_for_gqa(value, n_heads // kv_n_heads) reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) attn_output = flash_attn_func( # type: ignore