From 510270af34994c24f2f8792d2afd74114c70d736 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 30 Nov 2023 14:16:31 +0000 Subject: [PATCH] Generate: `GenerationConfig` throws an exception when `generate` args are passed (#27757) --- .../generation/configuration_utils.py | 18 ++++++++++++ tests/generation/test_configuration_utils.py | 28 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index cb9240d3bf3322..7a494bf3733a0e 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -497,6 +497,24 @@ def validate(self, is_init=False): f"({self.num_beams})." ) + # 5. check common issue: passing `generate` arguments inside the generation config + generate_arguments = ( + "logits_processor", + "stopping_criteria", + "prefix_allowed_tokens_fn", + "synced_gpus", + "assistant_model", + "streamer", + "negative_prompt_ids", + "negative_prompt_attention_mask", + ) + for arg in generate_arguments: + if hasattr(self, arg): + raise ValueError( + f"Argument `{arg}` is not a valid argument of `GenerationConfig`. It should be passed to " + "`generate()` (or a pipeline) directly." + ) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index a181b00ee08d2c..e5eb1bb34cc0dc 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -120,6 +120,34 @@ def test_kwarg_init(self): self.assertEqual(loaded_config.do_sample, True) self.assertEqual(loaded_config.num_beams, 1) # default value + def test_validate(self): + """ + Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time + """ + # Case 1: A correct configuration will not throw any warning + with warnings.catch_warnings(record=True) as captured_warnings: + GenerationConfig() + self.assertEqual(len(captured_warnings), 0) + + # Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling + # parameters with `do_sample=False`). May be escalated to an error in the future. + with warnings.catch_warnings(record=True) as captured_warnings: + GenerationConfig(temperature=0.5) + self.assertEqual(len(captured_warnings), 1) + + # Case 3: Impossible sets of contraints/parameters will raise an exception + with self.assertRaises(ValueError): + GenerationConfig(num_return_sequences=2) + + # Case 4: Passing `generate()`-only flags to `validate` will raise an exception + with self.assertRaises(ValueError): + GenerationConfig(logits_processor="foo") + + # Case 5: Model-specific parameters will NOT raise an exception or a warning + with warnings.catch_warnings(record=True) as captured_warnings: + GenerationConfig(foo="bar") + self.assertEqual(len(captured_warnings), 0) + def test_refuse_to_save(self): """Tests that we refuse to save a generation config that fails validation."""