From eb5b968c5d80271ecb29917dffecc8f4c00247a8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 31 Aug 2024 10:47:08 +0100 Subject: [PATCH] Generate: throw warning when `return_dict_in_generate` is False but should be True (#33146) --- .../generation/configuration_utils.py | 26 ++++++++++++++++--- tests/generation/test_configuration_utils.py | 4 +++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 160a8a7eae2dba..af62d0c797514c 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -288,7 +288,9 @@ class GenerationConfig(PushToHubMixin): Whether or not to return the unprocessed prediction logit scores. See `logits` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + Whether or not to return a [`~utils.ModelOutput`], as opposed to returning exclusively the generated + sequence. This flag must be set to `True` to return the generation cache (when `use_cache` is `True`) + or optional outputs (see flags starting with `output_`) > Special tokens that can be used at generation time @@ -334,6 +336,8 @@ class GenerationConfig(PushToHubMixin): present in `generate`'s signature will be used in the model forward pass. """ + extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits") + def __init__(self, **kwargs): # Parameters that control the length of the output self.max_length = kwargs.pop("max_length", 20) @@ -727,7 +731,17 @@ def validate(self, is_init=False): self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config) self.watermarking_config.validate() - # 7. check common issue: passing `generate` arguments inside the generation config + # 7. other incorrect combinations + if self.return_dict_in_generate is not True: + for extra_output_flag in self.extra_output_flags: + if getattr(self, extra_output_flag) is True: + warnings.warn( + f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When " + f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored.", + UserWarning, + ) + + # 8. check common issue: passing `generate` arguments inside the generation config generate_arguments = ( "logits_processor", "stopping_criteria", @@ -786,7 +800,8 @@ def save_pretrained( if use_auth_token is not None: warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. " + "Please use `token` instead.", FutureWarning, ) if kwargs.get("token", None) is not None: @@ -1189,6 +1204,11 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig" if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr): setattr(config, attr, decoder_config[attr]) + # If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`. + if config.return_dict_in_generate is False: + if any(getattr(config, extra_output_flag, False) for extra_output_flag in config.extra_output_flags): + config.return_dict_in_generate = True + config._original_object_hash = hash(config) # Hash to detect whether the instance was modified return config diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index b22b7eebf0080b..cd5f3d50162c45 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -136,6 +136,10 @@ def test_validate(self): GenerationConfig(do_sample=False, temperature=0.5) self.assertEqual(len(captured_warnings), 1) + with warnings.catch_warnings(record=True) as captured_warnings: + GenerationConfig(return_dict_in_generate=False, output_scores=True) + self.assertEqual(len(captured_warnings), 1) + # Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally, # that is done by unsetting the parameter (i.e. setting it to None) generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5)