-
Notifications
You must be signed in to change notification settings - Fork 27.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f14637a
commit d1aa9ce
Showing
8 changed files
with
837 additions
and
353 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
|
||
from ..modeling_flash_attention_utils import _flash_attention_forward | ||
|
||
|
||
def flash_attention_forward( | ||
config, query, key, value, attentions_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs | ||
): | ||
if attentions_mask is not None: | ||
seq_len = attentions_mask.shape[1] | ||
query = query[:, :, :seq_len] | ||
value = value[:, :, :seq_len] | ||
else: | ||
seq_len = query.shape[1] | ||
|
||
dropout_rate = config.attention_dropout if training else 0.0 | ||
|
||
input_dtype = query.dtype | ||
if input_dtype == torch.float32: | ||
query = query.to(target_dtype) | ||
key = key.to(target_dtype) | ||
value = value.to(target_dtype) | ||
|
||
attn_output = _flash_attention_forward( | ||
query, | ||
key, | ||
value, | ||
attentions_mask, | ||
seq_len, | ||
config=config, | ||
dropout=dropout_rate, | ||
layer_idx=layer_idx, | ||
**kwargs, | ||
) | ||
|
||
return attn_output, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from ..utils import is_torch_greater_or_equal | ||
|
||
|
||
if is_torch_greater_or_equal("2.5"): | ||
from torch.nn.attention.flex_attention import flex_attention | ||
|
||
|
||
def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs): | ||
causal_mask = attention_mask | ||
if causal_mask is not None: | ||
causal_mask = causal_mask[:, :, :, : key.shape[-2]] | ||
|
||
def causal_mod(score, b, h, q_idx, kv_idx): | ||
if causal_mask is not None: | ||
score += causal_mask[b][0][q_idx][kv_idx] | ||
return score | ||
|
||
attn_output, attention_weights = flex_attention( | ||
query, | ||
key, | ||
value, | ||
score_mod=causal_mod, | ||
enable_gqa=True, | ||
scale=module.scaling, | ||
return_lse=True, | ||
) | ||
attn_output = attn_output.transpose(1, 2).contiguous() | ||
return attn_output, attention_weights |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
|
||
|
||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | ||
""" | ||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | ||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | ||
""" | ||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | ||
if n_rep == 1: | ||
return hidden_states | ||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | ||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | ||
|
||
|
||
def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs): | ||
key = repeat_kv(key, module.num_key_value_groups) | ||
value = repeat_kv(value, module.num_key_value_groups) | ||
|
||
causal_mask = attention_mask | ||
if attention_mask is not None: | ||
causal_mask = causal_mask[:, :, :, : key.shape[-2]] | ||
|
||
query = query.contiguous() | ||
key = key.contiguous() | ||
value = value.contiguous() | ||
|
||
is_causal = True if causal_mask is None and query.shape[1] > 1 else False | ||
attn_output = torch.nn.functional.scaled_dot_product_attention( | ||
query, | ||
key, | ||
value, | ||
attn_mask=causal_mask, | ||
dropout_p=module.config.attention_dropout if module.training else 0.0, | ||
is_causal=is_causal, | ||
scale=module.scaling, | ||
) | ||
attn_output = attn_output.transpose(1, 2).contiguous() | ||
return attn_output, None |
Oops, something went wrong.