diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 08590219561531..735431fe6fec80 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -148,6 +148,11 @@ def __init__( self.generation_config.return_dict_in_generate = True self.generation_config.output_scores = True + # avoid unnecessary warnings that min_length is larger than max_new_tokens + self.main_model_min_length = self.generation_config.min_length + self.generation_config.min_length = 0 + self.generation_config.min_new_tokens = None + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ Fetches the candidates to be tried for the current input. @@ -166,6 +171,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # Don't generate more than `max_length - 1` candidates since the target model generates one extra token. new_cur_len = input_ids.shape[-1] max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) + min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) if max_new_tokens == 0: return input_ids, None @@ -186,6 +192,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # 2. Forecast next N tokens using the assistant model. assistant_generation_kwargs = { self.input_ids_key: input_ids, + "min_new_tokens": min_new_tokens, "max_new_tokens": max_new_tokens, "generation_config": self.generation_config, "logits_processor": self.logits_processor, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index cb3ac0ff1d121c..36e62794a435a1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1173,6 +1173,56 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de UserWarning, ) + def _prepare_generated_length( + self, + generation_config, + has_default_max_length, + has_default_min_length, + model_input_name, + input_ids_length, + inputs_tensor, + ): + """Prepared max and min length in generaion configs to avoid clashes between similar attributes""" + + if generation_config.max_new_tokens is not None: + if not has_default_max_length and generation_config.max_length is not None: + logger.warning( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_length + + # if both `inputs_embeds` and `input_ids` are passed, we do not correct the length + # otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length`` + elif ( + model_input_name == "inputs_embeds" + and input_ids_length != inputs_tensor.shape[1] + and not self.config.is_encoder_decoder + ): + generation_config.max_length -= inputs_tensor.shape[1] + + # same for min length + if generation_config.min_new_tokens is not None: + if not has_default_min_length: + logger.warning( + f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(=" + f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.min_length = generation_config.min_new_tokens + input_ids_length + + elif ( + model_input_name == "inputs_embeds" + and input_ids_length != inputs_tensor.shape[1] + and not self.config.is_encoder_decoder + ): + generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0) + + return generation_config + def _prepare_generation_config( self, generation_config: GenerationConfig, **kwargs: Dict ) -> Tuple[GenerationConfig, Dict]: @@ -1418,24 +1468,15 @@ def generate( # 6. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if generation_config.max_new_tokens is not None: - if not has_default_max_length and generation_config.max_length is not None: - logger.warning( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) - generation_config.max_length = generation_config.max_new_tokens + input_ids_length - - # otherwise the total length [inputs-embeds-len + new-tokens-len] will go beyond indicated `max_length`` - elif ( - model_input_name == "inputs_embeds" - and inputs_tensor.shape[:-1] != input_ids.shape - and not self.config.is_encoder_decoder - ): - generation_config.max_length -= inputs_tensor.shape[1] - generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0) + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if generation_config.cache_implementation == "static": @@ -1511,7 +1552,7 @@ def generate( ) # 12. run assisted generate - result = self.assisted_decoding( + result = self._assisted_decoding( input_ids, candidate_generator=candidate_generator, do_sample=generation_config.do_sample, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b346b745d8bcbe..d6b4840c4910c9 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1977,6 +1977,20 @@ def test_max_length_if_input_embeds(self): out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_length=max_length) self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1]) + def test_min_length_if_input_embeds(self): + # PT-only test: TF doesn't have StoppingCriteria + article = "Today a dragon flew over Paris." + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + inputs_embeds = model.get_input_embeddings()(input_ids) + + min_length = 10 + input_len = input_ids.shape[-1] + out_gen = model.generate(input_ids=input_ids, min_length=min_length) + out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, min_length=min_length) + self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1]) + def test_custom_stopping_criteria_overload_error(self): # PT-only test: TF doesn't have StoppingCriteria article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -2539,6 +2553,56 @@ def test_default_max_length_warning(self): model.generate(input_ids) self.assertEqual(len(warning_list), 0) + def test_length_warning_assisted_generation(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model.config.pad_token_id = tokenizer.eos_token_id + assistant.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # This should not raise any warning that min length is not feasible in candidate generation + with warnings.catch_warnings(record=True) as warning_list: + model.generate( + input_ids, + assistant_model=assistant, + min_new_tokens=10, + max_length=20, + ) + self.assertEqual(len(warning_list), 0) + + def test_generated_length_assisted_generation(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model.config.pad_token_id = tokenizer.eos_token_id + assistant.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + input_length = input_ids.shape[-1] + + out = model.generate( + input_ids, + assistant_model=assistant, + min_new_tokens=10, + max_new_tokens=20, + ) + self.assertTrue((10 + input_length) <= out.shape[-1] <= (20 + input_length)) + + out = model.generate( + input_ids, + assistant_model=assistant, + min_new_tokens=10, + ) + self.assertTrue((input_length + 10) <= out.shape[-1] <= 20) + def test_model_kwarg_assisted_decoding_decoder_only(self): # PT-only test: TF doesn't support assisted decoding yet. model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)