Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaner attention interfaces #35342

Merged
merged 5 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,13 @@ def flash_attention_forward(
key = key.transpose(1, 2)
value = value.transpose(1, 2)

if query.dtype == torch.float32:
query = query.to(torch.float16)
key = key.to(torch.float16)
value = value.to(torch.float16)

Cyrilvallez marked this conversation as resolved.
Show resolved Hide resolved
attn_output = _flash_attention_forward(
query,
key,
value,
attention_mask,
seq_len,
module.is_causal,
query_length=seq_len,
is_causal=module.is_causal,
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import torch

from ..utils import is_torch_greater_or_equal
from ..utils import is_torch_flex_attn_available


if is_torch_greater_or_equal("2.5"):
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import flex_attention


Expand Down Expand Up @@ -37,8 +37,12 @@ def causal_mod(score, b, h, q_idx, kv_idx):
score_mod=causal_mod,
enable_gqa=True,
scale=scaling,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
)
# lse is returned in float32
attention_weights = attention_weights.to(value.dtype)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attention_weights
25 changes: 22 additions & 3 deletions src/transformers/integrations/sdpa_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
from typing import Optional, Tuple

import torch
from packaging import version

from ..utils import get_torch_version


def sdpa_needs_contiguous_inputs():
"""Check if the currently installed version of torch sdpa is bugged with non-contiguous inputs.
Reference: https://github.com/pytorch/pytorch/issues/112577
"""
torch_version = version.parse(get_torch_version())
return version.parse("2.1.0") <= torch_version < version.parse("2.2.0")


_needs_contiguous_inputs = sdpa_needs_contiguous_inputs()


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -34,10 +48,15 @@ def sdpa_attention_forward(
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]

query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if _needs_contiguous_inputs:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could also just always apply contiguous!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, calling contiguous doesn't affect performance much either way so the comment clarification suffices imo


# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
if is_causal is None:
is_causal = causal_mask is None and query.shape[2] > 1

Expand Down
9 changes: 3 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,12 +1473,7 @@ def _autoset_attn_implementation(
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)

if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
"eager",
"sdpa",
"flash_attention_2",
"flex_attention",
]:
if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in ["eager"] + list(ALL_ATTENTION_FUNCTIONS.keys()):
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
Expand Down Expand Up @@ -1540,6 +1535,8 @@ def _autoset_attn_implementation(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.keys():
config._attn_implementation = requested_attn_implementation
elif isinstance(requested_attn_implementation, dict):
config._attn_implementation = None
else:
Expand Down
Loading