Skip to content

Commit

Permalink
Skip torchscript tests if a cache object is in model's outputs (#35596
Browse files Browse the repository at this point in the history
)

* fix 1

* fix 1

* comment

---------

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Jan 10, 2025
1 parent 6b73ee8 commit 6f127d3
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
from torch import nn

from transformers import MODEL_MAPPING, AdaptiveEmbedding
from transformers.cache_utils import DynamicCache
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_utils import load_state_dict, no_init_weights
from transformers.pytorch_utils import id_tensor_storage

Expand Down Expand Up @@ -1109,22 +1109,33 @@ def _create_and_check_torchscript(self, config, inputs_dict):
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
outputs = model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
# `torchscript` doesn't work with outputs containing `Cache` object. However, #35235 makes
# several models to use `Cache` by default instead of the legacy cache (tuple), and
# their `torchscript` tests are failing. We won't support them anyway, but we still want to keep
# the tests for encoder models like `BERT`. So we skip the checks if the model's output contains
# a `Cache` object.
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
image = inputs["image"].tensor
model(input_ids, bbox, image)
outputs = model(input_ids, bbox, image)
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(
model, (input_ids, bbox, image), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
elif "bbox" in inputs: # Bros requires additional inputs (bbox)
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
model(input_ids, bbox)
outputs = model(input_ids, bbox)
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(
model, (input_ids, bbox), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
Expand All @@ -1134,7 +1145,9 @@ def _create_and_check_torchscript(self, config, inputs_dict):
pixel_values = inputs["pixel_values"]
prompt_pixel_values = inputs["prompt_pixel_values"]
prompt_masks = inputs["prompt_masks"]
model(pixel_values, prompt_pixel_values, prompt_masks)
outputs = model(pixel_values, prompt_pixel_values, prompt_masks)
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(
model, (pixel_values, prompt_pixel_values, prompt_masks), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
Expand All @@ -1149,11 +1162,15 @@ def _create_and_check_torchscript(self, config, inputs_dict):
else:
self.skipTest(reason="testing SDPA without attention_mask is not supported")

model(main_input, attention_mask=inputs["attention_mask"])
outputs = model(main_input, attention_mask=inputs["attention_mask"])
if any(isinstance(x, Cache) for x in outputs):
continue
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
else:
model(main_input)
outputs = model(main_input)
if any(isinstance(x, Cache) for x in outputs):
continue
traced_model = torch.jit.trace(model, (main_input,))
except RuntimeError:
self.fail("Couldn't trace module.")
Expand Down

0 comments on commit 6f127d3

Please sign in to comment.