From e8fdd7875def7be59e2c9b823705fbf003163ea0 Mon Sep 17 00:00:00 2001 From: Pavarissy <69553539+pavaris-pm@users.noreply.github.com> Date: Tue, 10 Oct 2023 22:05:48 +0700 Subject: [PATCH 01/11] [docstring] Fix docstring for `LlamaConfig` (#26685) * Your commit message here * fix LlamaConfig docstring * run make fixup * fix formatting after review reformat of the file to prevent script issues * rerun make fixup after reformat --- .../models/llama/configuration_llama.py | 23 +++++++++++-------- utils/check_docstrings.py | 1 - 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 5bebd936d65e15..f3da8ab4cdc242 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -58,11 +58,6 @@ class LlamaConfig(PretrainedConfig): by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. - pretraining_tp (`int`, *optional*, defaults to `1`): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): @@ -70,12 +65,23 @@ class LlamaConfig(PretrainedConfig): Llama 2 up to 4096, CodeLlama up to 16384. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-12): + rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. @@ -87,10 +93,9 @@ class LlamaConfig(PretrainedConfig): these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. - attention_bias (`bool`, defaults to `False`): + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. - Example: ```python >>> from transformers import LlamaModel, LlamaConfig diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 2832e347ab5f1e..e140be28037d59 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -361,7 +361,6 @@ "LevitConfig", "LiltConfig", "LiltModel", - "LlamaConfig", "LlamaTokenizer", "LlamaTokenizerFast", "LongT5Config", From 975003eacb959011a7bb6fc6413904d84de06726 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?th=C3=A9o=20gigant?= <71786646+giganttheo@users.noreply.github.com> Date: Tue, 10 Oct 2023 20:36:32 +0200 Subject: [PATCH 02/11] fix a typo in flax T5 attention - attention_mask variable is misnamed (#26663) * fix a typo in flax t5 attention * fix the typo in flax longt5 attention --- src/transformers/models/longt5/modeling_flax_longt5.py | 2 +- src/transformers/models/t5/modeling_flax_t5.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py index 6b7bc7c28fcf7b..91ca9c72c22c4e 100644 --- a/src/transformers/models/longt5/modeling_flax_longt5.py +++ b/src/transformers/models/longt5/modeling_flax_longt5.py @@ -545,7 +545,7 @@ def __call__( # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_attention_mask = self._concatenate_to_cache( + key_states, value_states, attention_mask = self._concatenate_to_cache( key_states, value_states, query_states, attention_mask ) diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index b2a7181421527c..db4ca90c270ccb 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -405,7 +405,7 @@ def __call__( # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_attention_mask = self._concatenate_to_cache( + key_states, value_states, attention_mask = self._concatenate_to_cache( key_states, value_states, query_states, attention_mask ) From 3eceaa3637197fa78dd3525cb3df57fcaf5ba00d Mon Sep 17 00:00:00 2001 From: jheitmann Date: Tue, 10 Oct 2023 20:49:10 +0200 Subject: [PATCH 03/11] Fix source_prefix default value (#26654) --- examples/pytorch/summarization/run_summarization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 74c83eee49f215..5f20aac6cbc9c8 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -264,7 +264,7 @@ class DataTrainingArguments: }, ) source_prefix: Optional[str] = field( - default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} ) forced_bos_token: Optional[str] = field( From fc63914399b6f60512c720959f9182b02ae4a45c Mon Sep 17 00:00:00 2001 From: Roy Hvaara Date: Tue, 10 Oct 2023 12:35:16 -0700 Subject: [PATCH 04/11] [JAX] Replace uses of `jnp.array` in types with `jnp.ndarray`. (#26703) `jnp.array` is a function, not a type: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Co-authored-by: Peter Hawkins --- .../flax/image-captioning/run_image_captioning_flax.py | 2 +- examples/flax/language-modeling/run_clm_flax.py | 2 +- examples/flax/question-answering/run_qa.py | 2 +- .../run_flax_speech_recognition_seq2seq.py | 2 +- examples/flax/summarization/run_summarization_flax.py | 2 +- examples/flax/text-classification/run_flax_glue.py | 2 +- examples/flax/token-classification/run_flax_ner.py | 2 +- examples/flax/vision/run_image_classification.py | 2 +- .../jax-projects/hybrid_clip/run_hybrid_clip.py | 2 +- .../jax-projects/model_parallel/run_clm_mp.py | 2 +- src/transformers/models/bart/modeling_flax_bart.py | 2 +- src/transformers/models/bert/modeling_flax_bert.py | 2 +- .../models/big_bird/modeling_flax_big_bird.py | 2 +- .../models/blenderbot/modeling_flax_blenderbot.py | 2 +- .../blenderbot_small/modeling_flax_blenderbot_small.py | 2 +- src/transformers/models/electra/modeling_flax_electra.py | 8 ++++---- src/transformers/models/longt5/modeling_flax_longt5.py | 2 +- src/transformers/models/marian/modeling_flax_marian.py | 2 +- src/transformers/models/mt5/modeling_flax_mt5.py | 2 +- src/transformers/models/pegasus/modeling_flax_pegasus.py | 2 +- src/transformers/models/roberta/modeling_flax_roberta.py | 2 +- .../modeling_flax_roberta_prelayernorm.py | 2 +- src/transformers/models/t5/modeling_flax_t5.py | 2 +- .../models/xlm_roberta/modeling_flax_xlm_roberta.py | 2 +- .../modeling_flax_{{cookiecutter.lowercase_modelname}}.py | 2 +- 25 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/flax/image-captioning/run_image_captioning_flax.py b/examples/flax/image-captioning/run_image_captioning_flax.py index bbc79977a46793..d8c89c1a242f14 100644 --- a/examples/flax/image-captioning/run_image_captioning_flax.py +++ b/examples/flax/image-captioning/run_image_captioning_flax.py @@ -381,7 +381,7 @@ def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="t def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index 95e175d494bfe2..c61b24f4d7e615 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -326,7 +326,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py index 9cd90f285a02c4..0d35f302f8f37a 100644 --- a/examples/flax/question-answering/run_qa.py +++ b/examples/flax/question-answering/run_qa.py @@ -389,7 +389,7 @@ def cross_entropy_loss(logits, labels): # region Create learning rate function def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py index 4a2915a31ac7dc..8af835b6a4b4d3 100644 --- a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py +++ b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py @@ -360,7 +360,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def create_learning_rate_fn( num_train_steps: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index d57aa1769135db..782e9ee88c498f 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -409,7 +409,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index b42a2565310799..1535ff8492781b 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -288,7 +288,7 @@ def cross_entropy_loss(logits, labels): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py index f4b40220ff12c7..e06a85cb67c0a0 100644 --- a/examples/flax/token-classification/run_flax_ner.py +++ b/examples/flax/token-classification/run_flax_ner.py @@ -340,7 +340,7 @@ def cross_entropy_loss(logits, labels): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/vision/run_image_classification.py b/examples/flax/vision/run_image_classification.py index 66505014ec57dd..4bed9b663f6f00 100644 --- a/examples/flax/vision/run_image_classification.py +++ b/examples/flax/vision/run_image_classification.py @@ -249,7 +249,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py index f54641408f80a2..c5a4a202534b87 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py @@ -283,7 +283,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py index a6da8729f0ce3b..bb297e3e0db6e4 100644 --- a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py +++ b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py @@ -214,7 +214,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 9858eb2d1bf416..6abfcdc398422f 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -217,7 +217,7 @@ """ -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 99dfa2a0e2f9ab..bb2af0e0602aba 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -295,7 +295,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index afdac2645f2652..c6d8b7c1612ec9 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -316,7 +316,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py index 1035272fd05350..61239335be3b63 100644 --- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py @@ -204,7 +204,7 @@ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py index 2bf8b59e2757bc..b5272fb3bca9e2 100644 --- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py @@ -216,7 +216,7 @@ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 32e76b8b586f4f..8fced6ff1ea25e 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -263,7 +263,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, @@ -1228,13 +1228,13 @@ def __call__(self, hidden_states, cls_index=None, deterministic: bool = True): Compute a single vector summary of a sequence hidden states. Args: - hidden_states (`jnp.array` of shape `[batch_size, seq_len, hidden_size]`): + hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`): The hidden states of the last layer. - cls_index (`jnp.array` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. Returns: - `jnp.array`: The summary of the sequence hidden states. + `jnp.ndarray`: The summary of the sequence hidden states. """ # NOTE: this doest "first" type summary always output = hidden_states[:, 0] diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py index 91ca9c72c22c4e..36e273d5725a4f 100644 --- a/src/transformers/models/longt5/modeling_flax_longt5.py +++ b/src/transformers/models/longt5/modeling_flax_longt5.py @@ -56,7 +56,7 @@ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index a713fdb05dcfd9..5197c906895917 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -227,7 +227,7 @@ def create_sinusoidal_positions(n_pos, dim): # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/mt5/modeling_flax_mt5.py b/src/transformers/models/mt5/modeling_flax_mt5.py index 86ddf477ffab56..0046e02ca7308e 100644 --- a/src/transformers/models/mt5/modeling_flax_mt5.py +++ b/src/transformers/models/mt5/modeling_flax_mt5.py @@ -27,7 +27,7 @@ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py index c5189746b1065f..17772251bf0629 100644 --- a/src/transformers/models/pegasus/modeling_flax_pegasus.py +++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py @@ -210,7 +210,7 @@ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 845fcea4429787..6bc72f12b40790 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -256,7 +256,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py index b7c347693d951b..e98897993742db 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py @@ -258,7 +258,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index db4ca90c270ccb..09575fdcc3b82e 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -56,7 +56,7 @@ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py index f6f39ee93ba687..fb03c390f6f419 100644 --- a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py @@ -266,7 +266,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py index 83263a6a47ef11..63b5d83d308ad1 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py @@ -251,7 +251,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, From 1e3c9ddacc7fc4142253bc9ddcba85c4d5b977e7 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 11 Oct 2023 16:08:54 +0800 Subject: [PATCH 05/11] Make Whisper Encoder's sinusoidal PE non-trainable by default (#26032) * set encoder's PE as non-trainable * freeze flax * init sinusoids * add test for non-trainable embed positions * simplify TF encoder embed_pos * revert tf * clean up * add sinusoidal init for jax * make consistent sinusoidal function * fix dtype * add default dtype * use numpy for sinusoids. fix jax * add sinusoid init for TF * fix * use custom embedding * use specialized init for each impl * fix sinusoids init. add test for pytorch * fix TF dtype * simplify sinusoid init for flax and tf * add tests for TF * change default dtype to float32 * add sinusoid test for flax * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * move sinusoidal init to _init_weights --------- Co-authored-by: sanchit-gandhi Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- .../models/whisper/modeling_flax_whisper.py | 24 ++++++++++++++- .../models/whisper/modeling_tf_whisper.py | 30 +++++++++++++++++-- .../models/whisper/modeling_whisper.py | 17 +++++++++++ .../whisper/test_modeling_flax_whisper.py | 14 +++++++++ .../whisper/test_modeling_tf_whisper.py | 23 +++++++++++++- tests/models/whisper/test_modeling_whisper.py | 16 +++++++++- 6 files changed, 119 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 0f158fb602084a..ffcaeb53ad7153 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -14,6 +14,7 @@ # limitations under the License. """ Flax whisper model.""" +import math import random from functools import partial from typing import Optional, Tuple @@ -58,6 +59,19 @@ remat = nn_partitioning.remat +def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array: + """Returns sinusoids for positional embedding""" + length, channels = shape + if channels % 2 != 0: + raise ValueError( + f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + ) + log_timescale_increment = math.log(10000) / (channels // 2 - 1) + inv_timescales = jnp.exp(-log_timescale_increment * jnp.arange(channels // 2)) + scaled_time = jnp.arange(length).reshape(-1, 1) * inv_timescales.reshape(1, -1) + return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype) + + WHISPER_START_DOCSTRING = r""" This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads @@ -649,7 +663,13 @@ def setup(self) -> None: dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing, ) - self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype) + + self.embed_positions = nn.Embed( + self.config.max_source_positions, + self.config.d_model, + dtype=self.dtype, + embedding_init=sinusoidal_embedding_init, + ) self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) @@ -673,6 +693,8 @@ def __call__( hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions)) + # freeze the sinusoidal embeddings by stopping the back-prop + embed_positions = jax.lax.stop_gradient(embed_positions) hidden_states = hidden_states + embed_positions hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py index 27b6ff63cedacb..1dfe413da2ae46 100644 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ b/src/transformers/models/whisper/modeling_tf_whisper.py @@ -59,6 +59,19 @@ LARGE_NEGATIVE = -1e8 +def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor: + """Returns sinusoids for positional embedding""" + length, channels = shape + if channels % 2 != 0: + raise ValueError( + f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + ) + log_timescale_increment = math.log(10000) / (channels // 2 - 1) + inv_timescales = tf.exp(-log_timescale_increment * tf.range(channels // 2, dtype=tf.float32)) + scaled_time = tf.reshape(tf.range(length, dtype=tf.float32), (-1, 1)) * tf.reshape(inv_timescales, (1, -1)) + return tf.cast(tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1), dtype) + + # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): pad_token_id = tf.cast(pad_token_id, input_ids.dtype) @@ -117,16 +130,25 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): class TFWhisperPositionalEmbedding(tf.keras.layers.Layer): - def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None, **kwargs): + def __init__( + self, + num_positions: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + embedding_initializer=None, + **kwargs, + ): super().__init__(**kwargs) self.num_positions = num_positions self.embedding_dim = embedding_dim self.padding_idx = padding_idx + self.embedding_initializer = tf.keras.initializers.get(embedding_initializer) def build(self, input_shape): self.weight = self.add_weight( name="weight", shape=[self.num_positions, self.embedding_dim], + initializer=self.embedding_initializer, trainable=True, ) super().build(input_shape) @@ -620,8 +642,12 @@ def __init__(self, config: WhisperConfig, **kwargs): self.conv2 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2") self.embed_positions = TFWhisperPositionalEmbedding( - self.max_source_positions, self.embed_dim, name="embed_positions" + num_positions=self.max_source_positions, + embedding_dim=self.embed_dim, + embedding_initializer=sinusoidal_embedding_init, + name="embed_positions", ) + self.embed_positions.trainable = False self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 447d7275d5572d..be5f50dbffa2aa 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -55,6 +55,18 @@ ] +def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: + """Returns sinusoids for positional embedding""" + if channels % 2 != 0: + raise ValueError( + f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." + ) + log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) + return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) + + # Copied from transformers.models.bart.modeling_bart.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ @@ -668,6 +680,10 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, WhisperEncoder): + with torch.no_grad(): + embed_positions = module.embed_positions.weight + embed_positions.copy_(sinusoids(*embed_positions.shape)) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (WhisperDecoder, WhisperEncoder)): @@ -835,6 +851,7 @@ def __init__(self, config: WhisperConfig): self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + self.embed_positions.requires_grad_(False) self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 7ec5f90f0fcd53..982dcb4827a0d1 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -46,6 +46,7 @@ WhisperProcessor, ) from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model + from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init @require_flax @@ -387,6 +388,19 @@ def test_save_load_to_base(self): max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + def test_encoder_sinusoidal_embed_positions(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + params = model.params + if model.base_model_prefix in params: + params = model.params[model.base_model_prefix] + + embeds = params["encoder"]["embed_positions"]["embedding"] + sinusoids = sinusoidal_embedding_init(None, embeds.shape) + self.assertTrue(jax.numpy.allclose(embeds, sinusoids)) + @slow @require_flax diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index 7fae1e466e7a6e..75c62ae1ad07e6 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -42,7 +42,11 @@ import tensorflow as tf from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed - from transformers.models.whisper.modeling_tf_whisper import TFWhisperDecoder, TFWhisperEncoder + from transformers.models.whisper.modeling_tf_whisper import ( + TFWhisperDecoder, + TFWhisperEncoder, + sinusoidal_embedding_init, + ) def prepare_whisper_inputs_dict( @@ -297,6 +301,23 @@ def test_model_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model_forward(*config_and_inputs) + def test_requires_grad_encoder_embed_positions(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + encoder = model.get_encoder() + self.assertFalse(encoder.embed_positions.trainable) + + def test_encoder_sinusoidal_embed_positions(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + model.build() + + embeds = model.get_encoder().embed_positions.get_weights()[0] + sinusoids = sinusoidal_embedding_init(embeds.shape).numpy() + self.assertTrue(np.allclose(embeds, sinusoids)) + def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 9decb7192aee00..337d33485210cd 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -49,7 +49,7 @@ WhisperProcessor, set_seed, ) - from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder + from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids if is_flax_available(): import jax.numpy as jnp @@ -351,6 +351,20 @@ def test_requires_grad_with_frozen_encoder(self): self.assertFalse(all(encoder_grads)) self.assertTrue(all(decoder_grads)) + def test_requires_grad_encoder_embed_positions(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + encoder = model.get_encoder() + self.assertFalse(encoder.embed_positions.weight.requires_grad) + + def test_encoder_sinusoidal_embed_positions(self): + config = self.model_tester.get_config() + for model_class in self.all_model_classes: + model = model_class(config) + embeds = model.get_encoder().embed_positions.weight + self.assertTrue(torch.allclose(embeds, sinusoids(*embeds.shape))) + def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) From dcc49d8a7ef91c5e1baeb4d510ec4f37bc259760 Mon Sep 17 00:00:00 2001 From: Billy Bradley Date: Wed, 11 Oct 2023 12:18:42 +0100 Subject: [PATCH 06/11] In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) (#25242) * In assisted decoding, pass model_kwargs to model's forward call Previously, assisted decoding would ignore any additional kwargs that it doesn't explicitly handle. This was inconsistent with other generation methods, which pass the model_kwargs through prepare_inputs_for_generation and forward the returned dict to the model's forward call. The prepare_inputs_for_generation method needs to be amended in all models, as previously it only kept the last input ID when a past_key_values was passed. * Improve variable names in _extend_attention_mask * Refactor extending token_type_ids into a function * Replace deepcopy with copy to optimize performance * Update new persimmon model with llama changes for assisted generation * Update new mistral model for assisted generation with prepare_inputs_for_generation * Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation --- src/transformers/generation/utils.py | 91 +++++++++++-------- src/transformers/models/bark/modeling_bark.py | 15 ++- src/transformers/models/bart/modeling_bart.py | 22 ++++- src/transformers/models/bert/modeling_bert.py | 11 ++- .../modeling_bert_generation.py | 13 ++- .../models/big_bird/modeling_big_bird.py | 13 ++- .../modeling_bigbird_pegasus.py | 11 ++- .../models/biogpt/modeling_biogpt.py | 15 ++- .../models/blenderbot/modeling_blenderbot.py | 22 ++++- .../modeling_blenderbot_small.py | 22 ++++- .../models/blip/modeling_blip_text.py | 11 ++- .../models/bloom/modeling_bloom.py | 15 ++- .../models/camembert/modeling_camembert.py | 13 ++- .../models/codegen/modeling_codegen.py | 17 +++- src/transformers/models/ctrl/modeling_ctrl.py | 15 ++- .../models/data2vec/modeling_data2vec_text.py | 13 ++- .../open_llama/modeling_open_llama.py | 15 ++- .../models/electra/modeling_electra.py | 13 ++- .../models/ernie/modeling_ernie.py | 11 ++- .../models/falcon/modeling_falcon.py | 13 ++- src/transformers/models/gpt2/modeling_gpt2.py | 35 +++++-- .../gpt_bigcode/modeling_gpt_bigcode.py | 20 +++- .../models/gpt_neo/modeling_gpt_neo.py | 17 +++- .../models/gpt_neox/modeling_gpt_neox.py | 19 +++- src/transformers/models/gptj/modeling_gptj.py | 17 +++- .../models/imagegpt/modeling_imagegpt.py | 17 +++- .../models/llama/modeling_llama.py | 15 ++- .../models/longt5/modeling_longt5.py | 13 ++- .../models/m2m_100/modeling_m2m_100.py | 11 ++- .../models/marian/modeling_marian.py | 22 ++++- .../models/markuplm/modeling_markuplm.py | 11 ++- .../models/mbart/modeling_mbart.py | 22 ++++- .../megatron_bert/modeling_megatron_bert.py | 13 ++- .../models/mistral/modeling_mistral.py | 14 ++- src/transformers/models/mpt/modeling_mpt.py | 15 ++- src/transformers/models/mt5/modeling_mt5.py | 13 ++- .../models/musicgen/modeling_musicgen.py | 12 ++- src/transformers/models/mvp/modeling_mvp.py | 22 ++++- .../models/nllb_moe/modeling_nllb_moe.py | 11 ++- src/transformers/models/opt/modeling_opt.py | 13 ++- .../models/pegasus/modeling_pegasus.py | 22 ++++- .../models/pegasus_x/modeling_pegasus_x.py | 11 ++- .../models/persimmon/modeling_persimmon.py | 15 ++- .../models/pix2struct/modeling_pix2struct.py | 13 ++- .../models/plbart/modeling_plbart.py | 22 ++++- .../models/qdqbert/modeling_qdqbert.py | 13 ++- .../models/rembert/modeling_rembert.py | 13 ++- .../models/roberta/modeling_roberta.py | 13 ++- .../modeling_roberta_prelayernorm.py | 13 ++- .../models/roc_bert/modeling_roc_bert.py | 13 ++- .../models/roformer/modeling_roformer.py | 13 ++- .../modeling_speech_to_text_2.py | 11 ++- .../models/speecht5/modeling_speecht5.py | 11 ++- .../modeling_switch_transformers.py | 13 ++- src/transformers/models/t5/modeling_t5.py | 13 ++- .../models/trocr/modeling_trocr.py | 11 ++- src/transformers/models/umt5/modeling_umt5.py | 13 ++- .../models/whisper/modeling_whisper.py | 12 ++- src/transformers/models/xglm/modeling_xglm.py | 17 +++- .../xlm_roberta/modeling_xlm_roberta.py | 13 ++- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 13 ++- src/transformers/models/xmod/modeling_xmod.py | 13 ++- tests/generation/test_utils.py | 86 ++++++++++++++++++ 63 files changed, 911 insertions(+), 179 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3b1bef6f040084..49b213cc5e92cd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1297,6 +1297,43 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de UserWarning, ) + def _extend_attention_mask(self, model_kwargs: Dict[str, Any], new_mask_length: int) -> Dict[str, Any]: + if self.config.is_encoder_decoder: + key = "decoder_attention_mask" + else: + key = "attention_mask" + + if key not in model_kwargs: + return model_kwargs + + mask = model_kwargs[key] + mask_extension_length = new_mask_length - mask.shape[1] + + if mask_extension_length < 0: + raise ValueError("Cannot extend attention mask to a length less than it already is") + + model_kwargs[key] = torch.cat( + [mask, mask.new_ones((mask.shape[0], mask_extension_length))], + dim=-1, + ) + + return model_kwargs + + def _extend_token_type_ids(self, model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: + if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: + return model_kwargs + + token_type_ids = model_kwargs["token_type_ids"] + final_token_type = token_type_ids[:, -1].unsqueeze(-1) + extension_length = new_length - token_type_ids.shape[1] + token_type_copies = final_token_type.repeat(1, extension_length) + model_kwargs["token_type_ids"] = torch.cat( + [model_kwargs["token_type_ids"], token_type_copies], + dim=-1, + ) + + return model_kwargs + @torch.no_grad() def generate( self, @@ -4441,47 +4478,21 @@ def assisted_decoding( # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, # we use this forward pass to also pick the subsequent logits in the original model. - # 2.1. Run a forward pass on the candidate sequence - if "past_key_values" in model_kwargs: - model_attn = torch.ones_like(candidate_input_ids) - model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] - if self.config.is_encoder_decoder: - outputs = self( - decoder_input_ids=model_input_ids, - decoder_attention_mask=model_attn, - past_key_values=model_kwargs["past_key_values"], - encoder_outputs=model_kwargs["encoder_outputs"], - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=True, - ) - else: - outputs = self( - model_input_ids, - attention_mask=model_attn, - past_key_values=model_kwargs["past_key_values"], - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=True, - ) - else: - if self.config.is_encoder_decoder: - outputs = self( - decoder_input_ids=candidate_input_ids, - encoder_outputs=model_kwargs["encoder_outputs"], - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=True, - ) - else: - outputs = self( - candidate_input_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=True, - ) + # 2.1. Prepare the model inputs + candidate_kwargs = copy.copy(model_kwargs) + candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1]) + candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) + + model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) + + # 2.2. Run a forward pass on the candidate sequence + outputs = self( + **model_inputs, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) - # 2.2. Process the new logits + # 2.3. Process the new logits new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present if len(logits_processor) > 0: for i in range(candidate_length): diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index bdafb6347755d3..649719e0eefa5d 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -483,9 +483,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg position_ids = kwargs.get("position_ids", None) if past_key_values is not None: - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values seq_len = input_ids.shape[1] - input_ids = input_ids[:, [-1]] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # input_embeds have already been used and is not required anymore input_embeds = None @@ -507,7 +516,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 52dfa5e39229f8..9e7763ca23d885 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1443,7 +1443,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1934,7 +1943,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 29846b8051f867..1b0fad3f9d6546 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1282,7 +1282,16 @@ def prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "input_ids": input_ids, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index f245ac155e75ca..abe2d828b28bb9 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -993,9 +993,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 867aca67e99e8c..e266b1a67b7d41 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -2628,9 +2628,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index a32f3ecde76fdb..4e279f9dc059fe 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2627,7 +2627,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 7534ed17fe849a..d1c471aa8090c9 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -729,9 +729,18 @@ def forward( def prepare_inputs_for_generation( self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs ): - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + # only last tokens for inputs_ids if past is defined in kwargs + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index bdb8c52a552041..1db81905210b63 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1392,7 +1392,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1622,7 +1631,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index a1e888aec90807..129de3dd1456e3 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1359,7 +1359,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1589,7 +1598,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 2ae3ac053beab9..49b958afc2ebae 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -920,7 +920,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "input_ids": input_ids, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index d12ec1724f7097..d90bb6ad8fdfd5 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -844,9 +844,18 @@ def prepare_inputs_for_generation( inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> dict: - # only last token for input_ids if past is not None - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + # only last tokens for input_ids if past is not None + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed if past_key_values[0][0].shape[0] == input_ids.shape[0]: diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 4635c061980b53..8d7d279579e3e9 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -1542,9 +1542,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 93d5aa7ee47650..172a45544bac0d 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -617,11 +617,20 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -631,7 +640,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] return { "input_ids": input_ids, diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 70cd4ec0597a14..cec68de07dda75 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -526,9 +526,18 @@ def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs): - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + # only last tokens for inputs_ids if past is defined in kwargs + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 7cbaee692564b4..a521ccb39aaf0c 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -1009,9 +1009,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index c975aa40877c26..6853f5333f137c 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -843,8 +843,17 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -852,7 +861,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index c06d306c1a241d..da3ee8e51d3602 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -1667,9 +1667,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 7ee6f4381290ae..d55155f80093bc 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -1223,7 +1223,16 @@ def prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "input_ids": input_ids, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index ab29322613bea3..33b9fdde739e58 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1228,7 +1228,16 @@ def prepare_inputs_for_generation( **kwargs, ) -> dict: if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. if not self.transformer.use_alibi and attention_mask is not None and position_ids is None: @@ -1236,7 +1245,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] return { "input_ids": input_ids, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 714f0351b3e4df..838e7ca2992520 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1005,11 +1005,20 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -1019,7 +1028,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None @@ -1038,6 +1047,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ "token_type_ids": token_type_ids, } ) + return model_inputs @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @@ -1201,11 +1211,20 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -1215,7 +1234,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index d58e00af1dac13..be90f61e45bf1b 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -737,11 +737,23 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + if self.config.multi_query: + past_length = past_key_values[0].shape[1] + else: + past_length = past_key_values[0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -751,7 +763,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 6364cfc316220a..3ad49554c0ac8f 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -680,11 +680,20 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -694,7 +703,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index b4aa4154459cf7..9391805a77b851 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -808,10 +808,21 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): input_shape = input_ids.shape + print(input_shape) + print(past_key_values[0][0].shape if past_key_values is not None else "no pkv") # cut decoder_input_ids if past is used - if past_key_values and past_key_values[0] is not None: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -819,7 +830,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: @@ -830,7 +841,7 @@ def prepare_inputs_for_generation( model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} - + print(position_ids.shape) model_inputs.update( { "attention_mask": attention_mask, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index a93bdeaacd9d23..6b5607f235b1a6 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -785,11 +785,20 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -799,7 +808,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 5f193a137b00cc..54edcd30fc870d 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -912,11 +912,20 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[bool] = None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -926,7 +935,7 @@ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None return { diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 55753d5f75d9af..4afa3293ed46c5 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1080,8 +1080,17 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -1089,7 +1098,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index d08ed83af07ea1..4e8aef0678367d 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -2103,9 +2103,18 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 88e543b54b5249..6db8bbb5213b14 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1367,7 +1367,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index b4e3aac5be0b42..69de5b2e7d0e6f 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1509,7 +1509,16 @@ def prepare_inputs_for_generation( ) -> Dict: # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1740,7 +1749,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index ca6bea40337257..530c66a0c80b36 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -948,7 +948,16 @@ def prepare_inputs_for_generation( # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "input_ids": input_ids, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 276f94aebdbb9e..b53ad8848dd3c2 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1413,7 +1413,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1897,7 +1906,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 1c1eeff667d44f..5d0ad6e3410c8f 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -1251,9 +1251,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a55f16a23d5b52..62610ceb41ac66 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1083,8 +1083,18 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -1092,7 +1102,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 0c608dbd2a93bc..d760bec9854a8e 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -605,9 +605,18 @@ def prepare_inputs_for_generation( use_cache: Optional[bool] = None, **kwargs, ) -> dict: - # only last token for input_ids if past is not None - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + # only last tokens for input_ids if past is not None + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 3d03503ddd402e..0de50afe9d6d24 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1836,9 +1836,18 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index f178a6762005e6..16766e953c8574 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1995,9 +1995,17 @@ def prepare_inputs_for_generation( if decoder_attention_mask is not None: decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) - # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 21a82f95c33383..5c1ed05249ef5c 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1572,7 +1572,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -2054,7 +2063,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index f37f64627dfad4..3701bbecef2e73 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1808,7 +1808,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index d24211f039365e..8f3f246524348d 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -981,8 +981,17 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 67934520fbb6d9..55856f7b06b6be 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1466,7 +1466,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1719,7 +1728,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index def82bdbaa7182..e87e9c7164ab44 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1671,7 +1671,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index c09657c065f2be..a0bc5726382336 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -847,8 +847,17 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -856,7 +865,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 288e31a126e675..e19761803e267d 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1798,9 +1798,18 @@ def prepare_inputs_for_generation( if decoder_attention_mask is None: decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "flattened_patches": flattened_patches, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 93532f4b0d8c22..3a880839236d43 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1379,7 +1379,16 @@ def prepare_inputs_for_generation( ) -> Dict[str, Any]: # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1739,7 +1748,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index 47546930ebdfc1..fead8fc0cf7f42 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -1151,9 +1151,18 @@ def prepare_inputs_for_generation( if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 745be26ebfc97f..235bff89f8a354 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -1147,9 +1147,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 67e0fee422c4cc..6d4cc991d22ca0 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -1007,9 +1007,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index ddd87fa9ce0c1e..da1cd6331bc314 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -1014,9 +1014,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 35d4be9f20e0c0..a5b1b63050b1ef 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -1560,9 +1560,18 @@ def prepare_inputs_for_generation( if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if input_shape_ids is not None: input_shape_ids = input_shape_ids[:, -1:] if input_pronunciation_ids is not None: diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 2c3feeda12708c..b9c36a305ff1cd 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -1178,9 +1178,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index bfd801b242719f..f9b5dec4209273 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -963,7 +963,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 48334deb377865..c4de7de09089ca 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -2508,7 +2508,16 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "encoder_outputs": encoder_outputs, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 6c2fe8269782b6..541db4382dd649 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1727,9 +1727,18 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index e6d9deefa14639..9716c7ffaffa0c 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1810,9 +1810,18 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 50829592a02e72..c0541814be466e 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -1003,7 +1003,16 @@ def prepare_inputs_for_generation( attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 8323054144f549..bd35111be16e9f 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -1307,9 +1307,18 @@ def prepare_inputs_for_generation( encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index be5f50dbffa2aa..de1565fa76127b 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1810,9 +1810,17 @@ def prepare_inputs_for_generation( attention_mask=None, **kwargs, ): - # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "encoder_outputs": encoder_outputs, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 5f8778f98dcd2d..0c769dbbb5f324 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -851,21 +851,30 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) - - if past_key_values: - input_ids = input_ids[:, -1:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 761e96a11b7344..da454b1e3331f9 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -1011,9 +1011,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 025bab3887c0c7..26e0361abdb523 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -970,9 +970,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 61002bd2772e52..28fddc2fdbd6b5 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -1118,9 +1118,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f73e3f60a5530b..8e3079f748dfcc 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2906,3 +2906,89 @@ def test_default_max_length_warning(self): model.generation_config.max_length = 10 model.generate(input_ids) self.assertEqual(len(warning_list), 0) + + def test_model_kwarg_assisted_decoding_decoder_only(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_normal = model.generate(input_ids) + self.assertEqual(outputs_normal.shape, (1, 20)) + + # Should be different with token_type_ids + outputs_tti = model.generate( + input_ids, + token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device), + ) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_tti.tolist(), outputs_normal.tolist()) + + # Assistant model + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + assistant.config.pad_token_id = tokenizer.eos_token_id + + # If assisted generation passes model_kwargs correctly, should be same as previous + outputs_assisted = model.generate( + input_ids, + token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device), + assistant_model=assistant, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist()) + + def test_model_kwarg_assisted_decoding_encoder_decoder(self): + # PT-only test: TF doesn't support assisted decoding yet. + # Bart subclass with a kwarg that distorts the output + class FakeBart(BartForConditionalGeneration): + def forward(self, input_ids, foo=False, **kwargs): + outs = super().forward(input_ids, **kwargs) + + if foo: + outs["logits"][:, :, :] = 0.0 + + return outs + + def prepare_inputs_for_generation(self, *args, foo=False, **kwargs): + inputs = super().prepare_inputs_for_generation(*args, **kwargs) + + inputs["foo"] = foo + return inputs + + model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( + torch_device + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_normal = model.generate(input_ids) + self.assertEqual(outputs_normal.shape, (1, 20)) + + # Should be different with foo + outputs_foo = model.generate( + input_ids, + foo=True, + ) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) + + # Assistant model + assistant = AutoModelForSeq2SeqLM.from_pretrained( + "hf-internal-testing/tiny-random-BartForConditionalGeneration" + ).to(torch_device) + + # If assisted generation passes model_kwargs correctly, should be same as previous + outputs_assisted = model.generate( + input_ids, + foo=True, + assistant_model=assistant, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) From 9f40639292eaf723d366607a707d3f405b41541e Mon Sep 17 00:00:00 2001 From: Ben Gubler Date: Wed, 11 Oct 2023 05:50:23 -0600 Subject: [PATCH 07/11] Update docs to explain disabling callbacks using report_to (#26155) * feat: update callback doc to explain disabling callbacks using report_to * docs: update report_to docstring --- docs/source/en/main_classes/callback.md | 4 +++- src/transformers/training_args.py | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/source/en/main_classes/callback.md b/docs/source/en/main_classes/callback.md index ccfdf256832472..87bf0d63af1fc2 100644 --- a/docs/source/en/main_classes/callback.md +++ b/docs/source/en/main_classes/callback.md @@ -25,7 +25,7 @@ Callbacks are "read only" pieces of code, apart from the [`TrainerControl`] obje cannot change anything in the training loop. For customizations that require changes in the training loop, you should subclass [`Trainer`] and override the methods you need (see [trainer](trainer) for examples). -By default a [`Trainer`] will use the following callbacks: +By default, `TrainingArguments.report_to` is set to `"all"`, so a [`Trainer`] will use the following callbacks. - [`DefaultFlowCallback`] which handles the default behavior for logging, saving and evaluation. - [`PrinterCallback`] or [`ProgressCallback`] to display progress and print the @@ -45,6 +45,8 @@ By default a [`Trainer`] will use the following callbacks: - [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed. - [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed. +If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`). + The main class that implements callbacks is [`TrainerCallback`]. It gets the [`TrainingArguments`] used to instantiate the [`Trainer`], can access that Trainer's internal state via [`TrainerState`], and can take some actions on the training loop via diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 635ab656ff699c..96cb467bcbeb7c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2345,10 +2345,11 @@ def set_logging( Logger log level to use on the main process. Possible choices are the log levels as strings: `"debug"`, `"info"`, `"warning"`, `"error"` and `"critical"`, plus a `"passive"` level which doesn't set anything and lets the application set the level. - report_to (`str` or `List[str]`, *optional*, defaults to `"none"`): + report_to (`str` or `List[str]`, *optional*, defaults to `"all"`): The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`, - `"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. Use `"all"` to report - to all integrations installed, `"none"` for no integrations. + `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"flyte"`, `"mlflow"`, `"neptune"`, + `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no + integrations. first_step (`bool`, *optional*, defaults to `False`): Whether to log and evaluate the first `global_step` or not. nan_inf_filter (`bool`, *optional*, defaults to `True`): From 5334796d204177b5ccfddc3471ee1ca4cec217a6 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 11 Oct 2023 14:12:09 +0200 Subject: [PATCH 08/11] `Copied from` for test files (#26713) * copied statement for test files --------- Co-authored-by: ydshieh --- tests/models/biogpt/test_modeling_biogpt.py | 4 +- .../clap/test_feature_extraction_clap.py | 7 ++-- tests/models/llama/test_modeling_llama.py | 2 +- .../test_tokenization_longformer.py | 12 +++++- tests/models/mistral/test_modeling_mistral.py | 7 +++- .../test_tokenization_mobilebert.py | 21 +++++++++- .../persimmon/test_modeling_persimmon.py | 12 +++++- .../roberta/test_tokenization_roberta.py | 3 +- ...test_modeling_flax_roberta_prelayernorm.py | 4 +- .../test_modeling_roberta_prelayernorm.py | 26 ++++++++++--- .../test_modeling_tf_roberta_prelayernorm.py | 4 +- .../roc_bert/test_tokenization_roc_bert.py | 30 +++++++-------- .../whisper/test_tokenization_whisper.py | 2 +- utils/check_copies.py | 38 ++++++++++++++++--- 14 files changed, 127 insertions(+), 45 deletions(-) diff --git a/tests/models/biogpt/test_modeling_biogpt.py b/tests/models/biogpt/test_modeling_biogpt.py index e43fc1e41b8f9d..b7db0bbe28a7b7 100644 --- a/tests/models/biogpt/test_modeling_biogpt.py +++ b/tests/models/biogpt/test_modeling_biogpt.py @@ -386,7 +386,7 @@ def test_model_from_pretrained(self): model = BioGptModel.from_pretrained(model_name) self.assertIsNotNone(model) - # Copied from tests.models.opt.test_modeling_opt.OPTModelTest with OPT->BioGpt, prepare_config_and_inputs-> prepare_config_and_inputs_for_common + # Copied from tests.models.opt.test_modeling_opt.OPTModelTest.test_opt_sequence_classification_model with OPT->BioGpt,opt->biogpt,prepare_config_and_inputs->prepare_config_and_inputs_for_common def test_biogpt_sequence_classification_model(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 @@ -399,7 +399,7 @@ def test_biogpt_sequence_classification_model(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - # Copied from tests.models.opt.test_modeling_opt.OPTModelTest with OPT->BioGpt, prepare_config_and_inputs-> prepare_config_and_inputs_for_common + # Copied from tests.models.opt.test_modeling_opt.OPTModelTest.test_opt_sequence_classification_model_for_multi_label with OPT->BioGpt,opt->biogpt,prepare_config_and_inputs->prepare_config_and_inputs_for_common def test_biogpt_sequence_classification_model_for_multi_label(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 diff --git a/tests/models/clap/test_feature_extraction_clap.py b/tests/models/clap/test_feature_extraction_clap.py index c49d045ba87407..d0e913df828b84 100644 --- a/tests/models/clap/test_feature_extraction_clap.py +++ b/tests/models/clap/test_feature_extraction_clap.py @@ -19,6 +19,7 @@ import unittest import numpy as np +from datasets import load_dataset from transformers import ClapFeatureExtractor from transformers.testing_utils import require_torch, require_torchaudio @@ -110,10 +111,10 @@ def _flatten(list_of_lists): @require_torch @require_torchaudio -# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest with Whisper->Clap class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): feature_extraction_class = ClapFeatureExtractor + # Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.setUp with Whisper->Clap def setUp(self): self.feat_extract_tester = ClapFeatureExtractionTester(self) @@ -147,6 +148,7 @@ def test_call(self): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + # Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad def test_double_precision_pad(self): import torch @@ -160,9 +162,8 @@ def test_double_precision_pad(self): pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt") self.assertTrue(pt_processed.input_features.dtype == torch.float32) + # Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest._load_datasamples def _load_datasamples(self, num_samples): - from datasets import load_dataset - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # automatic decoding with librispeech speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 0223acbbd72a8a..2402986900fda6 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -341,7 +341,7 @@ def test_llama_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - @unittest.skip("LLaMA buffers include complex numbers, which breaks this test") + @unittest.skip("Llama buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/longformer/test_tokenization_longformer.py b/tests/models/longformer/test_tokenization_longformer.py index 2397a40bafa6b1..61d8653b60c608 100644 --- a/tests/models/longformer/test_tokenization_longformer.py +++ b/tests/models/longformer/test_tokenization_longformer.py @@ -27,7 +27,6 @@ from ...test_tokenization_common import TokenizerTesterMixin -# Copied from transformers.tests.roberta.test_modeling_roberta.py with Roberta->Longformer @require_tokenizers class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = LongformerTokenizer @@ -72,19 +71,23 @@ def setUp(self): with open(self.merges_file, "w", encoding="utf-8") as fp: fp.write("\n".join(merges)) + # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_tokenizer def get_tokenizer(self, **kwargs): kwargs.update(self.special_tokens_map) return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) + # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_rust_tokenizer def get_rust_tokenizer(self, **kwargs): kwargs.update(self.special_tokens_map) return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) + # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_input_output_texts def get_input_output_texts(self, tokenizer): input_text = "lower newer" output_text = "lower newer" return input_text, output_text + # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_full_tokenizer def test_full_tokenizer(self): tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map) text = "lower newer" @@ -96,6 +99,7 @@ def test_full_tokenizer(self): input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19] self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.roberta_dict_integration_testing with roberta->longformer def longformer_dict_integration_testing(self): tokenizer = self.get_tokenizer() @@ -106,6 +110,7 @@ def longformer_dict_integration_testing(self): ) @slow + # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_sequence_builders with roberta-base->allenai/longformer-base-4096 def test_sequence_builders(self): tokenizer = self.tokenizer_class.from_pretrained("allenai/longformer-base-4096") @@ -125,6 +130,7 @@ def test_sequence_builders(self): assert encoded_sentence == encoded_text_from_decode assert encoded_pair == encoded_pair_from_decode + # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_space_encoding def test_space_encoding(self): tokenizer = self.get_tokenizer() @@ -165,9 +171,11 @@ def test_space_encoding(self): first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0] self.assertNotEqual(first_char, space_encoding) + # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_pretokenized_inputs def test_pretokenized_inputs(self): pass + # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_embeded_special_tokens def test_embeded_special_tokens(self): for tokenizer, pretrained_name, kwargs in self.tokenizers_list: with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): @@ -200,6 +208,7 @@ def test_embeded_special_tokens(self): tokens_r_str, ["", "A", ",", "", "ĠAllen", "N", "LP", "Ġsentence", ".", ""] ) + # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_change_add_prefix_space_and_trim_offsets_args def test_change_add_prefix_space_and_trim_offsets_args(self): for trim_offsets, add_prefix_space in itertools.product([True, False], repeat=2): tokenizer_r = self.rust_tokenizer_class.from_pretrained( @@ -214,6 +223,7 @@ def test_change_add_prefix_space_and_trim_offsets_args(self): self.assertEqual(post_processor_state["add_prefix_space"], add_prefix_space) self.assertEqual(post_processor_state["trim_offsets"], trim_offsets) + # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments def test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments(self): # Test which aims to verify that the offsets are well adapted to the argument `add_prefix_space` and # `trim_offsets` diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index df1143d2516afd..d2e9b2685f0beb 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -39,7 +39,6 @@ ) -# Copied from transformers.tests.mistral.test_modelling_mistral.MistralModelTest with Llama->Mistral class MistralModelTester: def __init__( self, @@ -93,6 +92,7 @@ def __init__( self.pad_token_id = pad_token_id self.scope = scope + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -134,6 +134,7 @@ def get_config(self): pad_token_id=self.pad_token_id, ) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mistral def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -144,6 +145,7 @@ def create_and_check_model( result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Mistral def create_and_check_model_as_decoder( self, config, @@ -174,6 +176,7 @@ def create_and_check_model_as_decoder( result = model(input_ids, attention_mask=input_mask) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Mistral def create_and_check_for_causal_lm( self, config, @@ -192,6 +195,7 @@ def create_and_check_for_causal_lm( result = model(input_ids, attention_mask=input_mask, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Mistral def create_and_check_decoder_model_past_large_inputs( self, config, @@ -254,6 +258,7 @@ def create_and_check_decoder_model_past_large_inputs( # test that outputs are equal for slice self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( diff --git a/tests/models/mobilebert/test_tokenization_mobilebert.py b/tests/models/mobilebert/test_tokenization_mobilebert.py index 3ecc2e3238d512..babed7a8d9bfdc 100644 --- a/tests/models/mobilebert/test_tokenization_mobilebert.py +++ b/tests/models/mobilebert/test_tokenization_mobilebert.py @@ -32,7 +32,6 @@ from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english -# Copied from transformers.tests.models.bert.test_modeling_bert.py with Bert->MobileBert and pathfix @require_tokenizers class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = MobileBertTokenizer @@ -71,11 +70,13 @@ def setUp(self): for tokenizer_def in self.tokenizers_list ] + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.get_input_output_texts def get_input_output_texts(self, tokenizer): input_text = "UNwant\u00E9d,running" output_text = "unwanted, running" return input_text, output_text + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_full_tokenizer def test_full_tokenizer(self): tokenizer = self.tokenizer_class(self.vocab_file) @@ -83,6 +84,7 @@ def test_full_tokenizer(self): self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11]) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_rust_and_python_full_tokenizers def test_rust_and_python_full_tokenizers(self): if not self.test_rust_tokenizer: return @@ -124,11 +126,13 @@ def test_rust_and_python_full_tokenizers(self): rust_ids = rust_tokenizer.encode(sequence) self.assertListEqual(ids, rust_ids) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_chinese def test_chinese(self): tokenizer = BasicTokenizer() self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"]) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower def test_basic_tokenizer_lower(self): tokenizer = BasicTokenizer(do_lower_case=True) @@ -137,6 +141,7 @@ def test_basic_tokenizer_lower(self): ) self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_false def test_basic_tokenizer_lower_strip_accents_false(self): tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False) @@ -145,6 +150,7 @@ def test_basic_tokenizer_lower_strip_accents_false(self): ) self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"]) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_true def test_basic_tokenizer_lower_strip_accents_true(self): tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=True) @@ -153,6 +159,7 @@ def test_basic_tokenizer_lower_strip_accents_true(self): ) self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_default def test_basic_tokenizer_lower_strip_accents_default(self): tokenizer = BasicTokenizer(do_lower_case=True) @@ -161,6 +168,7 @@ def test_basic_tokenizer_lower_strip_accents_default(self): ) self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower def test_basic_tokenizer_no_lower(self): tokenizer = BasicTokenizer(do_lower_case=False) @@ -168,6 +176,7 @@ def test_basic_tokenizer_no_lower(self): tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"] ) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower_strip_accents_false def test_basic_tokenizer_no_lower_strip_accents_false(self): tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False) @@ -175,6 +184,7 @@ def test_basic_tokenizer_no_lower_strip_accents_false(self): tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HäLLo", "!", "how", "Are", "yoU", "?"] ) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower_strip_accents_true def test_basic_tokenizer_no_lower_strip_accents_true(self): tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True) @@ -182,6 +192,7 @@ def test_basic_tokenizer_no_lower_strip_accents_true(self): tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"] ) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_respects_never_split_tokens def test_basic_tokenizer_respects_never_split_tokens(self): tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"]) @@ -189,6 +200,7 @@ def test_basic_tokenizer_respects_never_split_tokens(self): tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"] ) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_wordpiece_tokenizer def test_wordpiece_tokenizer(self): vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"] @@ -203,6 +215,7 @@ def test_wordpiece_tokenizer(self): self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_whitespace def test_is_whitespace(self): self.assertTrue(_is_whitespace(" ")) self.assertTrue(_is_whitespace("\t")) @@ -213,6 +226,7 @@ def test_is_whitespace(self): self.assertFalse(_is_whitespace("A")) self.assertFalse(_is_whitespace("-")) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_control def test_is_control(self): self.assertTrue(_is_control("\u0005")) @@ -221,6 +235,7 @@ def test_is_control(self): self.assertFalse(_is_control("\t")) self.assertFalse(_is_control("\r")) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_punctuation def test_is_punctuation(self): self.assertTrue(_is_punctuation("-")) self.assertTrue(_is_punctuation("$")) @@ -230,6 +245,7 @@ def test_is_punctuation(self): self.assertFalse(_is_punctuation("A")) self.assertFalse(_is_punctuation(" ")) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_clean_text def test_clean_text(self): tokenizer = self.get_tokenizer() rust_tokenizer = self.get_rust_tokenizer() @@ -242,6 +258,7 @@ def test_clean_text(self): ) @slow + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_sequence_builders with bert-base-uncased->google/mobilebert-uncased def test_sequence_builders(self): tokenizer = self.tokenizer_class.from_pretrained("google/mobilebert-uncased") @@ -254,6 +271,7 @@ def test_sequence_builders(self): assert encoded_sentence == [101] + text + [102] assert encoded_pair == [101] + text + [102] + text_2 + [102] + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_offsets_with_special_characters def test_offsets_with_special_characters(self): for tokenizer, pretrained_name, kwargs in self.tokenizers_list: with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): @@ -306,6 +324,7 @@ def test_offsets_with_special_characters(self): ) self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"]) + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_change_tokenize_chinese_chars def test_change_tokenize_chinese_chars(self): list_of_commun_chinese_char = ["的", "人", "有"] text_with_chinese_char = "".join(list_of_commun_chinese_char) diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 3b67128c3b7372..60a5dabf1053f1 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -39,7 +39,7 @@ ) -# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Persimmon +# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon class PersimmonModelTester: def __init__( self, @@ -266,7 +266,6 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict -# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Persimmon @require_torch class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( @@ -288,23 +287,28 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester test_headmasking = False test_pruning = False + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Persimmon def setUp(self): self.model_tester = PersimmonModelTester(self) self.config_tester = ConfigTester(self, config_class=PersimmonConfig, hidden_size=37) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config def test_config(self): self.config_tester.run_common_tests() + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings def test_model_various_embeddings(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() for type in ["absolute", "relative_key", "relative_key_query"]: config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Persimmon,llama->persimmon def test_persimmon_sequence_classification_model(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 @@ -317,6 +321,7 @@ def test_persimmon_sequence_classification_model(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Persimmon,llama->persimmon def test_persimmon_sequence_classification_model_for_single_label(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 @@ -330,6 +335,7 @@ def test_persimmon_sequence_classification_model_for_single_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Persimmon,llama->persimmon def test_persimmon_sequence_classification_model_for_multi_label(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 @@ -346,10 +352,12 @@ def test_persimmon_sequence_classification_model_for_multi_label(self): self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) @unittest.skip("Persimmon buffers include complex numbers, which breaks this test") + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base def test_save_load_fast_init_from_base(self): pass @parameterized.expand([("linear",), ("dynamic",)]) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling with Llama->Persimmon def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) diff --git a/tests/models/roberta/test_tokenization_roberta.py b/tests/models/roberta/test_tokenization_roberta.py index 78bac218351bf3..3190ab13be4ea1 100644 --- a/tests/models/roberta/test_tokenization_roberta.py +++ b/tests/models/roberta/test_tokenization_roberta.py @@ -76,8 +76,7 @@ def get_tokenizer(self, **kwargs): def get_rust_tokenizer(self, **kwargs): kwargs.update(self.special_tokens_map) - return RobertaTokenizerFast.from_pretrained(self.tmpdirname, **kwargs) - return RobertaTokenizerFast(self.vocab_file, self.merges_file, **kwargs) + return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self, tokenizer): input_text = "lower newer" diff --git a/tests/models/roberta_prelayernorm/test_modeling_flax_roberta_prelayernorm.py b/tests/models/roberta_prelayernorm/test_modeling_flax_roberta_prelayernorm.py index 8500dfcb67a84f..65dbe65974d4c4 100644 --- a/tests/models/roberta_prelayernorm/test_modeling_flax_roberta_prelayernorm.py +++ b/tests/models/roberta_prelayernorm/test_modeling_flax_roberta_prelayernorm.py @@ -36,7 +36,7 @@ ) -# Copied from tests.models.roberta.test_modelling_flax_roberta.FlaxRobertaModelTester with Roberta->RobertaPreLayerNorm +# Copied from tests.models.roberta.test_modeling_flax_roberta.FlaxRobertaModelTester with Roberta->RobertaPreLayerNorm class FlaxRobertaPreLayerNormModelTester(unittest.TestCase): def __init__( self, @@ -134,7 +134,7 @@ def prepare_config_and_inputs_for_decoder(self): @require_flax -# Copied from tests.models.roberta.test_modelling_flax_roberta.FlaxRobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta-base->andreasmadsen/efficient_mlm_m0.40 +# Copied from tests.models.roberta.test_modeling_flax_roberta.FlaxRobertaModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta-base->andreasmadsen/efficient_mlm_m0.40 class FlaxRobertaPreLayerNormModelTest(FlaxModelTesterMixin, unittest.TestCase): test_head_masking = True diff --git a/tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py b/tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py index ee0972eec32964..a2f56e31a09192 100644 --- a/tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py +++ b/tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py @@ -44,7 +44,7 @@ ) -# Copied from tests.models.roberta.test_modelling_roberta.RobertaModelTester with Roberta->RobertaPreLayerNorm +# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTester with Roberta->RobertaPreLayerNorm class RobertaPreLayerNormModelTester: def __init__( self, @@ -365,7 +365,6 @@ def prepare_config_and_inputs_for_common(self): @require_torch -# Copied from tests.models.roberta.test_modelling_roberta.RobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( ( @@ -397,27 +396,33 @@ class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, Pipe fx_compatible = False model_split_percents = [0.5, 0.8, 0.9] + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.setUp with Roberta->RobertaPreLayerNorm def setUp(self): self.model_tester = RobertaPreLayerNormModelTester(self) self.config_tester = ConfigTester(self, config_class=RobertaPreLayerNormConfig, hidden_size=37) + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_config def test_config(self): self.config_tester.run_common_tests() + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_various_embeddings def test_model_various_embeddings(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() for type in ["absolute", "relative_key", "relative_key_query"]: config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_as_decoder def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_as_decoder_with_default_input_mask def test_model_as_decoder_with_default_input_mask(self): # This regression test was failing with PyTorch < 1.3 ( @@ -446,42 +451,50 @@ def test_model_as_decoder_with_default_input_mask(self): encoder_attention_mask, ) + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_causal_lm def test_for_causal_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_decoder_model_past_with_large_inputs def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_masked_lm def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_token_classification def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_multiple_choice def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_question_answering def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) @slow + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_from_pretrained with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm def test_model_from_pretrained(self): for model_name in ROBERTA_PRELAYERNORM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = RobertaPreLayerNormModel.from_pretrained(model_name) self.assertIsNotNone(model) + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_create_position_ids_respects_padding_index with Roberta->RobertaPreLayerNorm def test_create_position_ids_respects_padding_index(self): """Ensure that the default position ids only assign a sequential . This is a regression test for https://github.com/huggingface/transformers/issues/1761 - The position ids should be masked with the embedding object's padding index. Therefore, the - first available non-padding position index is RobertaPreLayerNormEmbeddings.padding_idx + 1 + The position ids should be masked with the embedding object's padding index. Therefore, the first available + non-padding position index is RobertaPreLayerNormEmbeddings.padding_idx + 1 """ config = self.model_tester.prepare_config_and_inputs()[0] model = RobertaPreLayerNormEmbeddings(config=config) @@ -495,12 +508,13 @@ def test_create_position_ids_respects_padding_index(self): self.assertEqual(position_ids.shape, expected_positions.shape) self.assertTrue(torch.all(torch.eq(position_ids, expected_positions))) + # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_create_position_ids_from_inputs_embeds with Roberta->RobertaPreLayerNorm def test_create_position_ids_from_inputs_embeds(self): """Ensure that the default position ids only assign a sequential . This is a regression test for https://github.com/huggingface/transformers/issues/1761 - The position ids should be masked with the embedding object's padding index. Therefore, the - first available non-padding position index is RobertaPreLayerNormEmbeddings.padding_idx + 1 + The position ids should be masked with the embedding object's padding index. Therefore, the first available + non-padding position index is RobertaPreLayerNormEmbeddings.padding_idx + 1 """ config = self.model_tester.prepare_config_and_inputs()[0] embeddings = RobertaPreLayerNormEmbeddings(config=config) diff --git a/tests/models/roberta_prelayernorm/test_modeling_tf_roberta_prelayernorm.py b/tests/models/roberta_prelayernorm/test_modeling_tf_roberta_prelayernorm.py index 9c1a25ccb982ea..384fa2e9e40013 100644 --- a/tests/models/roberta_prelayernorm/test_modeling_tf_roberta_prelayernorm.py +++ b/tests/models/roberta_prelayernorm/test_modeling_tf_roberta_prelayernorm.py @@ -42,7 +42,7 @@ ) -# Copied from tests.models.roberta.test_modelling_tf_roberta.TFRobertaModelTester with Roberta->RobertaPreLayerNorm +# Copied from tests.models.roberta.test_modeling_tf_roberta.TFRobertaModelTester with Roberta->RobertaPreLayerNorm class TFRobertaPreLayerNormModelTester: def __init__( self, @@ -551,7 +551,7 @@ def prepare_config_and_inputs_for_common(self): @require_tf -# Copied from tests.models.roberta.test_modelling_tf_roberta.TFRobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm +# Copied from tests.models.roberta.test_modeling_tf_roberta.TFRobertaModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm class TFRobertaPreLayerNormModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( ( diff --git a/tests/models/roc_bert/test_tokenization_roc_bert.py b/tests/models/roc_bert/test_tokenization_roc_bert.py index 0f8fe08efd1517..6a24514b3c2c2b 100644 --- a/tests/models/roc_bert/test_tokenization_roc_bert.py +++ b/tests/models/roc_bert/test_tokenization_roc_bert.py @@ -68,13 +68,13 @@ def test_full_tokenizer(self): self.assertListEqual(tokenizer.convert_tokens_to_shape_ids(tokens), [5, 6, 2, 5, 7, 8]) self.assertListEqual(tokenizer.convert_tokens_to_pronunciation_ids(tokens), [5, 6, 2, 5, 7, 8]) - # Copied from tests.models.bert.test_tokenization_bert.test_chinese with BasicTokenizer->RoCBertBertBasicTokenizer + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_chinese with BasicTokenizer->RoCBertBasicTokenizer def test_chinese(self): tokenizer = RoCBertBasicTokenizer() self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"]) - # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_lower with BasicTokenizer->RoCBertBertBasicTokenizer + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower with BasicTokenizer->RoCBertBasicTokenizer def test_basic_tokenizer_lower(self): tokenizer = RoCBertBasicTokenizer(do_lower_case=True) @@ -83,7 +83,7 @@ def test_basic_tokenizer_lower(self): ) self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) - # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_lower_strip_accents_false with BasicTokenizer->RoCBertBertBasicTokenizer + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_false with BasicTokenizer->RoCBertBasicTokenizer def test_basic_tokenizer_lower_strip_accents_false(self): tokenizer = RoCBertBasicTokenizer(do_lower_case=True, strip_accents=False) @@ -92,7 +92,7 @@ def test_basic_tokenizer_lower_strip_accents_false(self): ) self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"]) - # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_lower_strip_accents_true with BertBasicTokenizer->RoCBertBertBasicTokenizer + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_true with BasicTokenizer->RoCBertBasicTokenizer def test_basic_tokenizer_lower_strip_accents_true(self): tokenizer = RoCBertBasicTokenizer(do_lower_case=True, strip_accents=True) @@ -101,7 +101,7 @@ def test_basic_tokenizer_lower_strip_accents_true(self): ) self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) - # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_lower_strip_accents_default with BasicTokenizer->RoCBertBertBasicTokenizer + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_default with BasicTokenizer->RoCBertBasicTokenizer def test_basic_tokenizer_lower_strip_accents_default(self): tokenizer = RoCBertBasicTokenizer(do_lower_case=True) @@ -110,7 +110,7 @@ def test_basic_tokenizer_lower_strip_accents_default(self): ) self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) - # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_no_lower with BasicTokenizer->RoCBertBertBasicTokenizer + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower with BasicTokenizer->RoCBertBasicTokenizer def test_basic_tokenizer_no_lower(self): tokenizer = RoCBertBasicTokenizer(do_lower_case=False) @@ -118,7 +118,7 @@ def test_basic_tokenizer_no_lower(self): tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"] ) - # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_no_lower_strip_accents_false with BertBasicTokenizer->RoCBertBertBasicTokenizer + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower_strip_accents_false with BasicTokenizer->RoCBertBasicTokenizer def test_basic_tokenizer_no_lower_strip_accents_false(self): tokenizer = RoCBertBasicTokenizer(do_lower_case=False, strip_accents=False) @@ -126,7 +126,7 @@ def test_basic_tokenizer_no_lower_strip_accents_false(self): tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HäLLo", "!", "how", "Are", "yoU", "?"] ) - # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_no_lower_strip_accents_true with BasicTokenizer->RoCBertBertBasicTokenizer + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower_strip_accents_true with BasicTokenizer->RoCBertBasicTokenizer def test_basic_tokenizer_no_lower_strip_accents_true(self): tokenizer = RoCBertBasicTokenizer(do_lower_case=False, strip_accents=True) @@ -134,7 +134,7 @@ def test_basic_tokenizer_no_lower_strip_accents_true(self): tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"] ) - # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_respects_never_split_tokens with BasicTokenizer->RoCBertBertBasicTokenizer + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_respects_never_split_tokens with BasicTokenizer->RoCBertBasicTokenizer def test_basic_tokenizer_respects_never_split_tokens(self): tokenizer = RoCBertBasicTokenizer(do_lower_case=False, never_split=["[UNK]"]) @@ -142,7 +142,7 @@ def test_basic_tokenizer_respects_never_split_tokens(self): tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"] ) - # Copied from tests.models.bert.test_tokenization_bert.test_wordpiece_tokenizer with WordpieceTokenizer->RoCBertWordpieceTokenizer + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_wordpiece_tokenizer with WordpieceTokenizer->RoCBertWordpieceTokenizer def test_wordpiece_tokenizer(self): vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"] @@ -157,7 +157,7 @@ def test_wordpiece_tokenizer(self): self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) - # Copied from tests.models.bert.test_tokenization_bert.test_is_whitespace + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_whitespace def test_is_whitespace(self): self.assertTrue(_is_whitespace(" ")) self.assertTrue(_is_whitespace("\t")) @@ -168,7 +168,7 @@ def test_is_whitespace(self): self.assertFalse(_is_whitespace("A")) self.assertFalse(_is_whitespace("-")) - # Copied from tests.models.bert.test_tokenization_bert.test_is_control + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_control def test_is_control(self): self.assertTrue(_is_control("\u0005")) @@ -177,7 +177,7 @@ def test_is_control(self): self.assertFalse(_is_control("\t")) self.assertFalse(_is_control("\r")) - # Copied from tests.models.bert.test_tokenization_bert.test_is_punctuation + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_punctuation def test_is_punctuation(self): self.assertTrue(_is_punctuation("-")) self.assertTrue(_is_punctuation("$")) @@ -199,7 +199,7 @@ def test_clean_text(self): [rust_tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]] ) - # Copied from tests.models.bert.test_tokenization_bert. test_offsets_with_special_characters + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_offsets_with_special_characters def test_offsets_with_special_characters(self): for tokenizer, pretrained_name, kwargs in self.tokenizers_list: with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): @@ -252,7 +252,7 @@ def test_offsets_with_special_characters(self): ) self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"]) - # Copied from tests.models.bert.test_tokenization_bert. test_change_tokenize_chinese_chars + # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_change_tokenize_chinese_chars def test_change_tokenize_chinese_chars(self): list_of_commun_chinese_char = ["的", "人", "有"] text_with_chinese_char = "".join(list_of_commun_chinese_char) diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index be9e11de5401ea..fd1c135deb53b5 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -376,7 +376,7 @@ def test_tokenizer_special(self): def test_vocab_size(self): self.assertEqual(self.tokenizer.vocab_size, 50257) - # Copied from transformers.tests.speech_to_test.test_tokenization_speech_to_text.py + # Copied from tests.models.speech_to_text.test_tokenization_speech_to_text.SpeechToTextTokenizerMultilinguialTest.test_tokenizer_decode_ignores_language_codes def test_tokenizer_decode_ignores_language_codes(self): self.assertIn(ES_CODE, self.tokenizer.all_special_ids) generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2] diff --git a/utils/check_copies.py b/utils/check_copies.py index f198b9e062b38b..667c11ec724b42 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -51,6 +51,7 @@ # All paths are set with the intent you should run this script from the root of the repo with the command # python utils/check_copies.py TRANSFORMERS_PATH = "src/transformers" +MODEL_TEST_PATH = "tests/models" PATH_TO_DOCS = "docs/source/en" REPO_PATH = "." @@ -132,12 +133,15 @@ def _should_continue(line: str, indent: str) -> bool: return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None -def find_code_in_transformers(object_name: str) -> str: +def find_code_in_transformers(object_name: str, base_path: str = None) -> str: """ Find and return the source code of an object. Args: - object_name (`str`): The name of the object we want the source code of. + object_name (`str`): + The name of the object we want the source code of. + base_path (`str`, *optional*): + The path to the base folder where files are checked. If not set, it will be set to `TRANSFORMERS_PATH`. Returns: `str`: The source code of the object. @@ -145,9 +149,21 @@ def find_code_in_transformers(object_name: str) -> str: parts = object_name.split(".") i = 0 + # We can't set this as the default value in the argument, otherwise `CopyCheckTester` will fail, as it uses a + # patched temp directory. + if base_path is None: + base_path = TRANSFORMERS_PATH + + # Detail: the `Copied from` statement is originally designed to work with the last part of `TRANSFORMERS_PATH`, + # (which is `transformers`). The same should be applied for `MODEL_TEST_PATH`. However, its last part is `models` + # (to only check and search in it) which is a bit confusing. So we keep the copied statement staring with + # `tests.models.` and change it to `tests` here. + if base_path == MODEL_TEST_PATH: + base_path = "tests" + # First let's find the module where our object lives. module = parts[i] - while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")): + while i < len(parts) and not os.path.isfile(os.path.join(base_path, f"{module}.py")): i += 1 if i < len(parts): module = os.path.join(module, parts[i]) @@ -156,7 +172,7 @@ def find_code_in_transformers(object_name: str) -> str: f"`object_name` should begin with the name of a module of transformers but got {object_name}." ) - with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f: + with open(os.path.join(base_path, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() # Now let's find the class / func in the code! @@ -186,6 +202,7 @@ def find_code_in_transformers(object_name: str) -> str: _re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)") +_re_copy_warning_for_test_file = re.compile(r"^(\s*)#\s*Copied from\s+tests\.(\S+\.\S+)\s*($|\S.*$)") _re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)") _re_fill_pattern = re.compile(r"]*>") @@ -284,14 +301,20 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[ line_index = 0 # Not a for loop cause `lines` is going to change (if `overwrite=True`). while line_index < len(lines): - search = _re_copy_warning.search(lines[line_index]) + search_re = _re_copy_warning + if filename.startswith("tests"): + search_re = _re_copy_warning_for_test_file + + search = search_re.search(lines[line_index]) if search is None: line_index += 1 continue # There is some copied code here, let's retrieve the original. indent, object_name, replace_pattern = search.groups() - theoretical_code = find_code_in_transformers(object_name) + + base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH + theoretical_code = find_code_in_transformers(object_name, base_path=base_path) theoretical_indent = get_indent(theoretical_code) start_index = line_index + 1 if indent == theoretical_indent else line_index @@ -357,6 +380,9 @@ def check_copies(overwrite: bool = False): Whether or not to overwrite the copies when they don't match. """ all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True) + all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True) + all_files = list(all_files) + list(all_test_files) + diffs = [] for filename in all_files: new_diffs = is_copy_consistent(filename, overwrite) From da69de17e86501b95396086a5b6479f645e8f70e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Oct 2023 15:52:20 +0200 Subject: [PATCH 09/11] [Assistant Generation] Improve Encoder Decoder (#26701) * [Assistant Generation] Improve enc dec * save more * Fix logit processor checks * Clean * make style * fix deprecation * fix generation test * Apply suggestions from code review * fix biogpt * make style --- .../generation/configuration_utils.py | 18 +++++++++ src/transformers/generation/utils.py | 39 +++++++++++-------- .../models/biogpt/modeling_biogpt.py | 6 ++- tests/generation/test_utils.py | 14 ++++++- 4 files changed, 59 insertions(+), 18 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 18ccdb2835b411..3bd85568dcb714 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -227,6 +227,20 @@ class GenerationConfig(PushToHubMixin): decoder_start_token_id (`int`, *optional*): If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. + > Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192) + + num_assistant_tokens (`int`, *optional*, defaults to 5): + Defines the number of _speculative tokens_ that shall be generated by the assistant model before being + checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation + more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant + model requires lots of corrections, lower speed-ups are reached. + + num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`): + Defines the schedule at which max assistant tokens shall be changed during inference. + - `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else + reduce by 1 + - `"constant"`: `num_assistant_tokens` stays unchanged during generation + > Wild card generation_kwargs: @@ -294,6 +308,10 @@ def __init__(self, **kwargs): self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0) self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) + # Assistant generation + self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) + self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") + # Wild card self.generation_kwargs = kwargs.pop("generation_kwargs", {}) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 49b213cc5e92cd..a104113af891ff 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1241,6 +1241,10 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): decoder_model_args = set(inspect.signature(decoder.forward).parameters) model_args |= {f"decoder_{x}" for x in decoder_model_args} + # allow assistant_encoder_outputs to be passed if we're doing assisted generating + if "assistant_encoder_outputs" in model_kwargs: + model_args |= {"assistant_encoder_outputs"} + for key, value in model_kwargs.items(): if value is not None and key not in model_args: unused_model_args.append(key) @@ -1612,7 +1616,7 @@ def generate( raise ValueError("assisted generate requires `use_cache=True`") # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs - if assistant_model.config.is_encoder_decoder: + if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs: assistant_model_kwargs = copy.deepcopy(model_kwargs) inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs( inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs @@ -4347,8 +4351,14 @@ def assisted_decoding( ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" # Assistant: initialize assistant-related variables - if not hasattr(assistant_model, "max_assistant_tokens"): - assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls + if hasattr(assistant_model, "num_assistant_tokens"): + warnings.warn( + "Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be removed in v.37. Make sure to set `num_assistant_tokens` via the generation_config instead.", + FutureWarning, + ) + num_assistant_tokens = assistant_model.num_assistant_tokens + else: + num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() @@ -4421,26 +4431,23 @@ def assisted_decoding( # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we # need access to the assistant cache to secure strong speedups. candidate_input_ids = input_ids - for _ in range(int(assistant_model.max_assistant_tokens)): + for _ in range(int(num_assistant_tokens)): # 1.1. use the assistant model to obtain the next candidate logits if "assistant_past_key_values" in model_kwargs: prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) new_token_len = candidate_input_ids.shape[1] - prev_seq_len assist_inputs = candidate_input_ids[:, -new_token_len:] - assist_attn = torch.ones_like(candidate_input_ids) # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 if assistant_model.config.is_encoder_decoder: assistant_model_outputs = assistant_model( decoder_input_ids=assist_inputs, - decoder_attention_mask=assist_attn, past_key_values=model_kwargs["assistant_past_key_values"], encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) else: assistant_model_outputs = assistant_model( assist_inputs, - attention_mask=assist_attn, past_key_values=model_kwargs["assistant_past_key_values"], ) else: @@ -4495,18 +4502,18 @@ def assisted_decoding( # 2.3. Process the new logits new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present if len(logits_processor) > 0: - for i in range(candidate_length): + for i in range(candidate_length + 1): new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) if len(logits_warper) > 0: - for i in range(candidate_length): + for i in range(candidate_length + 1): new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) # 3. Obtain the next tokens from the original model logits. if do_sample: - probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1) + probs = new_logits.softmax(dim=-1) selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] else: - selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1) + selected_tokens = new_logits.argmax(dim=-1) # 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep # the assistant forecasted tokens until the first mismatch, or until the max length is reached. @@ -4540,13 +4547,13 @@ def assisted_decoding( # 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the # cost of forecasting incorrect assistant tokens. - if n_matches == int(assistant_model.max_assistant_tokens): - assistant_model.max_assistant_tokens += 2.0 - else: - assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0) + if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic": + if n_matches == int(num_assistant_tokens): + num_assistant_tokens += 2.0 + else: + num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0) # Assistant: main logic end - if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index d1c471aa8090c9..ca084db5c7d0b9 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -544,7 +544,11 @@ def forward( inputs_embeds = self.embed_tokens(input) * self.embed_scale if attention_mask is None: - attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device) + attention_mask = torch.ones( + (inputs_embeds.shape[0], inputs_embeds.shape[1] + past_key_values_length), + dtype=torch.bool, + device=inputs_embeds.device, + ) elif attention_mask.shape[1] != past_key_values_length + input_shape[1]: raise ValueError( f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 8e3079f748dfcc..175861fd149e5e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2953,7 +2953,8 @@ def forward(self, input_ids, foo=False, **kwargs): return outs - def prepare_inputs_for_generation(self, *args, foo=False, **kwargs): + def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): + kwargs["encoder_outputs"] = encoder_outputs inputs = super().prepare_inputs_for_generation(*args, **kwargs) inputs["foo"] = foo @@ -2992,3 +2993,14 @@ def prepare_inputs_for_generation(self, *args, foo=False, **kwargs): assistant_model=assistant, ) self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + + # Check that passing encoder_outputs directly also works as expected + encoder_outputs = assistant.get_encoder()(input_ids) + + outputs_assisted = model.generate( + foo=True, + assistant_model=assistant, + encoder_outputs=encoder_outputs, + assistant_encoder_outputs=encoder_outputs, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) From cc44ca8017b11c0086593f8666b23d4c420c5d51 Mon Sep 17 00:00:00 2001 From: Shivanand Date: Wed, 11 Oct 2023 19:23:32 +0530 Subject: [PATCH 10/11] [docstring] `SwinModel` docstring fix (#26679) * remove from utils * updated doc string * only in the model * Update src/transformers/models/swin/modeling_swin.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * Update src/transformers/models/swin/modeling_swin.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> --------- Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> --- src/transformers/models/swin/modeling_swin.py | 6 ++++++ utils/check_docstrings.py | 1 - 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 2cf1d33a51139d..45a7aa718cf026 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -942,6 +942,12 @@ def _set_gradient_checkpointing(self, module, value=False): @add_start_docstrings( "The bare Swin Model transformer outputting raw hidden-states without any specific head on top.", SWIN_START_DOCSTRING, + """ + add_pooling_layer (`bool`, *optional*, defaults to `True`): + Whether or not to apply pooling layer. + use_mask_token (`bool`, *optional*, defaults to `False`): + Whether or not to create and apply mask tokens in the embedding layer. + """, ) class SwinModel(SwinPreTrainedModel): def __init__(self, config, add_pooling_layer=True, use_mask_token=False): diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index e140be28037d59..f142c5dbccd1df 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -499,7 +499,6 @@ "SqueezeBertTokenizerFast", "SummarizationPipeline", "Swin2SRImageProcessor", - "SwinModel", "Swinv2Model", "SwitchTransformersConfig", "T5Config", From 69873d529db9796eaf8dc52d2d93b0bea11d2001 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 11 Oct 2023 19:28:23 +0530 Subject: [PATCH 11/11] fix the model card issue as `use_cuda_amp` is no more available (#26731) --- src/transformers/modelcard.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index 74acb55397a9ea..f1b2f70bc2ea61 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -895,10 +895,10 @@ def extract_hyperparameters_from_trainer(trainer): hyperparameters["num_epochs"] = trainer.args.num_train_epochs if trainer.args.fp16: - if trainer.use_cuda_amp: - hyperparameters["mixed_precision_training"] = "Native AMP" - elif trainer.use_apex: + if trainer.use_apex: hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}" + else: + hyperparameters["mixed_precision_training"] = "Native AMP" if trainer.args.label_smoothing_factor != 0.0: hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor