From 5728b5ad0071be8fa062f8b72c1345343d9b1a48 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:51:17 +0200 Subject: [PATCH] FIX: Fixes unexpected behaviour for Llava / LLama & AWQ Fused modules + revert #30070 at the same time (#30317) * Update awq.py * style * revert felix PR * fix * add felix comments --- src/transformers/integrations/awq.py | 27 ++++- src/transformers/modeling_attn_mask_utils.py | 99 ++++++++++++------- .../models/cohere/modeling_cohere.py | 36 +++++-- .../models/gemma/modeling_gemma.py | 36 +++++-- .../models/llama/modeling_llama.py | 37 +++++-- src/transformers/models/olmo/modeling_olmo.py | 34 +++++-- tests/test_modeling_common.py | 36 +++++++ 7 files changed, 230 insertions(+), 75 deletions(-) diff --git a/src/transformers/integrations/awq.py b/src/transformers/integrations/awq.py index 3f9f0d1d216f1c..a543860f100396 100644 --- a/src/transformers/integrations/awq.py +++ b/src/transformers/integrations/awq.py @@ -229,6 +229,8 @@ def fuse_awq_modules(model, quantization_config): else: raise ValueError("Fusing is only supported for the AutoAWQ backend") + fused_attention_modules = [] + for name, module in model.named_modules(): if modules_to_not_convert is not None: if any(module_name_to_not_convert in name for module_name_to_not_convert in modules_to_not_convert): @@ -241,7 +243,23 @@ def fuse_awq_modules(model, quantization_config): _fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP) # Replace attention layers - _fuse_awq_attention_layers(model, module, modules_to_fuse, name, QuantAttentionFused) + attention_has_been_fused = _fuse_awq_attention_layers( + model, module, modules_to_fuse, name, QuantAttentionFused + ) + + if attention_has_been_fused: + fused_attention_modules.append(name.split(".")[0]) + + # For AWQ fused + Llama we need to set `config._attn_implementation` = "custom" to avoid unexpected behavior and pass + # `None` attention mask to the fused attention modules as now the attention mask is dropped by our models and dealt + # by the `AttentionMaskConverter` module. + if len(fused_attention_modules) > 0: + for module_name, module in model.named_modules(): + if any( + module_name in fused_attention_modules for fused_attention_parent_module in fused_attention_modules + ): + if hasattr(module, "config") and hasattr(module.config, "_attn_implementation"): + module.config._attn_implementation = "custom" return model @@ -332,8 +350,10 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na """ from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV + module_has_been_fused = False + if len(modules_to_fuse["attention"]) == 0: - return + return module_has_been_fused if hasattr(module, modules_to_fuse["attention"][0]): # First, we pack the QKV layers together @@ -394,6 +414,9 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na setattr(parent, child_name, fused_attention_layer.to(previous_device)) del q_proj, k_proj, v_proj, o_proj + module_has_been_fused = True + + return module_has_been_fused def post_init_awq_exllama_modules(model, exllama_config): diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 130f7100b21d28..c69d9555b2afc8 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -234,6 +234,63 @@ def _unmask_unattended( return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) + @staticmethod + def _ignore_causal_mask_sdpa( + attention_mask: Optional[torch.Tensor], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, + ) -> bool: + """ + Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. + + In case no token is masked in the `attention_mask` argument, if `query_length == 1` or + `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). + """ + + batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] + key_value_length = query_length + past_key_values_length + + is_tracing = ( + torch.jit.is_tracing() + or isinstance(inputs_embeds, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + + ignore_causal_mask = False + + if attention_mask is None: + # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or + # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). + # Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag. + # + # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`). + if ( + not is_tracing + and (query_length == 1 or key_value_length == query_length) + and (sliding_window is None or key_value_length < sliding_window) + ): + ignore_causal_mask = True + elif sliding_window is None or key_value_length < sliding_window: + if len(attention_mask.shape) == 4: + expected_shape = (batch_size, 1, query_length, key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + elif not is_tracing and torch.all(attention_mask == 1): + if query_length == 1 or key_value_length == query_length: + # For query_length == 1, causal attention and bi-directional attention are the same. + ignore_causal_mask = True + + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation + # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. + + return ignore_causal_mask + def _prepare_4d_causal_attention_mask( attention_mask: Optional[torch.Tensor], @@ -305,7 +362,6 @@ def _prepare_4d_causal_attention_mask_for_sdpa( attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) key_value_length = input_shape[-1] + past_key_values_length - _, query_length = input_shape # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. @@ -316,41 +372,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa( or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) ) - ignore_causal_mask = False - - if attention_mask is None: - if ( - not is_tracing - and (query_length == 1 or key_value_length == query_length) - and (sliding_window is None or key_value_length < sliding_window) - ): - ignore_causal_mask = True - elif sliding_window is None or key_value_length < sliding_window: - # 4d mask is passed through - if len(attention_mask.shape) == 4: - expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) - if tuple(attention_mask.shape) != expected_shape: - raise ValueError( - f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." - ) - else: - # if the 4D mask has correct shape - invert it and fill with negative infinity - inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype) - attention_mask = inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min - ) - return attention_mask - - elif not is_tracing and torch.all(attention_mask == 1): - if query_length == 1: - # For query_length == 1, causal attention and bi-directional attention are the same. - ignore_causal_mask = True - elif key_value_length == query_length: - ignore_causal_mask = True - - # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation - # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. - # Reference: https://github.com/pytorch/pytorch/issues/108108 + ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) if ignore_causal_mask: expanded_4d_mask = None diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 95a7d768273eeb..950d45ea867a30 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -590,12 +590,15 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather + # relying on the `is_causal` argument. attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -908,9 +911,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] - ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) # embed positions hidden_states = inputs_embeds @@ -974,16 +975,31 @@ def forward( attentions=all_self_attns, ) - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_seen_tokens: int, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "sdpa": + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -991,7 +1007,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, curr target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c8b9b11c557972..6077259d0b0fac 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -570,12 +570,15 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather + # relying on the `is_causal` argument. attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -888,9 +891,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] - ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) # embed positions hidden_states = inputs_embeds @@ -960,16 +961,31 @@ def forward( attentions=all_self_attns, ) - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_seen_tokens: int, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "sdpa": + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -977,7 +993,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, curr target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e1afb61be0dfc6..2b8e8f6d0958dd 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -656,7 +656,6 @@ def forward( value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask - # if attention_mask is not None and cache_position is not None: if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] @@ -667,12 +666,15 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather + # relying on the `is_causal` argument. attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -987,9 +989,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] - ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) # embed positions hidden_states = inputs_embeds @@ -1053,16 +1053,31 @@ def forward( attentions=all_self_attns, ) - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_seen_tokens: int, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "sdpa": + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -1070,7 +1085,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, curr target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index b8fb01d7b23cad..83637536a12531 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -653,6 +653,7 @@ def forward( value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -970,9 +971,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1] - ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) # embed positions hidden_states = inputs_embeds @@ -1036,17 +1035,32 @@ def forward( attentions=all_self_attns, ) - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask - def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length): + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_seen_tokens: int, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if self.config._attn_implementation == "sdpa": + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -1054,7 +1068,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, curr target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1 + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a3cbcc081857a5..71cb28d7548555 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3772,6 +3772,42 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + @require_torch_sdpa + @require_torch_gpu + @slow + def test_sdpa_can_dispatch_on_flash(self): + compute_capability = torch.cuda.get_device_capability() + major, _ = compute_capability + + if not torch.version.cuda or major < 8: + self.skipTest("This test requires an NVIDIA GPU with compute capability >= 8.0") + + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if config.model_type in ["llava", "llava_next", "vipllava"]: + self.skipTest("Llava-like models currently (transformers==4.39.1) requires an attention_mask input") + if config.model_type in ["idefics"]: + self.skipTest("Idefics currently (transformers==4.39.1) requires an image_attention_mask input") + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa") + model.to(torch_device) + + inputs_dict.pop("attention_mask", None) + inputs_dict.pop("decoder_attention_mask", None) + + for name, inp in inputs_dict.items(): + if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: + inputs_dict[name] = inp.to(torch.float16) + + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + _ = model(**inputs_dict) + @require_torch_sdpa @slow def test_eager_matches_sdpa_generate(self):