diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 1be223f8b079ba..b8407bc29c6a8a 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -29,23 +29,34 @@ def flash_attention_forward( key = key.transpose(1, 2) value = value.transpose(1, 2) + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (usually our RMSNorm modules handle it correctly) + target_dtype = None if query.dtype == torch.float32: - query = query.to(torch.float16) - key = key.to(torch.float16) - value = value.to(torch.float16) + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(module.config, "_pre_quantization_dtype"): + target_dtype = module.config._pre_quantization_dtype + else: + target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype attn_output = _flash_attention_forward( query, key, value, attention_mask, - seq_len, - module.is_causal, + query_length=seq_len, + is_causal=module.is_causal, dropout=dropout, softmax_scale=scaling, sliding_window=sliding_window, softcap=softcap, use_top_left_mask=_use_top_left_mask, + target_dtype=target_dtype, **kwargs, ) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index eacfb2b568b55b..66ffc5638838cb 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -2,10 +2,10 @@ import torch -from ..utils import is_torch_greater_or_equal +from ..utils import is_torch_flex_attn_available -if is_torch_greater_or_equal("2.5"): +if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import flex_attention @@ -37,8 +37,12 @@ def causal_mod(score, b, h, q_idx, kv_idx): score_mod=causal_mod, enable_gqa=True, scale=scaling, + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. return_lse=True, ) + # lse is returned in float32 + attention_weights = attention_weights.to(value.dtype) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attention_weights diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 265260c9b79e4c..38701690bf7c2a 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -34,10 +34,14 @@ def sdpa_attention_forward( if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] + # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions + # Reference: https://github.com/pytorch/pytorch/issues/112577. query = query.contiguous() key = key.contiguous() value = value.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. if is_causal is None: is_causal = causal_mask is None and query.shape[2] > 1 diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9dcd6d758ecbe7..49d086c76e8683 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1474,11 +1474,8 @@ def _autoset_attn_implementation( ) if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [ - "eager", - "sdpa", - "flash_attention_2", - "flex_attention", - ]: + "eager" + ] + list(ALL_ATTENTION_FUNCTIONS.keys()): message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' if cls._supports_flash_attn_2: message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' @@ -1540,6 +1537,8 @@ def _autoset_attn_implementation( "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." ) torch.backends.cuda.enable_flash_sdp(False) + elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()): + config._attn_implementation = requested_attn_implementation elif isinstance(requested_attn_implementation, dict): config._attn_implementation = None else: