Skip to content

Commit

Permalink
added test (a few models need fixes)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed May 14, 2024
1 parent f35d62f commit e210081
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,12 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de
"""Performs validation related to the resulting generated length"""

# 1. Max length warnings related to poor parameterization
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
if (
not is_torchdynamo_compiling()
and has_default_max_length
and generation_config.max_new_tokens is None
and generation_config.max_length == 20
):
# 20 is the default max_length of the generation config
warnings.warn(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
Expand Down
26 changes: 26 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_flaky,
require_accelerate,
require_torch,
require_torch_gpu,
require_torch_multi_accelerator,
slow,
torch_device,
Expand Down Expand Up @@ -1652,6 +1653,31 @@ def test_new_cache_format(self, num_beams, do_sample):
)
)

@require_torch_gpu
@slow
def test_generate_compile_fullgraph(self):
"""Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
continue

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(torch_device)
input_ids = inputs_dict["input_ids"].to(torch_device)

# dynamic cache
output_dynamic = model.generate(input_ids)

# eager static cache
model.generation_config.cache_implementation = "static"
output_static = model.generate(input_ids)
self.assertListEqual(output_dynamic.tolist(), output_static.tolist())

# compiled static cache
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
output_compiled = compiled_generate(input_ids)
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())

def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences
Expand Down

0 comments on commit e210081

Please sign in to comment.