Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support batched input for decoder start ids #28887

Merged
merged 15 commits into from
Feb 8, 2024
7 changes: 5 additions & 2 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,11 @@ class GenerationConfig(PushToHubMixin):
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
`decoder_input_ids`.
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
decoder_start_token_id (`Union[int, List[int]]`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token or a list of length
`batch_size`. Indicating a list enables different start ids for each element in the batch
(e.g. multilingual models with different target languages in one batch)


> Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192)

Expand Down
26 changes: 22 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def _prepare_decoder_input_ids_for_generation(
batch_size: int,
model_input_name: str,
model_kwargs: Dict[str, torch.Tensor],
decoder_start_token_id: int = None,
decoder_start_token_id: Union[int, List[int]] = None,
bos_token_id: int = None,
device: torch.device = None,
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
Expand All @@ -515,7 +515,17 @@ def _prepare_decoder_input_ids_for_generation(
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
if device is None:
device = self.device
decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
if isinstance(decoder_start_token_id, list):
if len(decoder_start_token_id) != batch_size:
raise ValueError(
f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}"
)
decoder_input_ids_start = torch.tensor(decoder_start_token_id, dtype=torch.long, device=device)
decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
else:
decoder_input_ids_start = (
torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
)

# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:
Expand All @@ -527,7 +537,13 @@ def _prepare_decoder_input_ids_for_generation(
pass
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item():
elif (
isinstance(decoder_start_token_id, int)
and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
) or (
isinstance(decoder_start_token_id, torch.Tensor)
and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
):
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
Expand All @@ -539,7 +555,9 @@ def _prepare_decoder_input_ids_for_generation(

return decoder_input_ids, model_kwargs

def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
Expand Down
20 changes: 20 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3163,6 +3163,26 @@ def test_constrained_beam_search_mixin_type_checks(self):
with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]])

def test_batched_decoder_start_id(self):
# PT-only test: TF doesn't support batched_decoder_start_id
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
decoder_start_token_id = bart_model.generation_config.decoder_start_token_id
decoder_start_token_id_batch = [decoder_start_token_id] * input_ids.shape[0]

outputs = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id)

outputs_batched_ids = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id_batch)

self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist())

def test_contrastive_search_batched(self):
# PT-only test: TF doesn't have constrained beam search
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
Expand Down
Loading