Skip to content

Commit

Permalink
Fix length related warnings in speculative decoding (#29585)
Browse files Browse the repository at this point in the history
* avoid generation length warning

* add tests

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Joao Gante <[email protected]>

* add tests and minor fixes

* refine `min_new_tokens`

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Joao Gante <[email protected]>

* add method to prepare length arguments

* add test for min length

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: Joao Gante <[email protected]>

* fix variable naming

* empty commit for tests

* trigger tests (empty)

---------

Co-authored-by: Joao Gante <[email protected]>
  • Loading branch information
2 people authored and ArthurZucker committed Apr 22, 2024
1 parent d2beb63 commit 3260549
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 19 deletions.
7 changes: 7 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ def __init__(
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True

# avoid unnecessary warnings that min_length is larger than max_new_tokens
self.main_model_min_length = self.generation_config.min_length
self.generation_config.min_length = 0
self.generation_config.min_new_tokens = None

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Expand All @@ -166,6 +171,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
if max_new_tokens == 0:
return input_ids, None

Expand All @@ -186,6 +192,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
Expand Down
79 changes: 60 additions & 19 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,56 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de
UserWarning,
)

def _prepare_generated_length(
self,
generation_config,
has_default_max_length,
has_default_min_length,
model_input_name,
input_ids_length,
inputs_tensor,
):
"""Prepared max and min length in generaion configs to avoid clashes between similar attributes"""

if generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length

# if both `inputs_embeds` and `input_ids` are passed, we do not correct the length
# otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length``
elif (
model_input_name == "inputs_embeds"
and input_ids_length != inputs_tensor.shape[1]
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]

# same for min length
if generation_config.min_new_tokens is not None:
if not has_default_min_length:
logger.warning(
f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(="
f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.min_length = generation_config.min_new_tokens + input_ids_length

elif (
model_input_name == "inputs_embeds"
and input_ids_length != inputs_tensor.shape[1]
and not self.config.is_encoder_decoder
):
generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)

return generation_config

def _prepare_generation_config(
self, generation_config: GenerationConfig, **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
Expand Down Expand Up @@ -1418,24 +1468,15 @@ def generate(
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length

# otherwise the total length [inputs-embeds-len + new-tokens-len] will go beyond indicated `max_length``
elif (
model_input_name == "inputs_embeds"
and inputs_tensor.shape[:-1] != input_ids.shape
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]
generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)

if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static":
Expand Down Expand Up @@ -1511,7 +1552,7 @@ def generate(
)

# 12. run assisted generate
result = self.assisted_decoding(
result = self._assisted_decoding(
input_ids,
candidate_generator=candidate_generator,
do_sample=generation_config.do_sample,
Expand Down
64 changes: 64 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1977,6 +1977,20 @@ def test_max_length_if_input_embeds(self):
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_length=max_length)
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])

def test_min_length_if_input_embeds(self):
# PT-only test: TF doesn't have StoppingCriteria
article = "Today a dragon flew over Paris."
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
inputs_embeds = model.get_input_embeddings()(input_ids)

min_length = 10
input_len = input_ids.shape[-1]
out_gen = model.generate(input_ids=input_ids, min_length=min_length)
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, min_length=min_length)
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])

def test_custom_stopping_criteria_overload_error(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
Expand Down Expand Up @@ -2539,6 +2553,56 @@ def test_default_max_length_warning(self):
model.generate(input_ids)
self.assertEqual(len(warning_list), 0)

def test_length_warning_assisted_generation(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)

# This should not raise any warning that min length is not feasible in candidate generation
with warnings.catch_warnings(record=True) as warning_list:
model.generate(
input_ids,
assistant_model=assistant,
min_new_tokens=10,
max_length=20,
)
self.assertEqual(len(warning_list), 0)

def test_generated_length_assisted_generation(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
input_length = input_ids.shape[-1]

out = model.generate(
input_ids,
assistant_model=assistant,
min_new_tokens=10,
max_new_tokens=20,
)
self.assertTrue((10 + input_length) <= out.shape[-1] <= (20 + input_length))

out = model.generate(
input_ids,
assistant_model=assistant,
min_new_tokens=10,
)
self.assertTrue((input_length + 10) <= out.shape[-1] <= 20)

def test_model_kwarg_assisted_decoding_decoder_only(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
Expand Down

0 comments on commit 3260549

Please sign in to comment.