From 8430293337b98f20a03f905523643d869a3ae7e1 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 8 Feb 2024 21:00:53 +0500 Subject: [PATCH] Support batched input for decoder start ids (#28887) * support batched input for decoder start ids * Fix typos Co-authored-by: Joao Gante * minor changes * fix: decoder_start_id as list * empty commit * empty commit * empty commit * empty commit * empty commit * empty commit * empty commit * empty commit * empty commit --------- Co-authored-by: Joao Gante --- .../generation/configuration_utils.py | 7 +++-- src/transformers/generation/utils.py | 26 ++++++++++++++++--- tests/generation/test_utils.py | 20 ++++++++++++++ 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 69e1afe63c2e9b..4c3cdc12a44993 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -233,8 +233,11 @@ class GenerationConfig(PushToHubMixin): encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`. - decoder_start_token_id (`int`, *optional*): - If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. + decoder_start_token_id (`Union[int, List[int]]`, *optional*): + If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token or a list of length + `batch_size`. Indicating a list enables different start ids for each element in the batch + (e.g. multilingual models with different target languages in one batch) + > Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1405425e623827..0bbdd643421996 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -501,7 +501,7 @@ def _prepare_decoder_input_ids_for_generation( batch_size: int, model_input_name: str, model_kwargs: Dict[str, torch.Tensor], - decoder_start_token_id: int = None, + decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None, device: torch.device = None, ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: @@ -519,7 +519,17 @@ def _prepare_decoder_input_ids_for_generation( decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) if device is None: device = self.device - decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + if isinstance(decoder_start_token_id, list): + if len(decoder_start_token_id) != batch_size: + raise ValueError( + f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}" + ) + decoder_input_ids_start = torch.tensor(decoder_start_token_id, dtype=torch.long, device=device) + decoder_input_ids_start = decoder_input_ids_start.view(-1, 1) + else: + decoder_input_ids_start = ( + torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + ) # no user input -> use decoder_start_token_id as decoder_input_ids if decoder_input_ids is None: @@ -531,7 +541,13 @@ def _prepare_decoder_input_ids_for_generation( pass # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # decoder_attention_mask if provided) - elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): + elif ( + isinstance(decoder_start_token_id, int) + and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item() + ) or ( + isinstance(decoder_start_token_id, torch.Tensor) + and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item() + ): decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] @@ -543,7 +559,9 @@ def _prepare_decoder_input_ids_for_generation( return decoder_input_ids, model_kwargs - def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: + def _get_decoder_start_token_id( + self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None + ) -> int: decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 855187778d2cf0..4a13487cf8935d 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3163,6 +3163,26 @@ def test_constrained_beam_search_mixin_type_checks(self): with self.assertRaises(ValueError): model.generate(input_ids, force_words_ids=[[[-1]]]) + def test_batched_decoder_start_id(self): + # PT-only test: TF doesn't support batched_decoder_start_id + articles = [ + "Justin Timberlake and Jessica Biel, welcome to parenthood.", + "Michael Phelps is arguably the most decorated Olympian of all time.", + ] + bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) + input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) + decoder_start_token_id = bart_model.generation_config.decoder_start_token_id + decoder_start_token_id_batch = [decoder_start_token_id] * input_ids.shape[0] + + outputs = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id) + + outputs_batched_ids = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id_batch) + + self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist()) + def test_contrastive_search_batched(self): # PT-only test: TF doesn't have constrained beam search # Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)