diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 8a106cae7386f1..e6bc7c1668615a 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -232,6 +232,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel) +* [ESM](https://huggingface.co/docs/transformers/model_doc/esm#transformers.ESMModel) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 5df5435bb1229a..32b3ace49e344e 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -19,10 +19,15 @@ import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -31,7 +36,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import logging +from ...utils import get_torch_version, logging from .configuration_esm import EsmConfig @@ -88,7 +93,6 @@ def __init__(self, dim: int): super().__init__() # Generate and save the inverse frequency buffer (non trainable) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) - inv_freq = inv_freq self.register_buffer("inv_freq", inv_freq) self._seq_len_cached = None @@ -102,12 +106,20 @@ def _update_cos_sin_tables(self, x, seq_dimension=2): # or if we're on a new device (possibly due to tracing for instance) if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: self._seq_len_cached = seq_len - t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) - freqs = torch.outer(t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + t = torch.arange(x.shape[seq_dimension], device=x.device).float() + + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = torch.outer(t.float(), self.inv_freq.float()) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + cos = emb.cos()[None, None, :, :] + sin = emb.sin()[None, None, :, :] - self._cos_cached = emb.cos()[None, None, :, :] - self._sin_cached = emb.sin()[None, None, :, :] + self._cos_cached = cos.to(dtype=x.dtype) + self._sin_cached = sin.to(dtype=x.dtype) return self._cos_cached, self._sin_cached @@ -370,7 +382,7 @@ def forward( if head_mask is not None: attention_probs = attention_probs * head_mask - context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer) + context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) @@ -383,6 +395,113 @@ def forward( return outputs +class EsmSdpaSelfAttention(EsmSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type) + self.attention_dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type not in ["absolute", "rotary"] or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "EsmSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute or non-rotary `position_embedding_type` or `output_attentions=True` or `head_mask`. " + "Falling back to the manual attention implementation, but specifying the manual implementation will " + "be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + # Scale the query for rotary embeddings + query_layer = query_layer * self.attention_head_size**-0.5 + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + if self.position_embedding_type == "rotary": + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.attention_dropout_prob if self.training else 0.0, + is_causal=is_causal, + scale=1.0, # Scale is already applied to query_layer + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + class EsmSelfOutput(nn.Module): def __init__(self, config): super().__init__() @@ -396,10 +515,17 @@ def forward(self, hidden_states, input_tensor): return hidden_states +ESM_SELF_ATTENTION_CLASSES = { + "eager": EsmSelfAttention, + "sdpa": EsmSdpaSelfAttention, +} + + class EsmAttention(nn.Module): def __init__(self, config): super().__init__() - self.self = EsmSelfAttention(config) + # self.self = EsmSelfAttention(config) + self.self = ESM_SELF_ATTENTION_CLASSES[config._attn_implementation](config) self.output = EsmSelfOutput(config) self.pruned_heads = set() self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -680,6 +806,7 @@ class EsmPreTrainedModel(PreTrainedModel): base_model_prefix = "esm" supports_gradient_checkpointing = True _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"] + _supports_sdpa = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -787,6 +914,9 @@ def __init__(self, config, add_pooling_layer=True): in_features=config.num_hidden_layers * config.num_attention_heads, bias=True ) + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + # Initialize weights and apply final processing self.post_init() @@ -875,9 +1005,40 @@ def forward( if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type in ["absolute", "rotary"] + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks and attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -886,7 +1047,15 @@ def forward( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + + if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -897,13 +1066,6 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index 56a7e4d0c67fa1..6c2acc7560209c 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -16,6 +16,8 @@ import unittest +from parameterized import parameterized + from transformers import EsmConfig, is_torch_available from transformers.testing_utils import TestCasePlus, require_bitsandbytes, require_torch, slow, torch_device @@ -306,7 +308,7 @@ def test_inference_masked_lm(self): with torch.no_grad(): model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D") model.eval() - input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) + input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]).to(model.device) output = model(input_ids)[0] vocab_size = 33 @@ -316,7 +318,7 @@ def test_inference_masked_lm(self): expected_slice = torch.tensor( [[[8.9215, -10.5898, -6.4671], [-6.3967, -13.9114, -1.1212], [-7.7812, -13.9516, -3.7406]]] - ) + ).to(model.device) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) def test_inference_no_head(self): @@ -324,25 +326,21 @@ def test_inference_no_head(self): model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D") model.eval() - input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) + input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]).to(model.device) output = model(input_ids)[0] # compare the actual values for a slice. expected_slice = torch.tensor( [[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]] - ) + ).to(model.device) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) + @parameterized.expand([({"load_in_8bit": True},), ({"load_in_4bit": True},)]) @require_bitsandbytes - def test_inference_bitsandbytes(self): - model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_8bit=True) + def test_inference_bitsandbytes(self, bnb_kwargs): + model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", **bnb_kwargs) + model.eval() - input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) + input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]).to(model.device) # Just test if inference works with torch.no_grad(): _ = model(input_ids)[0] - - model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_4bit=True) - - input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]]) - # Just test if inference works - _ = model(input_ids)[0] diff --git a/tests/models/esm/test_modeling_esmfold.py b/tests/models/esm/test_modeling_esmfold.py index 5c05efb03f2fdd..5f1c03bf89e268 100644 --- a/tests/models/esm/test_modeling_esmfold.py +++ b/tests/models/esm/test_modeling_esmfold.py @@ -14,10 +14,16 @@ # limitations under the License. """Testing suite for the PyTorch ESM model.""" +import inspect +import tempfile import unittest +import numpy as np +from parameterized import parameterized + from transformers import EsmConfig, is_torch_available -from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device +from transformers.testing_utils import TestCasePlus, require_torch, require_torch_sdpa, slow, torch_device +from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -266,6 +272,203 @@ def test_torchscript_simple(self): def test_multi_gpu_data_parallel_forward(self): pass + # Modified from test_modeling_common.py as ESMFold doesn't support output hidden states + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead. + if torch_dtype == "float16": + torch_dtype = torch.float16 + elif torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "float32": + torch_dtype = torch.float32 + + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + + def get_mean_reldiff(failcase, x, ref, atol, rtol): + return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + config.layer_norm_eps = 1.0 + + for model_class in self.all_model_classes: + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # Note: the half precision will only be applied to backbone model + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + model_eager = model_class.from_pretrained( + tmpdirname, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + model_sdpa.esm.to(torch_dtype) + model_eager.esm.to(torch_dtype) + + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model, + # but it would be nicer to have an efficient way to use parameterized.expand + fail_cases = [] + for padding_side in ["left", "right"]: + for use_mask in [False, True]: + # TODO: if we can also check with `batch_size=1` without being flaky? + for batch_size in [7]: + dummy_input = inputs_dict[model.main_input_name] + + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + dummy_input = dummy_input.to(torch_dtype) + + dummy_input = dummy_input[:batch_size] + if dummy_input.shape[0] != batch_size: + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + extension = torch.rand( + batch_size - dummy_input.shape[0], + *dummy_input.shape[1:], + dtype=torch_dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + else: + extension = torch.randint( + high=5, + size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), + dtype=dummy_input.dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + + if not use_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + seqlen = dummy_input.shape[-1] + dummy_attention_mask = ( + torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + ) + + dummy_attention_mask = dummy_attention_mask[:batch_size] + if dummy_attention_mask.shape[0] != batch_size: + extension = torch.ones( + batch_size - dummy_attention_mask.shape[0], + *dummy_attention_mask.shape[1:], + dtype=dummy_attention_mask.dtype, + device=torch_device, + ) + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + dummy_attention_mask = dummy_attention_mask.to(torch_device) + + dummy_attention_mask[:] = 1 + if padding_side == "left": + dummy_attention_mask[-1, :2] = 0 + dummy_attention_mask[-1, 2:] = 1 + elif padding_side == "right": + dummy_attention_mask[-1, -2:] = 0 + dummy_attention_mask[-1, :-2] = 1 + + for enable_kernels in [False, True]: + failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}" + processed_inputs = { + model.main_input_name: dummy_input, + } + + # Otherwise fails for e.g. WhisperEncoderModel + if "attention_mask" in inspect.signature(model_eager.forward).parameters: + processed_inputs["attention_mask"] = dummy_attention_mask + + # TODO: test gradients as well (& for FA2 as well!) + with torch.no_grad(): + with torch.backends.cuda.sdp_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) + + logits_eager = outputs_eager.lm_logits + logits_sdpa = outputs_sdpa.lm_logits + + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_mask: + _logits_sdpa = torch.zeros_like(input=logits_sdpa) + _logits_eager = torch.zeros_like(input=logits_eager) + + _logits_sdpa[:-1] = logits_sdpa[:-1] + _logits_eager[:-1] = logits_eager[:-1] + + if padding_side == "left": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:] + _logits_eager[-1:, 2:] = logits_eager[-1:, 2:] + + elif padding_side == "right": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2] + _logits_eager[-1:, 2:] = logits_eager[-1:, :-2] + + logits_sdpa = _logits_sdpa + logits_eager = _logits_eager + + results = [ + torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) + for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) + ] + # If 80% batch elements have matched results, it's fine + if np.mean(results) < 0.8: + fail_cases.append( + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) + ) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + @require_torch class EsmModelIntegrationTest(TestCasePlus):