Skip to content

Commit

Permalink
Cleaner attention interfaces (#35342)
Browse files Browse the repository at this point in the history
* cleaner attention interfaces

* correctly set the _attn_implementation when adding other functions to it

* update

* Update modeling_utils.py

* CIs
  • Loading branch information
Cyrilvallez authored Dec 20, 2024
1 parent eafbb0e commit 0d51d65
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 12 deletions.
21 changes: 16 additions & 5 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,34 @@ def flash_attention_forward(
key = key.transpose(1, 2)
value = value.transpose(1, 2)

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (usually our RMSNorm modules handle it correctly)
target_dtype = None
if query.dtype == torch.float32:
query = query.to(torch.float16)
key = key.to(torch.float16)
value = value.to(torch.float16)
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(module.config, "_pre_quantization_dtype"):
target_dtype = module.config._pre_quantization_dtype
else:
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype

attn_output = _flash_attention_forward(
query,
key,
value,
attention_mask,
seq_len,
module.is_causal,
query_length=seq_len,
is_causal=module.is_causal,
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
softcap=softcap,
use_top_left_mask=_use_top_left_mask,
target_dtype=target_dtype,
**kwargs,
)

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import torch

from ..utils import is_torch_greater_or_equal
from ..utils import is_torch_flex_attn_available


if is_torch_greater_or_equal("2.5"):
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import flex_attention


Expand Down Expand Up @@ -37,8 +37,12 @@ def causal_mod(score, b, h, q_idx, kv_idx):
score_mod=causal_mod,
enable_gqa=True,
scale=scaling,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
)
# lse is returned in float32
attention_weights = attention_weights.to(value.dtype)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attention_weights
4 changes: 4 additions & 0 deletions src/transformers/integrations/sdpa_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ def sdpa_attention_forward(
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]

# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
# Reference: https://github.com/pytorch/pytorch/issues/112577.
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
if is_causal is None:
is_causal = causal_mask is None and query.shape[2] > 1

Expand Down
9 changes: 4 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,11 +1474,8 @@ def _autoset_attn_implementation(
)

if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
"eager",
"sdpa",
"flash_attention_2",
"flex_attention",
]:
"eager"
] + list(ALL_ATTENTION_FUNCTIONS.keys()):
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
Expand Down Expand Up @@ -1540,6 +1537,8 @@ def _autoset_attn_implementation(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()):
config._attn_implementation = requested_attn_implementation
elif isinstance(requested_attn_implementation, dict):
config._attn_implementation = None
else:
Expand Down

0 comments on commit 0d51d65

Please sign in to comment.