diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c29a15efd33342..bb2b17f8e5288e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -122,7 +122,7 @@ from torch import nn from transformers import MODEL_MAPPING, AdaptiveEmbedding - from transformers.cache_utils import DynamicCache + from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_utils import load_state_dict, no_init_weights from transformers.pytorch_utils import id_tensor_storage @@ -1109,7 +1109,14 @@ def _create_and_check_torchscript(self, config, inputs_dict): attention_mask = inputs["attention_mask"] decoder_input_ids = inputs["decoder_input_ids"] decoder_attention_mask = inputs["decoder_attention_mask"] - model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask) + outputs = model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask) + # `torchscript` doesn't work with outputs containing `Cache` object. However, #35235 makes + # several models to use `Cache` by default instead of the legacy cache (tuple), and + # their `torchscript` tests are failing. We won't support them anyway, but we still want to keep + # the tests for encoder models like `BERT`. So we skip the checks if the model's output contains + # a `Cache` object. + if any(isinstance(x, Cache) for x in outputs): + continue traced_model = torch.jit.trace( model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask) ) @@ -1117,14 +1124,18 @@ def _create_and_check_torchscript(self, config, inputs_dict): input_ids = inputs["input_ids"] bbox = inputs["bbox"] image = inputs["image"].tensor - model(input_ids, bbox, image) + outputs = model(input_ids, bbox, image) + if any(isinstance(x, Cache) for x in outputs): + continue traced_model = torch.jit.trace( model, (input_ids, bbox, image), check_trace=False ) # when traced model is checked, an error is produced due to name mangling elif "bbox" in inputs: # Bros requires additional inputs (bbox) input_ids = inputs["input_ids"] bbox = inputs["bbox"] - model(input_ids, bbox) + outputs = model(input_ids, bbox) + if any(isinstance(x, Cache) for x in outputs): + continue traced_model = torch.jit.trace( model, (input_ids, bbox), check_trace=False ) # when traced model is checked, an error is produced due to name mangling @@ -1134,7 +1145,9 @@ def _create_and_check_torchscript(self, config, inputs_dict): pixel_values = inputs["pixel_values"] prompt_pixel_values = inputs["prompt_pixel_values"] prompt_masks = inputs["prompt_masks"] - model(pixel_values, prompt_pixel_values, prompt_masks) + outputs = model(pixel_values, prompt_pixel_values, prompt_masks) + if any(isinstance(x, Cache) for x in outputs): + continue traced_model = torch.jit.trace( model, (pixel_values, prompt_pixel_values, prompt_masks), check_trace=False ) # when traced model is checked, an error is produced due to name mangling @@ -1149,11 +1162,15 @@ def _create_and_check_torchscript(self, config, inputs_dict): else: self.skipTest(reason="testing SDPA without attention_mask is not supported") - model(main_input, attention_mask=inputs["attention_mask"]) + outputs = model(main_input, attention_mask=inputs["attention_mask"]) + if any(isinstance(x, Cache) for x in outputs): + continue # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1. traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input) else: - model(main_input) + outputs = model(main_input) + if any(isinstance(x, Cache) for x in outputs): + continue traced_model = torch.jit.trace(model, (main_input,)) except RuntimeError: self.fail("Couldn't trace module.")