From e210081a6be504dd6b93c0d0853ea17731b5142b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 13 May 2024 17:26:14 +0000 Subject: [PATCH] added test (a few models need fixes) --- src/transformers/generation/utils.py | 7 ++++++- tests/generation/test_utils.py | 26 ++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 816b05cdb66f6c..6349eb67e75a3d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1163,7 +1163,12 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de """Performs validation related to the resulting generated length""" # 1. Max length warnings related to poor parameterization - if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: + if ( + not is_torchdynamo_compiling() + and has_default_max_length + and generation_config.max_new_tokens is None + and generation_config.max_length == 20 + ): # 20 is the default max_length of the generation config warnings.warn( f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cf703d8a22317b..c605a165840464 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -28,6 +28,7 @@ is_flaky, require_accelerate, require_torch, + require_torch_gpu, require_torch_multi_accelerator, slow, torch_device, @@ -1652,6 +1653,31 @@ def test_new_cache_format(self, num_beams, do_sample): ) ) + @require_torch_gpu + @slow + def test_generate_compile_fullgraph(self): + """Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results""" + for model_class in self.all_generative_model_classes: + if not model_class._supports_cache_class: + continue + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(torch_device) + input_ids = inputs_dict["input_ids"].to(torch_device) + + # dynamic cache + output_dynamic = model.generate(input_ids) + + # eager static cache + model.generation_config.cache_implementation = "static" + output_static = model.generate(input_ids) + self.assertListEqual(output_dynamic.tolist(), output_static.tolist()) + + # compiled static cache + compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") + output_compiled = compiled_generate(input_ids) + self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist()) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences