Skip to content

Commit

Permalink
cleaner attention interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 19, 2024
1 parent 56ff1e9 commit f74a08e
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
9 changes: 2 additions & 7 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,13 @@ def flash_attention_forward(
key = key.transpose(1, 2)
value = value.transpose(1, 2)

if query.dtype == torch.float32:
query = query.to(torch.float16)
key = key.to(torch.float16)
value = value.to(torch.float16)

attn_output = _flash_attention_forward(
query,
key,
value,
attention_mask,
seq_len,
module.is_causal,
query_length=seq_len,
is_causal=module.is_causal,
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import torch

from ..utils import is_torch_greater_or_equal
from ..utils import is_torch_flex_attn_available


if is_torch_greater_or_equal("2.5"):
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import flex_attention


Expand Down Expand Up @@ -37,8 +37,12 @@ def causal_mod(score, b, h, q_idx, kv_idx):
score_mod=causal_mod,
enable_gqa=True,
scale=scaling,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
)
# lse is returned in float32
attention_weights = attention_weights.to(value.dtype)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attention_weights
25 changes: 22 additions & 3 deletions src/transformers/integrations/sdpa_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
from typing import Optional, Tuple

import torch
from packaging import version

from ..utils import get_torch_version


def sdpa_needs_contiguous_inputs():
"""Check if the currently installed version of torch sdpa is bugged with non-contiguous inputs.
Reference: https://github.com/pytorch/pytorch/issues/112577
"""
torch_version = version.parse(get_torch_version())
return version.parse("2.1.0") <= torch_version < version.parse("2.2.0")


_needs_contiguous_inputs = sdpa_needs_contiguous_inputs()


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -34,10 +48,15 @@ def sdpa_attention_forward(
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]

query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if _needs_contiguous_inputs:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
if is_causal is None:
is_causal = causal_mask is None and query.shape[2] > 1

Expand Down

0 comments on commit f74a08e

Please sign in to comment.