Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaner attention interfaces #35342

Merged
merged 5 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,34 @@ def flash_attention_forward(
key = key.transpose(1, 2)
value = value.transpose(1, 2)

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (usually our RMSNorm modules handle it correctly)
target_dtype = None
if query.dtype == torch.float32:
query = query.to(torch.float16)
key = key.to(torch.float16)
value = value.to(torch.float16)
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(module.config, "_pre_quantization_dtype"):
target_dtype = module.config._pre_quantization_dtype
else:
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype

Cyrilvallez marked this conversation as resolved.
Show resolved Hide resolved
attn_output = _flash_attention_forward(
query,
key,
value,
attention_mask,
seq_len,
module.is_causal,
query_length=seq_len,
is_causal=module.is_causal,
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
softcap=softcap,
use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype,
**kwargs,
)

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import torch

from ..utils import is_torch_greater_or_equal
from ..utils import is_torch_flex_attn_available


if is_torch_greater_or_equal("2.5"):
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import flex_attention


Expand Down Expand Up @@ -37,8 +37,12 @@ def causal_mod(score, b, h, q_idx, kv_idx):
score_mod=causal_mod,
enable_gqa=True,
scale=scaling,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
)
# lse is returned in float32
attention_weights = attention_weights.to(value.dtype)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attention_weights
4 changes: 4 additions & 0 deletions src/transformers/integrations/sdpa_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ def sdpa_attention_forward(
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]

# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
# Reference: https://github.com/pytorch/pytorch/issues/112577.
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
if is_causal is None:
is_causal = causal_mask is None and query.shape[2] > 1

Expand Down
9 changes: 4 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,11 +1474,8 @@ def _autoset_attn_implementation(
)

if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
"eager",
"sdpa",
"flash_attention_2",
"flex_attention",
]:
"eager"
] + list(ALL_ATTENTION_FUNCTIONS.keys()):
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
Expand Down Expand Up @@ -1540,6 +1537,8 @@ def _autoset_attn_implementation(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()):
config._attn_implementation = requested_attn_implementation
elif isinstance(requested_attn_implementation, dict):
config._attn_implementation = None
else:
Expand Down
Loading