From 4b3eb19fa7f359d25f62ca9108479f71de912ebc Mon Sep 17 00:00:00 2001 From: Edoardo Cetin <32273096+Aladoro@users.noreply.github.com> Date: Wed, 15 May 2024 19:48:19 +0200 Subject: [PATCH] Fix llama model sdpa attention forward function masking bug when output_attentions=True (#30652) * Fix llama model forward function with attention=True, same-length encoded sequence. * Fix style * propagate fix to modeling_cohere, gemma, dbrx, and olmo (which copy the same sdpa masking logic from llama) * Fix style * ignore unnecessary sdpa mask converter when output_attentions=True * add tests checking sdpa and eager outputs match when output_attentions=True * Split if statements in two lines Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fix formatting * Add fix to new jetmoe model * Add missing output_attentions argument to jetmoe mask creation --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../models/cohere/modeling_cohere.py | 10 +- src/transformers/models/dbrx/modeling_dbrx.py | 11 +- .../models/gemma/modeling_gemma.py | 10 +- .../models/jetmoe/modeling_jetmoe.py | 10 +- .../models/llama/modeling_llama.py | 10 +- src/transformers/models/olmo/modeling_olmo.py | 10 +- tests/test_modeling_common.py | 314 +++++++++--------- 7 files changed, 212 insertions(+), 163 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index b25528dfe73e2e..1f08bba62011e7 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -889,7 +889,9 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) # embed positions hidden_states = inputs_embeds @@ -958,6 +960,7 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, + output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -974,7 +977,9 @@ def _update_causal_mask( # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - if self.config._attn_implementation == "sdpa" and not using_static_cache: + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1020,6 +1025,7 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" + and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 38c1fc814b1cad..eaaad0097e6740 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1123,7 +1123,10 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) # embed positions hidden_states = inputs_embeds @@ -1204,6 +1207,7 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, + output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1220,7 +1224,9 @@ def _update_causal_mask( # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - if self.config._attn_implementation == "sdpa" and not using_static_cache: + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1266,6 +1272,7 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" + and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9af816839445e6..4d9c0aaa384dd7 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -873,7 +873,9 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) # embed positions hidden_states = inputs_embeds @@ -948,6 +950,7 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, + output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -964,7 +967,9 @@ def _update_causal_mask( # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - if self.config._attn_implementation == "sdpa" and not using_static_cache: + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1008,6 +1013,7 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" + and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 15571ce9a3e02d..f8dc2c96846194 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1103,7 +1103,9 @@ def forward( " this may lead to unexpected behaviour for Flash Attention version of JetMoe. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) hidden_states = inputs_embeds @@ -1178,6 +1180,7 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, + output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1194,7 +1197,9 @@ def _update_causal_mask( # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - if self.config._attn_implementation == "sdpa" and not using_static_cache: + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1240,6 +1245,7 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" + and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c6da59fcfb3edc..2cf0979d9095a9 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -967,7 +967,9 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) # embed positions hidden_states = inputs_embeds @@ -1036,6 +1038,7 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, + output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1052,7 +1055,9 @@ def _update_causal_mask( # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - if self.config._attn_implementation == "sdpa" and not using_static_cache: + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1098,6 +1103,7 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" + and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 6a7b2f748fcf03..063f78e5db463b 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -945,7 +945,9 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) # embed positions hidden_states = inputs_embeds @@ -1015,6 +1017,7 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, + output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1031,7 +1034,9 @@ def _update_causal_mask( # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - if self.config._attn_implementation == "sdpa" and not using_static_cache: + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1077,6 +1082,7 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" + and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bfe37d30cbf2d3..64a8a348ff6b81 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3757,177 +3757,189 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): if not has_sdpa and model_sdpa.config.model_type != "falcon": raise ValueError("The SDPA model should have SDPA attention layers") - # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model, # but it would be nicer to have an efficient way to use parameterized.expand fail_cases = [] for padding_side in ["left", "right"]: for use_mask in [False, True]: - for batch_size in [1, 5]: - dummy_input = inputs_dict[model.main_input_name] + for output_attentions in [True, False]: + can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters + if not (self.has_attentions and can_output_attn) and output_attentions: + continue + for batch_size in [1, 5]: + dummy_input = inputs_dict[model.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: - dummy_input = dummy_input.to(torch_dtype) - - dummy_input = dummy_input[:batch_size] - if dummy_input.shape[0] != batch_size: if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: - extension = torch.rand( - batch_size - dummy_input.shape[0], - *dummy_input.shape[1:], - dtype=torch_dtype, - device=torch_device, - ) - dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) - else: - extension = torch.randint( - high=5, - size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), - dtype=dummy_input.dtype, - device=torch_device, - ) - dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) - - if not use_mask: - dummy_attention_mask = None - else: - dummy_attention_mask = inputs_dict.get("attention_mask", None) - if dummy_attention_mask is None: - if is_encoder_decoder: - seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1] + dummy_input = dummy_input.to(torch_dtype) + + dummy_input = dummy_input[:batch_size] + if dummy_input.shape[0] != batch_size: + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + extension = torch.rand( + batch_size - dummy_input.shape[0], + *dummy_input.shape[1:], + dtype=torch_dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) else: - seqlen = dummy_input.shape[-1] - dummy_attention_mask = ( - torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) - ) + extension = torch.randint( + high=5, + size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), + dtype=dummy_input.dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) - dummy_attention_mask = dummy_attention_mask[:batch_size] - if dummy_attention_mask.shape[0] != batch_size: - extension = torch.ones( - batch_size - dummy_attention_mask.shape[0], - *dummy_attention_mask.shape[1:], - dtype=dummy_attention_mask.dtype, - device=torch_device, - ) - dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) - dummy_attention_mask = dummy_attention_mask.to(torch_device) - - dummy_attention_mask[:] = 1 - if padding_side == "left": - dummy_attention_mask[-1, :-1] = 1 - dummy_attention_mask[-1, -4:] = 0 - elif padding_side == "right": - dummy_attention_mask[-1, 1:] = 1 - dummy_attention_mask[-1, :3] = 0 - - for enable_kernels in [False, True]: - failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}" - if is_encoder_decoder: - decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:batch_size] - if decoder_input_ids.shape[0] != batch_size: + if not use_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + if is_encoder_decoder: + seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1] + else: + seqlen = dummy_input.shape[-1] + dummy_attention_mask = ( + torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + ) + + dummy_attention_mask = dummy_attention_mask[:batch_size] + if dummy_attention_mask.shape[0] != batch_size: extension = torch.ones( - batch_size - decoder_input_ids.shape[0], - *decoder_input_ids.shape[1:], - dtype=decoder_input_ids.dtype, + batch_size - dummy_attention_mask.shape[0], + *dummy_attention_mask.shape[1:], + dtype=dummy_attention_mask.dtype, device=torch_device, ) - decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) - decoder_input_ids = decoder_input_ids.to(torch_device) - - # TODO: never an `attention_mask` arg here? - processed_inputs = { - model.main_input_name: dummy_input, - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": dummy_attention_mask, - "output_hidden_states": True, - } - else: - processed_inputs = { - model.main_input_name: dummy_input, - "output_hidden_states": True, - } - - # Otherwise fails for e.g. WhisperEncoderModel - if "attention_mask" in inspect.signature(model_eager.forward).parameters: - processed_inputs["attention_mask"] = dummy_attention_mask - - # TODO: test gradients as well (& for FA2 as well!) - with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( - enable_flash=enable_kernels, - enable_math=True, - enable_mem_efficient=enable_kernels, - ): - prepared_inputs = self._prepare_for_class(processed_inputs, model_class) - outputs_eager = model_eager(**prepared_inputs) - outputs_sdpa = model_sdpa(**prepared_inputs) - - logits_eager = ( - outputs_eager.hidden_states[-1] - if not is_encoder_decoder - else outputs_eager.decoder_hidden_states[-1] - ) - logits_sdpa = ( - outputs_sdpa.hidden_states[-1] - if not is_encoder_decoder - else outputs_sdpa.decoder_hidden_states[-1] - ) - - if torch_device in ["cpu", "cuda"]: - atol = atols[torch_device, enable_kernels, torch_dtype] - rtol = rtols[torch_device, enable_kernels, torch_dtype] - else: - atol = 1e-7 - rtol = 1e-4 + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + dummy_attention_mask = dummy_attention_mask.to(torch_device) - # Masked tokens output slightly deviates - we don't mind that. - if use_mask: + dummy_attention_mask[:] = 1 if padding_side == "left": - sub_sdpa = logits_sdpa[:-1] - sub_eager = logits_eager[:-1] - if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): - fail_cases.append( - get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) - ) + dummy_attention_mask[-1, :-1] = 1 + dummy_attention_mask[-1, -4:] = 0 + elif padding_side == "right": + dummy_attention_mask[-1, 1:] = 1 + dummy_attention_mask[-1, :3] = 0 - sub_sdpa = logits_sdpa[-1, :-4] - sub_eager = logits_eager[-1, :-4] - if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): - fail_cases.append( - get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + for enable_kernels in [False, True]: + failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}" + if is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[ + :batch_size + ] + if decoder_input_ids.shape[0] != batch_size: + extension = torch.ones( + batch_size - decoder_input_ids.shape[0], + *decoder_input_ids.shape[1:], + dtype=decoder_input_ids.dtype, + device=torch_device, ) + decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) + decoder_input_ids = decoder_input_ids.to(torch_device) + + # TODO: never an `attention_mask` arg here? + processed_inputs = { + model.main_input_name: dummy_input, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + else: + processed_inputs = { + model.main_input_name: dummy_input, + "output_hidden_states": True, + } + + # Otherwise fails for e.g. WhisperEncoderModel + if "attention_mask" in inspect.signature(model_eager.forward).parameters: + processed_inputs["attention_mask"] = dummy_attention_mask + + if ( + self.has_attentions + and "output_attentions" in inspect.signature(model_sdpa.forward).parameters + ): + processed_inputs["output_attentions"] = output_attentions + + # TODO: test gradients as well (& for FA2 as well!) + with torch.no_grad(): + with torch.backends.cuda.sdp_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) + + logits_eager = ( + outputs_eager.hidden_states[-1] + if not is_encoder_decoder + else outputs_eager.decoder_hidden_states[-1] + ) + logits_sdpa = ( + outputs_sdpa.hidden_states[-1] + if not is_encoder_decoder + else outputs_sdpa.decoder_hidden_states[-1] + ) - # Testing the padding tokens is not really meaningful but anyway - # sub_sdpa = logits_sdpa[-1, -4:] - # sub_eager = logits_eager[-1, -4:] - # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): - # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) - elif padding_side == "right": - sub_sdpa = logits_sdpa[:-1] - sub_eager = logits_eager[:-1] - if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): - fail_cases.append( - get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) - ) + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_mask: + if padding_side == "left": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, :-4] + sub_eager = logits_eager[-1, :-4] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, -4:] + # sub_eager = logits_eager[-1, -4:] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) + elif padding_side == "right": + sub_sdpa = logits_sdpa[:-1] + sub_eager = logits_eager[:-1] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + sub_sdpa = logits_sdpa[-1, 3:] + sub_eager = logits_eager[-1, 3:] + if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + fail_cases.append( + get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + ) + + # Testing the padding tokens is not really meaningful but anyway + # sub_sdpa = logits_sdpa[-1, :3] + # sub_eager = logits_eager[-1, :3] + # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) - sub_sdpa = logits_sdpa[-1, 3:] - sub_eager = logits_eager[-1, 3:] - if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): + else: + if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol): fail_cases.append( - get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol) + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) ) - # Testing the padding tokens is not really meaningful but anyway - # sub_sdpa = logits_sdpa[-1, :3] - # sub_eager = logits_eager[-1, :3] - # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol): - # fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2)) - - else: - if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol): - fail_cases.append( - get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) - ) - self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) @require_torch_sdpa