diff --git a/tests/models/esm/test_modeling_esmfold.py b/tests/models/esm/test_modeling_esmfold.py index 3c3d0865e5af71..5f1c03bf89e268 100644 --- a/tests/models/esm/test_modeling_esmfold.py +++ b/tests/models/esm/test_modeling_esmfold.py @@ -14,12 +14,16 @@ # limitations under the License. """Testing suite for the PyTorch ESM model.""" +import inspect +import tempfile import unittest +import numpy as np from parameterized import parameterized from transformers import EsmConfig, is_torch_available from transformers.testing_utils import TestCasePlus, require_torch, require_torch_sdpa, slow, torch_device +from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -268,11 +272,202 @@ def test_torchscript_simple(self): def test_multi_gpu_data_parallel_forward(self): pass + # Modified from test_modeling_common.py as ESMFold doesn't support output hidden states @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa - @unittest.skip("ESMFold doesn't support output hidden states in normal way which is required in this test.") - def test_eager_matches_sdpa_inference(self): - pass + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead. + if torch_dtype == "float16": + torch_dtype = torch.float16 + elif torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "float32": + torch_dtype = torch.float32 + + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + + def get_mean_reldiff(failcase, x, ref, atol, rtol): + return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + config.layer_norm_eps = 1.0 + + for model_class in self.all_model_classes: + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # Note: the half precision will only be applied to backbone model + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + model_eager = model_class.from_pretrained( + tmpdirname, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + model_sdpa.esm.to(torch_dtype) + model_eager.esm.to(torch_dtype) + + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model, + # but it would be nicer to have an efficient way to use parameterized.expand + fail_cases = [] + for padding_side in ["left", "right"]: + for use_mask in [False, True]: + # TODO: if we can also check with `batch_size=1` without being flaky? + for batch_size in [7]: + dummy_input = inputs_dict[model.main_input_name] + + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + dummy_input = dummy_input.to(torch_dtype) + + dummy_input = dummy_input[:batch_size] + if dummy_input.shape[0] != batch_size: + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + extension = torch.rand( + batch_size - dummy_input.shape[0], + *dummy_input.shape[1:], + dtype=torch_dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + else: + extension = torch.randint( + high=5, + size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), + dtype=dummy_input.dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + + if not use_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + seqlen = dummy_input.shape[-1] + dummy_attention_mask = ( + torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + ) + + dummy_attention_mask = dummy_attention_mask[:batch_size] + if dummy_attention_mask.shape[0] != batch_size: + extension = torch.ones( + batch_size - dummy_attention_mask.shape[0], + *dummy_attention_mask.shape[1:], + dtype=dummy_attention_mask.dtype, + device=torch_device, + ) + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + dummy_attention_mask = dummy_attention_mask.to(torch_device) + + dummy_attention_mask[:] = 1 + if padding_side == "left": + dummy_attention_mask[-1, :2] = 0 + dummy_attention_mask[-1, 2:] = 1 + elif padding_side == "right": + dummy_attention_mask[-1, -2:] = 0 + dummy_attention_mask[-1, :-2] = 1 + + for enable_kernels in [False, True]: + failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}" + processed_inputs = { + model.main_input_name: dummy_input, + } + + # Otherwise fails for e.g. WhisperEncoderModel + if "attention_mask" in inspect.signature(model_eager.forward).parameters: + processed_inputs["attention_mask"] = dummy_attention_mask + + # TODO: test gradients as well (& for FA2 as well!) + with torch.no_grad(): + with torch.backends.cuda.sdp_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) + + logits_eager = outputs_eager.lm_logits + logits_sdpa = outputs_sdpa.lm_logits + + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_mask: + _logits_sdpa = torch.zeros_like(input=logits_sdpa) + _logits_eager = torch.zeros_like(input=logits_eager) + + _logits_sdpa[:-1] = logits_sdpa[:-1] + _logits_eager[:-1] = logits_eager[:-1] + + if padding_side == "left": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:] + _logits_eager[-1:, 2:] = logits_eager[-1:, 2:] + + elif padding_side == "right": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2] + _logits_eager[-1:, 2:] = logits_eager[-1:, :-2] + + logits_sdpa = _logits_sdpa + logits_eager = _logits_eager + + results = [ + torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) + for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) + ] + # If 80% batch elements have matched results, it's fine + if np.mean(results) < 0.8: + fail_cases.append( + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) + ) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) @require_torch