Skip to content

Commit

Permalink
support batched input for decoder start ids
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp committed Feb 6, 2024
1 parent 5346db1 commit 221f335
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
7 changes: 5 additions & 2 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,11 @@ class GenerationConfig(PushToHubMixin):
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
`decoder_input_ids`.
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
decoder_start_token_id (`Union[int, torch.Tensor`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token or a tensor with shape
`(batch_size, 1)`. Indicating a tensor enables different start ids for each element in the batch
(e.g. multilingual models with different target languages in one batch)
> Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192)
Expand Down
16 changes: 13 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def _prepare_decoder_input_ids_for_generation(
batch_size: int,
model_input_name: str,
model_kwargs: Dict[str, torch.Tensor],
decoder_start_token_id: int = None,
decoder_start_token_id: Union[int, torch.Tensor] = None,
bos_token_id: int = None,
device: torch.device = None,
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
Expand All @@ -515,6 +515,8 @@ def _prepare_decoder_input_ids_for_generation(
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
if device is None:
device = self.device
if isinstance(decoder_start_token_id, torch.Tensor) and decoder_start_token_id.shape != (batch_size, 1):
raise ValueError("decoder_start_token_id` has to be shape (batch_size, 1) when passed as a torch.Tensor")
decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id

# no user input -> use decoder_start_token_id as decoder_input_ids
Expand All @@ -527,7 +529,13 @@ def _prepare_decoder_input_ids_for_generation(
pass
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item():
if (
isinstance(decoder_start_token_id, int)
and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
) or (
isinstance(decoder_start_token_id, torch.Tensor)
and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
):
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
Expand All @@ -539,7 +547,9 @@ def _prepare_decoder_input_ids_for_generation(

return decoder_input_ids, model_kwargs

def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, torch.Tensor] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
Expand Down

0 comments on commit 221f335

Please sign in to comment.