Skip to content

Commit

Permalink
minimal changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Dec 12, 2024
1 parent f14637a commit d1aa9ce
Show file tree
Hide file tree
Showing 8 changed files with 837 additions and 353 deletions.
36 changes: 36 additions & 0 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch

from ..modeling_flash_attention_utils import _flash_attention_forward


def flash_attention_forward(
config, query, key, value, attentions_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
):
if attentions_mask is not None:
seq_len = attentions_mask.shape[1]
query = query[:, :, :seq_len]
value = value[:, :, :seq_len]
else:
seq_len = query.shape[1]

dropout_rate = config.attention_dropout if training else 0.0

input_dtype = query.dtype
if input_dtype == torch.float32:
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)

attn_output = _flash_attention_forward(
query,
key,
value,
attentions_mask,
seq_len,
config=config,
dropout=dropout_rate,
layer_idx=layer_idx,
**kwargs,
)

return attn_output, None
28 changes: 28 additions & 0 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from ..utils import is_torch_greater_or_equal


if is_torch_greater_or_equal("2.5"):
from torch.nn.attention.flex_attention import flex_attention


def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs):
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]

def causal_mod(score, b, h, q_idx, kv_idx):
if causal_mask is not None:
score += causal_mask[b][0][q_idx][kv_idx]
return score

attn_output, attention_weights = flex_attention(
query,
key,
value,
score_mod=causal_mod,
enable_gqa=True,
scale=module.scaling,
return_lse=True,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attention_weights
39 changes: 39 additions & 0 deletions src/transformers/integrations/sdpa_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)

causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]

query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

is_causal = True if causal_mask is None and query.shape[1] > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=module.config.attention_dropout if module.training else 0.0,
is_causal=is_causal,
scale=module.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
Loading

0 comments on commit d1aa9ce

Please sign in to comment.