diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 1be223f8b079ba..6e3b629c9e6bae 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -29,18 +29,13 @@ def flash_attention_forward( key = key.transpose(1, 2) value = value.transpose(1, 2) - if query.dtype == torch.float32: - query = query.to(torch.float16) - key = key.to(torch.float16) - value = value.to(torch.float16) - attn_output = _flash_attention_forward( query, key, value, attention_mask, - seq_len, - module.is_causal, + query_length=seq_len, + is_causal=module.is_causal, dropout=dropout, softmax_scale=scaling, sliding_window=sliding_window, diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index eacfb2b568b55b..66ffc5638838cb 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -2,10 +2,10 @@ import torch -from ..utils import is_torch_greater_or_equal +from ..utils import is_torch_flex_attn_available -if is_torch_greater_or_equal("2.5"): +if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import flex_attention @@ -37,8 +37,12 @@ def causal_mod(score, b, h, q_idx, kv_idx): score_mod=causal_mod, enable_gqa=True, scale=scaling, + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. return_lse=True, ) + # lse is returned in float32 + attention_weights = attention_weights.to(value.dtype) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attention_weights diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 265260c9b79e4c..6a94a4030b5558 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -1,6 +1,20 @@ from typing import Optional, Tuple import torch +from packaging import version + +from ..utils import get_torch_version + + +def sdpa_needs_contiguous_inputs(): + """Check if the currently installed version of torch sdpa is bugged with non-contiguous inputs. + Reference: https://github.com/pytorch/pytorch/issues/112577 + """ + torch_version = version.parse(get_torch_version()) + return version.parse("2.1.0") <= torch_version < version.parse("2.2.0") + + +_needs_contiguous_inputs = sdpa_needs_contiguous_inputs() def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -34,10 +48,15 @@ def sdpa_attention_forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() + # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if _needs_contiguous_inputs: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. if is_causal is None: is_causal = causal_mask is None and query.shape[2] > 1