From 221f3355d594294c925e02d9e941253d7a7fe21a Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 6 Feb 2024 10:02:39 +0100 Subject: [PATCH 01/13] support batched input for decoder start ids --- .../generation/configuration_utils.py | 7 +++++-- src/transformers/generation/utils.py | 16 +++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 17b8875a40195d..a4a7c75f225442 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -233,8 +233,11 @@ class GenerationConfig(PushToHubMixin): encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`. - decoder_start_token_id (`int`, *optional*): - If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. + decoder_start_token_id (`Union[int, torch.Tensor`, *optional*): + If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token or a tensor with shape + `(batch_size, 1)`. Indicating a tensor enables different start ids for each element in the batch + (e.g. multilingual models with different target languages in one batch) + > Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 622d6731776876..bec1de9f0ab95d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -497,7 +497,7 @@ def _prepare_decoder_input_ids_for_generation( batch_size: int, model_input_name: str, model_kwargs: Dict[str, torch.Tensor], - decoder_start_token_id: int = None, + decoder_start_token_id: Union[int, torch.Tensor] = None, bos_token_id: int = None, device: torch.device = None, ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: @@ -515,6 +515,8 @@ def _prepare_decoder_input_ids_for_generation( decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) if device is None: device = self.device + if isinstance(decoder_start_token_id, torch.Tensor) and decoder_start_token_id.shape != (batch_size, 1): + raise ValueError("decoder_start_token_id` has to be shape (batch_size, 1) when passed as a torch.Tensor") decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id # no user input -> use decoder_start_token_id as decoder_input_ids @@ -527,7 +529,13 @@ def _prepare_decoder_input_ids_for_generation( pass # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # decoder_attention_mask if provided) - elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): + if ( + isinstance(decoder_start_token_id, int) + and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item() + ) or ( + isinstance(decoder_start_token_id, torch.Tensor) + and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item() + ): decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] @@ -539,7 +547,9 @@ def _prepare_decoder_input_ids_for_generation( return decoder_input_ids, model_kwargs - def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: + def _get_decoder_start_token_id( + self, decoder_start_token_id: Union[int, torch.Tensor] = None, bos_token_id: int = None + ) -> int: decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None From b5364e4044da2f5abef58d4646c2bb83198a01e5 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 7 Feb 2024 16:26:42 +0500 Subject: [PATCH 02/13] Fix typos Co-authored-by: Joao Gante --- src/transformers/generation/configuration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index a4a7c75f225442..4e4e8f135f5201 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -233,7 +233,7 @@ class GenerationConfig(PushToHubMixin): encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`. - decoder_start_token_id (`Union[int, torch.Tensor`, *optional*): + decoder_start_token_id (`Union[int, torch.Tensor]`, *optional*): If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token or a tensor with shape `(batch_size, 1)`. Indicating a tensor enables different start ids for each element in the batch (e.g. multilingual models with different target languages in one batch) From c69764038167fe045c71dab028ffa0135b392b44 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 7 Feb 2024 12:29:08 +0100 Subject: [PATCH 03/13] minor changes --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index bec1de9f0ab95d..8499962eec47a7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -529,7 +529,7 @@ def _prepare_decoder_input_ids_for_generation( pass # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # decoder_attention_mask if provided) - if ( + elif ( isinstance(decoder_start_token_id, int) and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item() ) or ( From bb4e24d326ca8f547c3092a775f8df4027ccf54f Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 7 Feb 2024 21:05:40 +0100 Subject: [PATCH 04/13] fix: decoder_start_id as list --- .../generation/configuration_utils.py | 6 +++--- src/transformers/generation/utils.py | 18 ++++++++++++----- tests/generation/test_utils.py | 20 +++++++++++++++++++ 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 4e4e8f135f5201..8743cf904011b2 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -233,9 +233,9 @@ class GenerationConfig(PushToHubMixin): encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`. - decoder_start_token_id (`Union[int, torch.Tensor]`, *optional*): - If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token or a tensor with shape - `(batch_size, 1)`. Indicating a tensor enables different start ids for each element in the batch + decoder_start_token_id (`Union[int, List[int]]`, *optional*): + If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token or a list of length + `batch_size`. Indicating a list enables different start ids for each element in the batch (e.g. multilingual models with different target languages in one batch) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8499962eec47a7..8cd6b3849affe2 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -497,7 +497,7 @@ def _prepare_decoder_input_ids_for_generation( batch_size: int, model_input_name: str, model_kwargs: Dict[str, torch.Tensor], - decoder_start_token_id: Union[int, torch.Tensor] = None, + decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None, device: torch.device = None, ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: @@ -515,9 +515,17 @@ def _prepare_decoder_input_ids_for_generation( decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) if device is None: device = self.device - if isinstance(decoder_start_token_id, torch.Tensor) and decoder_start_token_id.shape != (batch_size, 1): - raise ValueError("decoder_start_token_id` has to be shape (batch_size, 1) when passed as a torch.Tensor") - decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + if isinstance(decoder_start_token_id, list): + if len(decoder_start_token_id) != batch_size: + raise ValueError( + f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}" + ) + decoder_input_ids_start = torch.tensor(decoder_start_token_id, dtype=torch.long, device=device) + decoder_input_ids_start = decoder_input_ids_start.view(-1, 1) + else: + decoder_input_ids_start = ( + torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + ) # no user input -> use decoder_start_token_id as decoder_input_ids if decoder_input_ids is None: @@ -548,7 +556,7 @@ def _prepare_decoder_input_ids_for_generation( return decoder_input_ids, model_kwargs def _get_decoder_start_token_id( - self, decoder_start_token_id: Union[int, torch.Tensor] = None, bos_token_id: int = None + self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None ) -> int: decoder_start_token_id = ( decoder_start_token_id diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 855187778d2cf0..4a13487cf8935d 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3163,6 +3163,26 @@ def test_constrained_beam_search_mixin_type_checks(self): with self.assertRaises(ValueError): model.generate(input_ids, force_words_ids=[[[-1]]]) + def test_batched_decoder_start_id(self): + # PT-only test: TF doesn't support batched_decoder_start_id + articles = [ + "Justin Timberlake and Jessica Biel, welcome to parenthood.", + "Michael Phelps is arguably the most decorated Olympian of all time.", + ] + bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") + bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( + torch_device + ) + input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) + decoder_start_token_id = bart_model.generation_config.decoder_start_token_id + decoder_start_token_id_batch = [decoder_start_token_id] * input_ids.shape[0] + + outputs = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id) + + outputs_batched_ids = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id_batch) + + self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist()) + def test_contrastive_search_batched(self): # PT-only test: TF doesn't have constrained beam search # Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs) From 336fc16cd9baa899026e3c0467c61fa39c27d427 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 8 Feb 2024 12:19:48 +0100 Subject: [PATCH 05/13] empty commit From de3e6041a39981fd1f5207fa860d4b49deb255c4 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 8 Feb 2024 12:41:15 +0100 Subject: [PATCH 06/13] empty commit From 950ad128e0c8de851f6e710db4fcc1ce8c3fdc35 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 8 Feb 2024 13:03:27 +0100 Subject: [PATCH 07/13] empty commit From 62053b3e13e7e870de0664f836ce63a1455011dd Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 8 Feb 2024 13:19:37 +0100 Subject: [PATCH 08/13] empty commit From cd9c913e8c1de8a7d3ebb76abc2f3a9406897553 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 8 Feb 2024 13:35:06 +0100 Subject: [PATCH 09/13] empty commit From 43628109b2c9b433c368f295db3157ba8278fa87 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 8 Feb 2024 13:52:13 +0100 Subject: [PATCH 10/13] empty commit From 1d81b99f1d6759cabeec749f96135485a52d3631 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 8 Feb 2024 14:07:59 +0100 Subject: [PATCH 11/13] empty commit From f67327b3c8e1fbdd9ea300f0d287a6eb51691f73 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 8 Feb 2024 14:30:50 +0100 Subject: [PATCH 12/13] empty commit From 2fcbee3f50a55c3d81fa648aeae234ee43b80093 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 8 Feb 2024 16:28:57 +0100 Subject: [PATCH 13/13] empty commit