diff --git a/tests/models/roberta/test_modeling_roberta.py b/tests/models/roberta/test_modeling_roberta.py index ca557937803cff..1c128513b17d13 100644 --- a/tests/models/roberta/test_modeling_roberta.py +++ b/tests/models/roberta/test_modeling_roberta.py @@ -16,7 +16,7 @@ import unittest -from transformers import RobertaConfig, is_torch_available +from transformers import AutoTokenizer, RobertaConfig, is_torch_available from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -41,6 +41,7 @@ RobertaEmbeddings, create_position_ids_from_input_ids, ) + from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 ROBERTA_TINY = "sshleifer/tiny-distilroberta-base" @@ -576,3 +577,43 @@ def test_inference_classification_head(self): # expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach() self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4)) + + @slow + def test_export(self): + if not is_torch_greater_or_equal_than_2_4: + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + roberta_model = "FacebookAI/roberta-base" + device = "cpu" + attn_implementation = "sdpa" + max_length = 512 + + tokenizer = AutoTokenizer.from_pretrained(roberta_model) + inputs = tokenizer( + "The goal of life is .", + return_tensors="pt", + padding="max_length", + max_length=max_length, + ) + + model = RobertaForMaskedLM.from_pretrained( + roberta_model, + device_map=device, + attn_implementation=attn_implementation, + use_cache=True, + ) + + logits = model(**inputs).logits + eager_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices) + self.assertEqual(eager_predicted_mask.split(), ["happiness", "love", "peace", "freedom", "simplicity"]) + + exported_program = torch.export.export( + model, + args=(inputs["input_ids"],), + kwargs={"attention_mask": inputs["attention_mask"]}, + strict=True, + ) + + result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"]) + exported_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices) + self.assertEqual(eager_predicted_mask, exported_predicted_mask)