diff --git a/docs/source/en/main_classes/callback.md b/docs/source/en/main_classes/callback.md
index ccfdf256832472..87bf0d63af1fc2 100644
--- a/docs/source/en/main_classes/callback.md
+++ b/docs/source/en/main_classes/callback.md
@@ -25,7 +25,7 @@ Callbacks are "read only" pieces of code, apart from the [`TrainerControl`] obje
cannot change anything in the training loop. For customizations that require changes in the training loop, you should
subclass [`Trainer`] and override the methods you need (see [trainer](trainer) for examples).
-By default a [`Trainer`] will use the following callbacks:
+By default, `TrainingArguments.report_to` is set to `"all"`, so a [`Trainer`] will use the following callbacks.
- [`DefaultFlowCallback`] which handles the default behavior for logging, saving and evaluation.
- [`PrinterCallback`] or [`ProgressCallback`] to display progress and print the
@@ -45,6 +45,8 @@ By default a [`Trainer`] will use the following callbacks:
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
+If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`).
+
The main class that implements callbacks is [`TrainerCallback`]. It gets the
[`TrainingArguments`] used to instantiate the [`Trainer`], can access that
Trainer's internal state via [`TrainerState`], and can take some actions on the training loop via
diff --git a/examples/flax/image-captioning/run_image_captioning_flax.py b/examples/flax/image-captioning/run_image_captioning_flax.py
index bbc79977a46793..d8c89c1a242f14 100644
--- a/examples/flax/image-captioning/run_image_captioning_flax.py
+++ b/examples/flax/image-captioning/run_image_captioning_flax.py
@@ -381,7 +381,7 @@ def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="t
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
-) -> Callable[[int], jnp.array]:
+) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py
index 95e175d494bfe2..c61b24f4d7e615 100755
--- a/examples/flax/language-modeling/run_clm_flax.py
+++ b/examples/flax/language-modeling/run_clm_flax.py
@@ -326,7 +326,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
-) -> Callable[[int], jnp.array]:
+) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py
index 9cd90f285a02c4..0d35f302f8f37a 100644
--- a/examples/flax/question-answering/run_qa.py
+++ b/examples/flax/question-answering/run_qa.py
@@ -389,7 +389,7 @@ def cross_entropy_loss(logits, labels):
# region Create learning rate function
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
-) -> Callable[[int], jnp.array]:
+) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
diff --git a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py
index 4a2915a31ac7dc..8af835b6a4b4d3 100644
--- a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py
+++ b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py
@@ -360,7 +360,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
def create_learning_rate_fn(
num_train_steps: int, num_warmup_steps: int, learning_rate: float
-) -> Callable[[int], jnp.array]:
+) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py
index d57aa1769135db..782e9ee88c498f 100644
--- a/examples/flax/summarization/run_summarization_flax.py
+++ b/examples/flax/summarization/run_summarization_flax.py
@@ -409,7 +409,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
-) -> Callable[[int], jnp.array]:
+) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py
index b42a2565310799..1535ff8492781b 100755
--- a/examples/flax/text-classification/run_flax_glue.py
+++ b/examples/flax/text-classification/run_flax_glue.py
@@ -288,7 +288,7 @@ def cross_entropy_loss(logits, labels):
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
-) -> Callable[[int], jnp.array]:
+) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py
index f4b40220ff12c7..e06a85cb67c0a0 100644
--- a/examples/flax/token-classification/run_flax_ner.py
+++ b/examples/flax/token-classification/run_flax_ner.py
@@ -340,7 +340,7 @@ def cross_entropy_loss(logits, labels):
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
-) -> Callable[[int], jnp.array]:
+) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
diff --git a/examples/flax/vision/run_image_classification.py b/examples/flax/vision/run_image_classification.py
index 66505014ec57dd..4bed9b663f6f00 100644
--- a/examples/flax/vision/run_image_classification.py
+++ b/examples/flax/vision/run_image_classification.py
@@ -249,7 +249,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
-) -> Callable[[int], jnp.array]:
+) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py
index 74c83eee49f215..5f20aac6cbc9c8 100755
--- a/examples/pytorch/summarization/run_summarization.py
+++ b/examples/pytorch/summarization/run_summarization.py
@@ -264,7 +264,7 @@ class DataTrainingArguments:
},
)
source_prefix: Optional[str] = field(
- default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
diff --git a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py
index f54641408f80a2..c5a4a202534b87 100644
--- a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py
+++ b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py
@@ -283,7 +283,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
-) -> Callable[[int], jnp.array]:
+) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
diff --git a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py
index a6da8729f0ce3b..bb297e3e0db6e4 100644
--- a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py
+++ b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py
@@ -214,7 +214,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
-) -> Callable[[int], jnp.array]:
+) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py
index 18ccdb2835b411..3bd85568dcb714 100644
--- a/src/transformers/generation/configuration_utils.py
+++ b/src/transformers/generation/configuration_utils.py
@@ -227,6 +227,20 @@ class GenerationConfig(PushToHubMixin):
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
+ > Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192)
+
+ num_assistant_tokens (`int`, *optional*, defaults to 5):
+ Defines the number of _speculative tokens_ that shall be generated by the assistant model before being
+ checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation
+ more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant
+ model requires lots of corrections, lower speed-ups are reached.
+
+ num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`):
+ Defines the schedule at which max assistant tokens shall be changed during inference.
+ - `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else
+ reduce by 1
+ - `"constant"`: `num_assistant_tokens` stays unchanged during generation
+
> Wild card
generation_kwargs:
@@ -294,6 +308,10 @@ def __init__(self, **kwargs):
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
+ # Assistant generation
+ self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
+ self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
+
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index 3b1bef6f040084..a104113af891ff 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -1241,6 +1241,10 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}
+ # allow assistant_encoder_outputs to be passed if we're doing assisted generating
+ if "assistant_encoder_outputs" in model_kwargs:
+ model_args |= {"assistant_encoder_outputs"}
+
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
@@ -1297,6 +1301,43 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de
UserWarning,
)
+ def _extend_attention_mask(self, model_kwargs: Dict[str, Any], new_mask_length: int) -> Dict[str, Any]:
+ if self.config.is_encoder_decoder:
+ key = "decoder_attention_mask"
+ else:
+ key = "attention_mask"
+
+ if key not in model_kwargs:
+ return model_kwargs
+
+ mask = model_kwargs[key]
+ mask_extension_length = new_mask_length - mask.shape[1]
+
+ if mask_extension_length < 0:
+ raise ValueError("Cannot extend attention mask to a length less than it already is")
+
+ model_kwargs[key] = torch.cat(
+ [mask, mask.new_ones((mask.shape[0], mask_extension_length))],
+ dim=-1,
+ )
+
+ return model_kwargs
+
+ def _extend_token_type_ids(self, model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
+ if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
+ return model_kwargs
+
+ token_type_ids = model_kwargs["token_type_ids"]
+ final_token_type = token_type_ids[:, -1].unsqueeze(-1)
+ extension_length = new_length - token_type_ids.shape[1]
+ token_type_copies = final_token_type.repeat(1, extension_length)
+ model_kwargs["token_type_ids"] = torch.cat(
+ [model_kwargs["token_type_ids"], token_type_copies],
+ dim=-1,
+ )
+
+ return model_kwargs
+
@torch.no_grad()
def generate(
self,
@@ -1575,7 +1616,7 @@ def generate(
raise ValueError("assisted generate requires `use_cache=True`")
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
- if assistant_model.config.is_encoder_decoder:
+ if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs:
assistant_model_kwargs = copy.deepcopy(model_kwargs)
inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs(
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs
@@ -4310,8 +4351,14 @@ def assisted_decoding(
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# Assistant: initialize assistant-related variables
- if not hasattr(assistant_model, "max_assistant_tokens"):
- assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
+ if hasattr(assistant_model, "num_assistant_tokens"):
+ warnings.warn(
+ "Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be removed in v.37. Make sure to set `num_assistant_tokens` via the generation_config instead.",
+ FutureWarning,
+ )
+ num_assistant_tokens = assistant_model.num_assistant_tokens
+ else:
+ num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
@@ -4384,26 +4431,23 @@ def assisted_decoding(
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
# need access to the assistant cache to secure strong speedups.
candidate_input_ids = input_ids
- for _ in range(int(assistant_model.max_assistant_tokens)):
+ for _ in range(int(num_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits
if "assistant_past_key_values" in model_kwargs:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
- assist_attn = torch.ones_like(candidate_input_ids)
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=assist_inputs,
- decoder_attention_mask=assist_attn,
past_key_values=model_kwargs["assistant_past_key_values"],
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
assistant_model_outputs = assistant_model(
assist_inputs,
- attention_mask=assist_attn,
past_key_values=model_kwargs["assistant_past_key_values"],
)
else:
@@ -4441,61 +4485,35 @@ def assisted_decoding(
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
# we use this forward pass to also pick the subsequent logits in the original model.
- # 2.1. Run a forward pass on the candidate sequence
- if "past_key_values" in model_kwargs:
- model_attn = torch.ones_like(candidate_input_ids)
- model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
- if self.config.is_encoder_decoder:
- outputs = self(
- decoder_input_ids=model_input_ids,
- decoder_attention_mask=model_attn,
- past_key_values=model_kwargs["past_key_values"],
- encoder_outputs=model_kwargs["encoder_outputs"],
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- use_cache=True,
- )
- else:
- outputs = self(
- model_input_ids,
- attention_mask=model_attn,
- past_key_values=model_kwargs["past_key_values"],
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- use_cache=True,
- )
- else:
- if self.config.is_encoder_decoder:
- outputs = self(
- decoder_input_ids=candidate_input_ids,
- encoder_outputs=model_kwargs["encoder_outputs"],
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- use_cache=True,
- )
- else:
- outputs = self(
- candidate_input_ids,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- use_cache=True,
- )
+ # 2.1. Prepare the model inputs
+ candidate_kwargs = copy.copy(model_kwargs)
+ candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
+ candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
+
+ model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
- # 2.2. Process the new logits
+ # 2.2. Run a forward pass on the candidate sequence
+ outputs = self(
+ **model_inputs,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ # 2.3. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
- for i in range(candidate_length):
+ for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
if len(logits_warper) > 0:
- for i in range(candidate_length):
+ for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
# 3. Obtain the next tokens from the original model logits.
if do_sample:
- probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
+ probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
- selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)
+ selected_tokens = new_logits.argmax(dim=-1)
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
@@ -4529,13 +4547,13 @@ def assisted_decoding(
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
- if n_matches == int(assistant_model.max_assistant_tokens):
- assistant_model.max_assistant_tokens += 2.0
- else:
- assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
+ if assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic":
+ if n_matches == int(num_assistant_tokens):
+ num_assistant_tokens += 2.0
+ else:
+ num_assistant_tokens = max(1.0, num_assistant_tokens - 1.0)
# Assistant: main logic end
-
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py
index 74acb55397a9ea..f1b2f70bc2ea61 100644
--- a/src/transformers/modelcard.py
+++ b/src/transformers/modelcard.py
@@ -895,10 +895,10 @@ def extract_hyperparameters_from_trainer(trainer):
hyperparameters["num_epochs"] = trainer.args.num_train_epochs
if trainer.args.fp16:
- if trainer.use_cuda_amp:
- hyperparameters["mixed_precision_training"] = "Native AMP"
- elif trainer.use_apex:
+ if trainer.use_apex:
hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}"
+ else:
+ hyperparameters["mixed_precision_training"] = "Native AMP"
if trainer.args.label_smoothing_factor != 0.0:
hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor
diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py
index bdafb6347755d3..649719e0eefa5d 100644
--- a/src/transformers/models/bark/modeling_bark.py
+++ b/src/transformers/models/bark/modeling_bark.py
@@ -483,9 +483,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
position_ids = kwargs.get("position_ids", None)
if past_key_values is not None:
- # only last token for inputs_ids if past is defined in kwargs
+ # Omit tokens covered by past_key_values
seq_len = input_ids.shape[1]
- input_ids = input_ids[:, [-1]]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# input_embeds have already been used and is not required anymore
input_embeds = None
@@ -507,7 +516,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None
diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py
index dd4223a47f4735..309717310e701f 100755
--- a/src/transformers/models/bart/modeling_bart.py
+++ b/src/transformers/models/bart/modeling_bart.py
@@ -1717,7 +1717,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
@@ -2208,7 +2217,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py
index 9858eb2d1bf416..6abfcdc398422f 100644
--- a/src/transformers/models/bart/modeling_flax_bart.py
+++ b/src/transformers/models/bart/modeling_flax_bart.py
@@ -217,7 +217,7 @@
"""
-def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py
index 29846b8051f867..1b0fad3f9d6546 100755
--- a/src/transformers/models/bert/modeling_bert.py
+++ b/src/transformers/models/bert/modeling_bert.py
@@ -1282,7 +1282,16 @@ def prepare_inputs_for_generation(
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {
"input_ids": input_ids,
diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py
index 99dfa2a0e2f9ab..bb2af0e0602aba 100644
--- a/src/transformers/models/bert/modeling_flax_bert.py
+++ b/src/transformers/models/bert/modeling_flax_bert.py
@@ -295,7 +295,7 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask,
- key_value_states: Optional[jnp.array] = None,
+ key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic=True,
output_attentions: bool = False,
diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py
index f245ac155e75ca..abe2d828b28bb9 100755
--- a/src/transformers/models/bert_generation/modeling_bert_generation.py
+++ b/src/transformers/models/bert_generation/modeling_bert_generation.py
@@ -993,9 +993,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py
index 867aca67e99e8c..e266b1a67b7d41 100755
--- a/src/transformers/models/big_bird/modeling_big_bird.py
+++ b/src/transformers/models/big_bird/modeling_big_bird.py
@@ -2628,9 +2628,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py
index afdac2645f2652..c6d8b7c1612ec9 100644
--- a/src/transformers/models/big_bird/modeling_flax_big_bird.py
+++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py
@@ -316,7 +316,7 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask,
- key_value_states: Optional[jnp.array] = None,
+ key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic=True,
output_attentions: bool = False,
diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
index a32f3ecde76fdb..4e279f9dc059fe 100755
--- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
+++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
@@ -2627,7 +2627,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py
index 7534ed17fe849a..ca084db5c7d0b9 100755
--- a/src/transformers/models/biogpt/modeling_biogpt.py
+++ b/src/transformers/models/biogpt/modeling_biogpt.py
@@ -544,7 +544,11 @@ def forward(
inputs_embeds = self.embed_tokens(input) * self.embed_scale
if attention_mask is None:
- attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
+ attention_mask = torch.ones(
+ (inputs_embeds.shape[0], inputs_embeds.shape[1] + past_key_values_length),
+ dtype=torch.bool,
+ device=inputs_embeds.device,
+ )
elif attention_mask.shape[1] != past_key_values_length + input_shape[1]:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
@@ -729,9 +733,18 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs
):
- # only last token for inputs_ids if past is defined in kwargs
- if past_key_values:
- input_ids = input_ids[:, -1].unsqueeze(-1)
+ # only last tokens for inputs_ids if past is defined in kwargs
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py
index bdb8c52a552041..1db81905210b63 100755
--- a/src/transformers/models/blenderbot/modeling_blenderbot.py
+++ b/src/transformers/models/blenderbot/modeling_blenderbot.py
@@ -1392,7 +1392,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
@@ -1622,7 +1631,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py
index 1035272fd05350..61239335be3b63 100644
--- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py
+++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py
@@ -204,7 +204,7 @@
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
index a1e888aec90807..129de3dd1456e3 100755
--- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
+++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
@@ -1359,7 +1359,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
@@ -1589,7 +1598,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
index 2bf8b59e2757bc..b5272fb3bca9e2 100644
--- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
+++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
@@ -216,7 +216,7 @@
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py
index 2ae3ac053beab9..49b958afc2ebae 100644
--- a/src/transformers/models/blip/modeling_blip_text.py
+++ b/src/transformers/models/blip/modeling_blip_text.py
@@ -920,7 +920,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {
"input_ids": input_ids,
diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py
index d12ec1724f7097..d90bb6ad8fdfd5 100644
--- a/src/transformers/models/bloom/modeling_bloom.py
+++ b/src/transformers/models/bloom/modeling_bloom.py
@@ -844,9 +844,18 @@ def prepare_inputs_for_generation(
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
- # only last token for input_ids if past is not None
- if past_key_values:
- input_ids = input_ids[:, -1].unsqueeze(-1)
+ # only last tokens for input_ids if past is not None
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py
index 4635c061980b53..8d7d279579e3e9 100644
--- a/src/transformers/models/camembert/modeling_camembert.py
+++ b/src/transformers/models/camembert/modeling_camembert.py
@@ -1542,9 +1542,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py
index 93d5aa7ee47650..172a45544bac0d 100644
--- a/src/transformers/models/codegen/modeling_codegen.py
+++ b/src/transformers/models/codegen/modeling_codegen.py
@@ -617,11 +617,20 @@ def set_output_embeddings(self, new_embeddings):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
- # only last token for inputs_ids if past is defined in kwargs
+ # Omit tokens covered by past_key_values
if past_key_values:
- input_ids = input_ids[:, -1].unsqueeze(-1)
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None:
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
@@ -631,7 +640,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
return {
"input_ids": input_ids,
diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py
index 70cd4ec0597a14..cec68de07dda75 100644
--- a/src/transformers/models/ctrl/modeling_ctrl.py
+++ b/src/transformers/models/ctrl/modeling_ctrl.py
@@ -526,9 +526,18 @@ def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
- # only last token for inputs_ids if past is defined in kwargs
- if past_key_values:
- input_ids = input_ids[:, -1].unsqueeze(-1)
+ # only last tokens for inputs_ids if past is defined in kwargs
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}
diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py
index 7cbaee692564b4..a521ccb39aaf0c 100644
--- a/src/transformers/models/data2vec/modeling_data2vec_text.py
+++ b/src/transformers/models/data2vec/modeling_data2vec_text.py
@@ -1009,9 +1009,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py
index c975aa40877c26..6853f5333f137c 100644
--- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py
+++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py
@@ -843,8 +843,17 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
- if past_key_values:
- input_ids = input_ids[:, -1:]
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
@@ -852,7 +861,7 @@ def prepare_inputs_for_generation(
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py
index c06d306c1a241d..da3ee8e51d3602 100644
--- a/src/transformers/models/electra/modeling_electra.py
+++ b/src/transformers/models/electra/modeling_electra.py
@@ -1667,9 +1667,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py
index 32e76b8b586f4f..8fced6ff1ea25e 100644
--- a/src/transformers/models/electra/modeling_flax_electra.py
+++ b/src/transformers/models/electra/modeling_flax_electra.py
@@ -263,7 +263,7 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask,
- key_value_states: Optional[jnp.array] = None,
+ key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic=True,
output_attentions: bool = False,
@@ -1228,13 +1228,13 @@ def __call__(self, hidden_states, cls_index=None, deterministic: bool = True):
Compute a single vector summary of a sequence hidden states.
Args:
- hidden_states (`jnp.array` of shape `[batch_size, seq_len, hidden_size]`):
+ hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`):
The hidden states of the last layer.
- cls_index (`jnp.array` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
+ cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
Returns:
- `jnp.array`: The summary of the sequence hidden states.
+ `jnp.ndarray`: The summary of the sequence hidden states.
"""
# NOTE: this doest "first" type summary always
output = hidden_states[:, 0]
diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py
index 7ee6f4381290ae..d55155f80093bc 100644
--- a/src/transformers/models/ernie/modeling_ernie.py
+++ b/src/transformers/models/ernie/modeling_ernie.py
@@ -1223,7 +1223,16 @@ def prepare_inputs_for_generation(
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {
"input_ids": input_ids,
diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py
index ab29322613bea3..33b9fdde739e58 100644
--- a/src/transformers/models/falcon/modeling_falcon.py
+++ b/src/transformers/models/falcon/modeling_falcon.py
@@ -1228,7 +1228,16 @@ def prepare_inputs_for_generation(
**kwargs,
) -> dict:
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if not self.transformer.use_alibi and attention_mask is not None and position_ids is None:
@@ -1236,7 +1245,7 @@ def prepare_inputs_for_generation(
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
return {
"input_ids": input_ids,
diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py
index 714f0351b3e4df..838e7ca2992520 100644
--- a/src/transformers/models/gpt2/modeling_gpt2.py
+++ b/src/transformers/models/gpt2/modeling_gpt2.py
@@ -1005,11 +1005,20 @@ def set_output_embeddings(self, new_embeddings):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
- # only last token for inputs_ids if past is defined in kwargs
+ # Omit tokens covered by past_key_values
if past_key_values:
- input_ids = input_ids[:, -1].unsqueeze(-1)
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None:
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
@@ -1019,7 +1028,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None
@@ -1038,6 +1047,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
"token_type_ids": token_type_ids,
}
)
+
return model_inputs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@@ -1201,11 +1211,20 @@ def set_output_embeddings(self, new_embeddings):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
- # only last token for inputs_ids if past is defined in kwargs
+ # Omit tokens covered by past_key_values
if past_key_values:
- input_ids = input_ids[:, -1].unsqueeze(-1)
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None:
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
@@ -1215,7 +1234,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None
diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
index d58e00af1dac13..be90f61e45bf1b 100644
--- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -737,11 +737,23 @@ def set_output_embeddings(self, new_embeddings):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
- # only last token for inputs_ids if past is defined in kwargs
+ # Omit tokens covered by past_key_values
if past_key_values:
- input_ids = input_ids[:, -1].unsqueeze(-1)
+ if self.config.multi_query:
+ past_length = past_key_values[0].shape[1]
+ else:
+ past_length = past_key_values[0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None:
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
@@ -751,7 +763,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None
diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
index 6364cfc316220a..3ad49554c0ac8f 100755
--- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py
+++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
@@ -680,11 +680,20 @@ def set_output_embeddings(self, new_embeddings):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
- # only last token for inputs_ids if past is defined in kwargs
+ # Omit tokens covered by past_key_values
if past_key_values:
- input_ids = input_ids[:, -1].unsqueeze(-1)
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None:
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
@@ -694,7 +703,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
index b4aa4154459cf7..9391805a77b851 100755
--- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py
+++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -808,10 +808,21 @@ def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
input_shape = input_ids.shape
+ print(input_shape)
+ print(past_key_values[0][0].shape if past_key_values is not None else "no pkv")
# cut decoder_input_ids if past is used
- if past_key_values and past_key_values[0] is not None:
- input_ids = input_ids[:, -1:]
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
@@ -819,7 +830,7 @@ def prepare_inputs_for_generation(
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
@@ -830,7 +841,7 @@ def prepare_inputs_for_generation(
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
-
+ print(position_ids.shape)
model_inputs.update(
{
"attention_mask": attention_mask,
diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py
index a93bdeaacd9d23..6b5607f235b1a6 100644
--- a/src/transformers/models/gptj/modeling_gptj.py
+++ b/src/transformers/models/gptj/modeling_gptj.py
@@ -785,11 +785,20 @@ def set_output_embeddings(self, new_embeddings):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
- # only last token for inputs_ids if past is defined in kwargs
+ # Omit tokens covered by past_key_values
if past_key_values:
- input_ids = input_ids[:, -1].unsqueeze(-1)
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None:
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
@@ -799,7 +808,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py
index 5f193a137b00cc..54edcd30fc870d 100755
--- a/src/transformers/models/imagegpt/modeling_imagegpt.py
+++ b/src/transformers/models/imagegpt/modeling_imagegpt.py
@@ -912,11 +912,20 @@ def set_output_embeddings(self, new_embeddings):
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[bool] = None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
- # only last token for inputs_ids if past is defined in kwargs
+ # Omit tokens covered by past_key_values
if past_key_values:
- input_ids = input_ids[:, -1].unsqueeze(-1)
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None:
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
@@ -926,7 +935,7 @@ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None
return {
diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py
index 5bebd936d65e15..f3da8ab4cdc242 100644
--- a/src/transformers/models/llama/configuration_llama.py
+++ b/src/transformers/models/llama/configuration_llama.py
@@ -58,11 +58,6 @@ class LlamaConfig(PretrainedConfig):
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
- pretraining_tp (`int`, *optional*, defaults to `1`):
- Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
- document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
- necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
- issue](https://github.com/pytorch/pytorch/issues/76232).
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
@@ -70,12 +65,23 @@ class LlamaConfig(PretrainedConfig):
Llama 2 up to 4096, CodeLlama up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- rms_norm_eps (`float`, *optional*, defaults to 1e-12):
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
- tie_word_embeddings(`bool`, *optional*, defaults to `False`):
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
@@ -87,10 +93,9 @@ class LlamaConfig(PretrainedConfig):
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
- attention_bias (`bool`, defaults to `False`):
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
- Example:
```python
>>> from transformers import LlamaModel, LlamaConfig
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index aef4fdab90616a..4eee4c205acd48 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -1080,8 +1080,17 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
- if past_key_values:
- input_ids = input_ids[:, -1:]
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
@@ -1089,7 +1098,7 @@ def prepare_inputs_for_generation(
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py
index 6b7bc7c28fcf7b..36e273d5725a4f 100644
--- a/src/transformers/models/longt5/modeling_flax_longt5.py
+++ b/src/transformers/models/longt5/modeling_flax_longt5.py
@@ -56,7 +56,7 @@
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
@@ -545,7 +545,7 @@ def __call__(
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
- key_states, value_states, attention_attention_mask = self._concatenate_to_cache(
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py
index d08ed83af07ea1..4e8aef0678367d 100644
--- a/src/transformers/models/longt5/modeling_longt5.py
+++ b/src/transformers/models/longt5/modeling_longt5.py
@@ -2103,9 +2103,18 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
):
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {
"decoder_input_ids": input_ids,
diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py
index 88e543b54b5249..6db8bbb5213b14 100755
--- a/src/transformers/models/m2m_100/modeling_m2m_100.py
+++ b/src/transformers/models/m2m_100/modeling_m2m_100.py
@@ -1367,7 +1367,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py
index a713fdb05dcfd9..5197c906895917 100644
--- a/src/transformers/models/marian/modeling_flax_marian.py
+++ b/src/transformers/models/marian/modeling_flax_marian.py
@@ -227,7 +227,7 @@ def create_sinusoidal_positions(n_pos, dim):
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py
index b4e3aac5be0b42..69de5b2e7d0e6f 100755
--- a/src/transformers/models/marian/modeling_marian.py
+++ b/src/transformers/models/marian/modeling_marian.py
@@ -1509,7 +1509,16 @@ def prepare_inputs_for_generation(
) -> Dict:
# cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
@@ -1740,7 +1749,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py
index ca6bea40337257..530c66a0c80b36 100755
--- a/src/transformers/models/markuplm/modeling_markuplm.py
+++ b/src/transformers/models/markuplm/modeling_markuplm.py
@@ -948,7 +948,16 @@ def prepare_inputs_for_generation(
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {
"input_ids": input_ids,
diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py
index 276f94aebdbb9e..b53ad8848dd3c2 100755
--- a/src/transformers/models/mbart/modeling_mbart.py
+++ b/src/transformers/models/mbart/modeling_mbart.py
@@ -1413,7 +1413,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
@@ -1897,7 +1906,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py
index 1c1eeff667d44f..5d0ad6e3410c8f 100755
--- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py
+++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py
@@ -1251,9 +1251,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py
index a55f16a23d5b52..62610ceb41ac66 100644
--- a/src/transformers/models/mistral/modeling_mistral.py
+++ b/src/transformers/models/mistral/modeling_mistral.py
@@ -1083,8 +1083,18 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
+ # Omit tokens covered by past_key_values
if past_key_values:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
@@ -1092,7 +1102,7 @@ def prepare_inputs_for_generation(
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py
index 0c608dbd2a93bc..d760bec9854a8e 100644
--- a/src/transformers/models/mpt/modeling_mpt.py
+++ b/src/transformers/models/mpt/modeling_mpt.py
@@ -605,9 +605,18 @@ def prepare_inputs_for_generation(
use_cache: Optional[bool] = None,
**kwargs,
) -> dict:
- # only last token for input_ids if past is not None
- if past_key_values:
- input_ids = input_ids[:, -1].unsqueeze(-1)
+ # only last tokens for input_ids if past is not None
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
diff --git a/src/transformers/models/mt5/modeling_flax_mt5.py b/src/transformers/models/mt5/modeling_flax_mt5.py
index 86ddf477ffab56..0046e02ca7308e 100644
--- a/src/transformers/models/mt5/modeling_flax_mt5.py
+++ b/src/transformers/models/mt5/modeling_flax_mt5.py
@@ -27,7 +27,7 @@
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py
index 3d03503ddd402e..0de50afe9d6d24 100644
--- a/src/transformers/models/mt5/modeling_mt5.py
+++ b/src/transformers/models/mt5/modeling_mt5.py
@@ -1836,9 +1836,18 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
):
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {
"decoder_input_ids": input_ids,
diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py
index f178a6762005e6..16766e953c8574 100644
--- a/src/transformers/models/musicgen/modeling_musicgen.py
+++ b/src/transformers/models/musicgen/modeling_musicgen.py
@@ -1995,9 +1995,17 @@ def prepare_inputs_for_generation(
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
- # cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py
index 21a82f95c33383..5c1ed05249ef5c 100644
--- a/src/transformers/models/mvp/modeling_mvp.py
+++ b/src/transformers/models/mvp/modeling_mvp.py
@@ -1572,7 +1572,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
@@ -2054,7 +2063,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py
index f37f64627dfad4..3701bbecef2e73 100644
--- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py
+++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py
@@ -1808,7 +1808,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py
index d24211f039365e..8f3f246524348d 100644
--- a/src/transformers/models/opt/modeling_opt.py
+++ b/src/transformers/models/opt/modeling_opt.py
@@ -981,8 +981,17 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
- if past_key_values:
- input_ids = input_ids[:, -1:]
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py
index c5189746b1065f..17772251bf0629 100644
--- a/src/transformers/models/pegasus/modeling_flax_pegasus.py
+++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py
@@ -210,7 +210,7 @@
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py
index 67934520fbb6d9..55856f7b06b6be 100755
--- a/src/transformers/models/pegasus/modeling_pegasus.py
+++ b/src/transformers/models/pegasus/modeling_pegasus.py
@@ -1466,7 +1466,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
@@ -1719,7 +1728,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py
index def82bdbaa7182..e87e9c7164ab44 100755
--- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py
+++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py
@@ -1671,7 +1671,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py
index c09657c065f2be..a0bc5726382336 100644
--- a/src/transformers/models/persimmon/modeling_persimmon.py
+++ b/src/transformers/models/persimmon/modeling_persimmon.py
@@ -847,8 +847,17 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
- if past_key_values:
- input_ids = input_ids[:, -1:]
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
@@ -856,7 +865,7 @@ def prepare_inputs_for_generation(
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py
index 288e31a126e675..e19761803e267d 100644
--- a/src/transformers/models/pix2struct/modeling_pix2struct.py
+++ b/src/transformers/models/pix2struct/modeling_pix2struct.py
@@ -1798,9 +1798,18 @@ def prepare_inputs_for_generation(
if decoder_attention_mask is None:
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {
"flattened_patches": flattened_patches,
diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py
index 93532f4b0d8c22..3a880839236d43 100644
--- a/src/transformers/models/plbart/modeling_plbart.py
+++ b/src/transformers/models/plbart/modeling_plbart.py
@@ -1379,7 +1379,16 @@ def prepare_inputs_for_generation(
) -> Dict[str, Any]:
# cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
@@ -1739,7 +1748,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py
index 47546930ebdfc1..fead8fc0cf7f42 100755
--- a/src/transformers/models/qdqbert/modeling_qdqbert.py
+++ b/src/transformers/models/qdqbert/modeling_qdqbert.py
@@ -1151,9 +1151,18 @@ def prepare_inputs_for_generation(
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py
index 745be26ebfc97f..235bff89f8a354 100755
--- a/src/transformers/models/rembert/modeling_rembert.py
+++ b/src/transformers/models/rembert/modeling_rembert.py
@@ -1147,9 +1147,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py
index 845fcea4429787..6bc72f12b40790 100644
--- a/src/transformers/models/roberta/modeling_flax_roberta.py
+++ b/src/transformers/models/roberta/modeling_flax_roberta.py
@@ -256,7 +256,7 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask,
- key_value_states: Optional[jnp.array] = None,
+ key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic=True,
output_attentions: bool = False,
diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py
index 67e0fee422c4cc..6d4cc991d22ca0 100644
--- a/src/transformers/models/roberta/modeling_roberta.py
+++ b/src/transformers/models/roberta/modeling_roberta.py
@@ -1007,9 +1007,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py
index b7c347693d951b..e98897993742db 100644
--- a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py
+++ b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py
@@ -258,7 +258,7 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask,
- key_value_states: Optional[jnp.array] = None,
+ key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic=True,
output_attentions: bool = False,
diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py
index ddd87fa9ce0c1e..da1cd6331bc314 100644
--- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py
+++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py
@@ -1014,9 +1014,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py
index 35d4be9f20e0c0..a5b1b63050b1ef 100644
--- a/src/transformers/models/roc_bert/modeling_roc_bert.py
+++ b/src/transformers/models/roc_bert/modeling_roc_bert.py
@@ -1560,9 +1560,18 @@ def prepare_inputs_for_generation(
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
if input_shape_ids is not None:
input_shape_ids = input_shape_ids[:, -1:]
if input_pronunciation_ids is not None:
diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py
index 2c3feeda12708c..b9c36a305ff1cd 100644
--- a/src/transformers/models/roformer/modeling_roformer.py
+++ b/src/transformers/models/roformer/modeling_roformer.py
@@ -1178,9 +1178,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
index bfd801b242719f..f9b5dec4209273 100755
--- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
+++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
@@ -963,7 +963,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py
index 48334deb377865..c4de7de09089ca 100644
--- a/src/transformers/models/speecht5/modeling_speecht5.py
+++ b/src/transformers/models/speecht5/modeling_speecht5.py
@@ -2508,7 +2508,16 @@ def prepare_inputs_for_generation(
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"encoder_outputs": encoder_outputs,
diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py
index 2cf1d33a51139d..45a7aa718cf026 100644
--- a/src/transformers/models/swin/modeling_swin.py
+++ b/src/transformers/models/swin/modeling_swin.py
@@ -942,6 +942,12 @@ def _set_gradient_checkpointing(self, module, value=False):
@add_start_docstrings(
"The bare Swin Model transformer outputting raw hidden-states without any specific head on top.",
SWIN_START_DOCSTRING,
+ """
+ add_pooling_layer (`bool`, *optional*, defaults to `True`):
+ Whether or not to apply pooling layer.
+ use_mask_token (`bool`, *optional*, defaults to `False`):
+ Whether or not to create and apply mask tokens in the embedding layer.
+ """,
)
class SwinModel(SwinPreTrainedModel):
def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py
index 6c2fe8269782b6..541db4382dd649 100644
--- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py
+++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py
@@ -1727,9 +1727,18 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
):
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {
"decoder_input_ids": input_ids,
diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py
index b2a7181421527c..09575fdcc3b82e 100644
--- a/src/transformers/models/t5/modeling_flax_t5.py
+++ b/src/transformers/models/t5/modeling_flax_t5.py
@@ -56,7 +56,7 @@
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
@@ -405,7 +405,7 @@ def __call__(
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
- key_states, value_states, attention_attention_mask = self._concatenate_to_cache(
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py
index e6d9deefa14639..9716c7ffaffa0c 100644
--- a/src/transformers/models/t5/modeling_t5.py
+++ b/src/transformers/models/t5/modeling_t5.py
@@ -1810,9 +1810,18 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
):
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {
"decoder_input_ids": input_ids,
diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py
index 50829592a02e72..c0541814be466e 100644
--- a/src/transformers/models/trocr/modeling_trocr.py
+++ b/src/transformers/models/trocr/modeling_trocr.py
@@ -1003,7 +1003,16 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_ids.shape)
if past_key_values:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py
index 8323054144f549..bd35111be16e9f 100644
--- a/src/transformers/models/umt5/modeling_umt5.py
+++ b/src/transformers/models/umt5/modeling_umt5.py
@@ -1307,9 +1307,18 @@ def prepare_inputs_for_generation(
encoder_outputs=None,
**kwargs,
):
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {
"decoder_input_ids": input_ids,
diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py
index 0f158fb602084a..ffcaeb53ad7153 100644
--- a/src/transformers/models/whisper/modeling_flax_whisper.py
+++ b/src/transformers/models/whisper/modeling_flax_whisper.py
@@ -14,6 +14,7 @@
# limitations under the License.
""" Flax whisper model."""
+import math
import random
from functools import partial
from typing import Optional, Tuple
@@ -58,6 +59,19 @@
remat = nn_partitioning.remat
+def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array:
+ """Returns sinusoids for positional embedding"""
+ length, channels = shape
+ if channels % 2 != 0:
+ raise ValueError(
+ f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
+ )
+ log_timescale_increment = math.log(10000) / (channels // 2 - 1)
+ inv_timescales = jnp.exp(-log_timescale_increment * jnp.arange(channels // 2))
+ scaled_time = jnp.arange(length).reshape(-1, 1) * inv_timescales.reshape(1, -1)
+ return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype)
+
+
WHISPER_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -649,7 +663,13 @@ def setup(self) -> None:
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
- self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype)
+
+ self.embed_positions = nn.Embed(
+ self.config.max_source_positions,
+ self.config.d_model,
+ dtype=self.dtype,
+ embedding_init=sinusoidal_embedding_init,
+ )
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
@@ -673,6 +693,8 @@ def __call__(
hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
+ # freeze the sinusoidal embeddings by stopping the back-prop
+ embed_positions = jax.lax.stop_gradient(embed_positions)
hidden_states = hidden_states + embed_positions
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py
index 27b6ff63cedacb..1dfe413da2ae46 100644
--- a/src/transformers/models/whisper/modeling_tf_whisper.py
+++ b/src/transformers/models/whisper/modeling_tf_whisper.py
@@ -59,6 +59,19 @@
LARGE_NEGATIVE = -1e8
+def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor:
+ """Returns sinusoids for positional embedding"""
+ length, channels = shape
+ if channels % 2 != 0:
+ raise ValueError(
+ f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
+ )
+ log_timescale_increment = math.log(10000) / (channels // 2 - 1)
+ inv_timescales = tf.exp(-log_timescale_increment * tf.range(channels // 2, dtype=tf.float32))
+ scaled_time = tf.reshape(tf.range(length, dtype=tf.float32), (-1, 1)) * tf.reshape(inv_timescales, (1, -1))
+ return tf.cast(tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1), dtype)
+
+
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
@@ -117,16 +130,25 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
class TFWhisperPositionalEmbedding(tf.keras.layers.Layer):
- def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None, **kwargs):
+ def __init__(
+ self,
+ num_positions: int,
+ embedding_dim: int,
+ padding_idx: Optional[int] = None,
+ embedding_initializer=None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.num_positions = num_positions
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
+ self.embedding_initializer = tf.keras.initializers.get(embedding_initializer)
def build(self, input_shape):
self.weight = self.add_weight(
name="weight",
shape=[self.num_positions, self.embedding_dim],
+ initializer=self.embedding_initializer,
trainable=True,
)
super().build(input_shape)
@@ -620,8 +642,12 @@ def __init__(self, config: WhisperConfig, **kwargs):
self.conv2 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2")
self.embed_positions = TFWhisperPositionalEmbedding(
- self.max_source_positions, self.embed_dim, name="embed_positions"
+ num_positions=self.max_source_positions,
+ embedding_dim=self.embed_dim,
+ embedding_initializer=sinusoidal_embedding_init,
+ name="embed_positions",
)
+ self.embed_positions.trainable = False
self.encoder_layers = [TFWhisperEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py
index b15d05e45a2f25..b0f39cb8d32435 100644
--- a/src/transformers/models/whisper/modeling_whisper.py
+++ b/src/transformers/models/whisper/modeling_whisper.py
@@ -20,6 +20,7 @@
import numpy as np
import torch
import torch.utils.checkpoint
+import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
@@ -73,6 +74,18 @@ def _get_unpad_data(padding_mask):
)
+def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
+ """Returns sinusoids for positional embedding"""
+ if channels % 2 != 0:
+ raise ValueError(
+ f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
+ )
+ log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
+ scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
+ return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
+
+
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
@@ -913,6 +926,10 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, WhisperEncoder):
+ with torch.no_grad():
+ embed_positions = module.embed_positions.weight
+ embed_positions.copy_(sinusoids(*embed_positions.shape))
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (WhisperDecoder, WhisperEncoder)):
@@ -1080,6 +1097,7 @@ def __init__(self, config: WhisperConfig):
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
+ self.embed_positions.requires_grad_(False)
self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
@@ -2038,9 +2056,17 @@ def prepare_inputs_for_generation(
attention_mask=None,
**kwargs,
):
- # cut decoder_input_ids if past is used
if past_key_values is not None:
- decoder_input_ids = decoder_input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
return {
"encoder_outputs": encoder_outputs,
diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py
index 5f8778f98dcd2d..0c769dbbb5f324 100755
--- a/src/transformers/models/xglm/modeling_xglm.py
+++ b/src/transformers/models/xglm/modeling_xglm.py
@@ -851,21 +851,30 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
):
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
+
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[:, -1].unsqueeze(-1)
+ position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
-
- if past_key_values:
- input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
diff --git a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
index f6f39ee93ba687..fb03c390f6f419 100644
--- a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
+++ b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
@@ -266,7 +266,7 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask,
- key_value_states: Optional[jnp.array] = None,
+ key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic=True,
output_attentions: bool = False,
diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
index 761e96a11b7344..da454b1e3331f9 100644
--- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
+++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
@@ -1011,9 +1011,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
index 025bab3887c0c7..26e0361abdb523 100644
--- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
+++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
@@ -970,9 +970,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py
index 61002bd2772e52..28fddc2fdbd6b5 100644
--- a/src/transformers/models/xmod/modeling_xmod.py
+++ b/src/transformers/models/xmod/modeling_xmod.py
@@ -1118,9 +1118,18 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
- # cut decoder_input_ids if past is used
+ # cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
- input_ids = input_ids[:, -1:]
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py
index 635ab656ff699c..96cb467bcbeb7c 100644
--- a/src/transformers/training_args.py
+++ b/src/transformers/training_args.py
@@ -2345,10 +2345,11 @@ def set_logging(
Logger log level to use on the main process. Possible choices are the log levels as strings: `"debug"`,
`"info"`, `"warning"`, `"error"` and `"critical"`, plus a `"passive"` level which doesn't set anything
and lets the application set the level.
- report_to (`str` or `List[str]`, *optional*, defaults to `"none"`):
+ report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
- `"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. Use `"all"` to report
- to all integrations installed, `"none"` for no integrations.
+ `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"flyte"`, `"mlflow"`, `"neptune"`,
+ `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
+ integrations.
first_step (`bool`, *optional*, defaults to `False`):
Whether to log and evaluate the first `global_step` or not.
nan_inf_filter (`bool`, *optional*, defaults to `True`):
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
index 83263a6a47ef11..63b5d83d308ad1 100644
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
@@ -251,7 +251,7 @@ def __call__(
hidden_states,
attention_mask,
layer_head_mask,
- key_value_states: Optional[jnp.array] = None,
+ key_value_states: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic=True,
output_attentions: bool = False,
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index f73e3f60a5530b..175861fd149e5e 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -2906,3 +2906,101 @@ def test_default_max_length_warning(self):
model.generation_config.max_length = 10
model.generate(input_ids)
self.assertEqual(len(warning_list), 0)
+
+ def test_model_kwarg_assisted_decoding_decoder_only(self):
+ # PT-only test: TF doesn't support assisted decoding yet.
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
+ model.config.pad_token_id = tokenizer.eos_token_id
+
+ text = "Hello world"
+ tokenized_inputs = tokenizer([text], return_tensors="pt")
+ input_ids = tokenized_inputs.input_ids.to(torch_device)
+
+ # Traditional way of generating text
+ outputs_normal = model.generate(input_ids)
+ self.assertEqual(outputs_normal.shape, (1, 20))
+
+ # Should be different with token_type_ids
+ outputs_tti = model.generate(
+ input_ids,
+ token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device),
+ )
+ with self.assertRaises(AssertionError):
+ self.assertListEqual(outputs_tti.tolist(), outputs_normal.tolist())
+
+ # Assistant model
+ assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
+ assistant.config.pad_token_id = tokenizer.eos_token_id
+
+ # If assisted generation passes model_kwargs correctly, should be same as previous
+ outputs_assisted = model.generate(
+ input_ids,
+ token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device),
+ assistant_model=assistant,
+ )
+ self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
+
+ def test_model_kwarg_assisted_decoding_encoder_decoder(self):
+ # PT-only test: TF doesn't support assisted decoding yet.
+ # Bart subclass with a kwarg that distorts the output
+ class FakeBart(BartForConditionalGeneration):
+ def forward(self, input_ids, foo=False, **kwargs):
+ outs = super().forward(input_ids, **kwargs)
+
+ if foo:
+ outs["logits"][:, :, :] = 0.0
+
+ return outs
+
+ def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
+ kwargs["encoder_outputs"] = encoder_outputs
+ inputs = super().prepare_inputs_for_generation(*args, **kwargs)
+
+ inputs["foo"] = foo
+ return inputs
+
+ model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
+ torch_device
+ )
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
+
+ text = "Hello world"
+ tokenized_inputs = tokenizer([text], return_tensors="pt")
+ input_ids = tokenized_inputs.input_ids.to(torch_device)
+
+ # Traditional way of generating text
+ outputs_normal = model.generate(input_ids)
+ self.assertEqual(outputs_normal.shape, (1, 20))
+
+ # Should be different with foo
+ outputs_foo = model.generate(
+ input_ids,
+ foo=True,
+ )
+ with self.assertRaises(AssertionError):
+ self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
+
+ # Assistant model
+ assistant = AutoModelForSeq2SeqLM.from_pretrained(
+ "hf-internal-testing/tiny-random-BartForConditionalGeneration"
+ ).to(torch_device)
+
+ # If assisted generation passes model_kwargs correctly, should be same as previous
+ outputs_assisted = model.generate(
+ input_ids,
+ foo=True,
+ assistant_model=assistant,
+ )
+ self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
+
+ # Check that passing encoder_outputs directly also works as expected
+ encoder_outputs = assistant.get_encoder()(input_ids)
+
+ outputs_assisted = model.generate(
+ foo=True,
+ assistant_model=assistant,
+ encoder_outputs=encoder_outputs,
+ assistant_encoder_outputs=encoder_outputs,
+ )
+ self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
diff --git a/tests/models/biogpt/test_modeling_biogpt.py b/tests/models/biogpt/test_modeling_biogpt.py
index e43fc1e41b8f9d..b7db0bbe28a7b7 100644
--- a/tests/models/biogpt/test_modeling_biogpt.py
+++ b/tests/models/biogpt/test_modeling_biogpt.py
@@ -386,7 +386,7 @@ def test_model_from_pretrained(self):
model = BioGptModel.from_pretrained(model_name)
self.assertIsNotNone(model)
- # Copied from tests.models.opt.test_modeling_opt.OPTModelTest with OPT->BioGpt, prepare_config_and_inputs-> prepare_config_and_inputs_for_common
+ # Copied from tests.models.opt.test_modeling_opt.OPTModelTest.test_opt_sequence_classification_model with OPT->BioGpt,opt->biogpt,prepare_config_and_inputs->prepare_config_and_inputs_for_common
def test_biogpt_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
@@ -399,7 +399,7 @@ def test_biogpt_sequence_classification_model(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
- # Copied from tests.models.opt.test_modeling_opt.OPTModelTest with OPT->BioGpt, prepare_config_and_inputs-> prepare_config_and_inputs_for_common
+ # Copied from tests.models.opt.test_modeling_opt.OPTModelTest.test_opt_sequence_classification_model_for_multi_label with OPT->BioGpt,opt->biogpt,prepare_config_and_inputs->prepare_config_and_inputs_for_common
def test_biogpt_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
diff --git a/tests/models/clap/test_feature_extraction_clap.py b/tests/models/clap/test_feature_extraction_clap.py
index c49d045ba87407..d0e913df828b84 100644
--- a/tests/models/clap/test_feature_extraction_clap.py
+++ b/tests/models/clap/test_feature_extraction_clap.py
@@ -19,6 +19,7 @@
import unittest
import numpy as np
+from datasets import load_dataset
from transformers import ClapFeatureExtractor
from transformers.testing_utils import require_torch, require_torchaudio
@@ -110,10 +111,10 @@ def _flatten(list_of_lists):
@require_torch
@require_torchaudio
-# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest with Whisper->Clap
class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = ClapFeatureExtractor
+ # Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.setUp with Whisper->Clap
def setUp(self):
self.feat_extract_tester = ClapFeatureExtractionTester(self)
@@ -147,6 +148,7 @@ def test_call(self):
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
+ # Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad
def test_double_precision_pad(self):
import torch
@@ -160,9 +162,8 @@ def test_double_precision_pad(self):
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
+ # Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest._load_datasamples
def _load_datasamples(self, num_samples):
- from datasets import load_dataset
-
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py
index 0223acbbd72a8a..2402986900fda6 100644
--- a/tests/models/llama/test_modeling_llama.py
+++ b/tests/models/llama/test_modeling_llama.py
@@ -341,7 +341,7 @@ def test_llama_sequence_classification_model_for_multi_label(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
- @unittest.skip("LLaMA buffers include complex numbers, which breaks this test")
+ @unittest.skip("Llama buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self):
pass
diff --git a/tests/models/longformer/test_tokenization_longformer.py b/tests/models/longformer/test_tokenization_longformer.py
index 2397a40bafa6b1..61d8653b60c608 100644
--- a/tests/models/longformer/test_tokenization_longformer.py
+++ b/tests/models/longformer/test_tokenization_longformer.py
@@ -27,7 +27,6 @@
from ...test_tokenization_common import TokenizerTesterMixin
-# Copied from transformers.tests.roberta.test_modeling_roberta.py with Roberta->Longformer
@require_tokenizers
class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = LongformerTokenizer
@@ -72,19 +71,23 @@ def setUp(self):
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
+ # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_tokenizer
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
+ # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_rust_tokenizer
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
+ # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_input_output_texts
def get_input_output_texts(self, tokenizer):
input_text = "lower newer"
output_text = "lower newer"
return input_text, output_text
+ # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_full_tokenizer
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer"
@@ -96,6 +99,7 @@ def test_full_tokenizer(self):
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
+ # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.roberta_dict_integration_testing with roberta->longformer
def longformer_dict_integration_testing(self):
tokenizer = self.get_tokenizer()
@@ -106,6 +110,7 @@ def longformer_dict_integration_testing(self):
)
@slow
+ # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_sequence_builders with roberta-base->allenai/longformer-base-4096
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("allenai/longformer-base-4096")
@@ -125,6 +130,7 @@ def test_sequence_builders(self):
assert encoded_sentence == encoded_text_from_decode
assert encoded_pair == encoded_pair_from_decode
+ # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_space_encoding
def test_space_encoding(self):
tokenizer = self.get_tokenizer()
@@ -165,9 +171,11 @@ def test_space_encoding(self):
first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0]
self.assertNotEqual(first_char, space_encoding)
+ # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_pretokenized_inputs
def test_pretokenized_inputs(self):
pass
+ # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_embeded_special_tokens
def test_embeded_special_tokens(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
@@ -200,6 +208,7 @@ def test_embeded_special_tokens(self):
tokens_r_str, ["", "A", ",", "", "ĠAllen", "N", "LP", "Ġsentence", ".", ""]
)
+ # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_change_add_prefix_space_and_trim_offsets_args
def test_change_add_prefix_space_and_trim_offsets_args(self):
for trim_offsets, add_prefix_space in itertools.product([True, False], repeat=2):
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
@@ -214,6 +223,7 @@ def test_change_add_prefix_space_and_trim_offsets_args(self):
self.assertEqual(post_processor_state["add_prefix_space"], add_prefix_space)
self.assertEqual(post_processor_state["trim_offsets"], trim_offsets)
+ # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments
def test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments(self):
# Test which aims to verify that the offsets are well adapted to the argument `add_prefix_space` and
# `trim_offsets`
diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py
index df1143d2516afd..d2e9b2685f0beb 100644
--- a/tests/models/mistral/test_modeling_mistral.py
+++ b/tests/models/mistral/test_modeling_mistral.py
@@ -39,7 +39,6 @@
)
-# Copied from transformers.tests.mistral.test_modelling_mistral.MistralModelTest with Llama->Mistral
class MistralModelTester:
def __init__(
self,
@@ -93,6 +92,7 @@ def __init__(
self.pad_token_id = pad_token_id
self.scope = scope
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -134,6 +134,7 @@ def get_config(self):
pad_token_id=self.pad_token_id,
)
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mistral
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
@@ -144,6 +145,7 @@ def create_and_check_model(
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Mistral
def create_and_check_model_as_decoder(
self,
config,
@@ -174,6 +176,7 @@ def create_and_check_model_as_decoder(
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Mistral
def create_and_check_for_causal_lm(
self,
config,
@@ -192,6 +195,7 @@ def create_and_check_for_causal_lm(
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Mistral
def create_and_check_decoder_model_past_large_inputs(
self,
config,
@@ -254,6 +258,7 @@ def create_and_check_decoder_model_past_large_inputs(
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
diff --git a/tests/models/mobilebert/test_tokenization_mobilebert.py b/tests/models/mobilebert/test_tokenization_mobilebert.py
index 3ecc2e3238d512..babed7a8d9bfdc 100644
--- a/tests/models/mobilebert/test_tokenization_mobilebert.py
+++ b/tests/models/mobilebert/test_tokenization_mobilebert.py
@@ -32,7 +32,6 @@
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
-# Copied from transformers.tests.models.bert.test_modeling_bert.py with Bert->MobileBert and pathfix
@require_tokenizers
class MobileBERTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MobileBertTokenizer
@@ -71,11 +70,13 @@ def setUp(self):
for tokenizer_def in self.tokenizers_list
]
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.get_input_output_texts
def get_input_output_texts(self, tokenizer):
input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running"
return input_text, output_text
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_full_tokenizer
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file)
@@ -83,6 +84,7 @@ def test_full_tokenizer(self):
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_rust_and_python_full_tokenizers
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
@@ -124,11 +126,13 @@ def test_rust_and_python_full_tokenizers(self):
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_chinese
def test_chinese(self):
tokenizer = BasicTokenizer()
self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower
def test_basic_tokenizer_lower(self):
tokenizer = BasicTokenizer(do_lower_case=True)
@@ -137,6 +141,7 @@ def test_basic_tokenizer_lower(self):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_false
def test_basic_tokenizer_lower_strip_accents_false(self):
tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False)
@@ -145,6 +150,7 @@ def test_basic_tokenizer_lower_strip_accents_false(self):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"])
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_true
def test_basic_tokenizer_lower_strip_accents_true(self):
tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=True)
@@ -153,6 +159,7 @@ def test_basic_tokenizer_lower_strip_accents_true(self):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_default
def test_basic_tokenizer_lower_strip_accents_default(self):
tokenizer = BasicTokenizer(do_lower_case=True)
@@ -161,6 +168,7 @@ def test_basic_tokenizer_lower_strip_accents_default(self):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower
def test_basic_tokenizer_no_lower(self):
tokenizer = BasicTokenizer(do_lower_case=False)
@@ -168,6 +176,7 @@ def test_basic_tokenizer_no_lower(self):
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
)
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower_strip_accents_false
def test_basic_tokenizer_no_lower_strip_accents_false(self):
tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False)
@@ -175,6 +184,7 @@ def test_basic_tokenizer_no_lower_strip_accents_false(self):
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HäLLo", "!", "how", "Are", "yoU", "?"]
)
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower_strip_accents_true
def test_basic_tokenizer_no_lower_strip_accents_true(self):
tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True)
@@ -182,6 +192,7 @@ def test_basic_tokenizer_no_lower_strip_accents_true(self):
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"]
)
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_respects_never_split_tokens
def test_basic_tokenizer_respects_never_split_tokens(self):
tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"])
@@ -189,6 +200,7 @@ def test_basic_tokenizer_respects_never_split_tokens(self):
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
)
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_wordpiece_tokenizer
def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
@@ -203,6 +215,7 @@ def test_wordpiece_tokenizer(self):
self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_whitespace
def test_is_whitespace(self):
self.assertTrue(_is_whitespace(" "))
self.assertTrue(_is_whitespace("\t"))
@@ -213,6 +226,7 @@ def test_is_whitespace(self):
self.assertFalse(_is_whitespace("A"))
self.assertFalse(_is_whitespace("-"))
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_control
def test_is_control(self):
self.assertTrue(_is_control("\u0005"))
@@ -221,6 +235,7 @@ def test_is_control(self):
self.assertFalse(_is_control("\t"))
self.assertFalse(_is_control("\r"))
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_punctuation
def test_is_punctuation(self):
self.assertTrue(_is_punctuation("-"))
self.assertTrue(_is_punctuation("$"))
@@ -230,6 +245,7 @@ def test_is_punctuation(self):
self.assertFalse(_is_punctuation("A"))
self.assertFalse(_is_punctuation(" "))
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_clean_text
def test_clean_text(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
@@ -242,6 +258,7 @@ def test_clean_text(self):
)
@slow
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_sequence_builders with bert-base-uncased->google/mobilebert-uncased
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("google/mobilebert-uncased")
@@ -254,6 +271,7 @@ def test_sequence_builders(self):
assert encoded_sentence == [101] + text + [102]
assert encoded_pair == [101] + text + [102] + text_2 + [102]
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_offsets_with_special_characters
def test_offsets_with_special_characters(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
@@ -306,6 +324,7 @@ def test_offsets_with_special_characters(self):
)
self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_change_tokenize_chinese_chars
def test_change_tokenize_chinese_chars(self):
list_of_commun_chinese_char = ["的", "人", "有"]
text_with_chinese_char = "".join(list_of_commun_chinese_char)
diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py
index 3b67128c3b7372..60a5dabf1053f1 100644
--- a/tests/models/persimmon/test_modeling_persimmon.py
+++ b/tests/models/persimmon/test_modeling_persimmon.py
@@ -39,7 +39,7 @@
)
-# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Persimmon
+# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon
class PersimmonModelTester:
def __init__(
self,
@@ -266,7 +266,6 @@ def prepare_config_and_inputs_for_common(self):
return config, inputs_dict
-# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Persimmon
@require_torch
class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
@@ -288,23 +287,28 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
test_headmasking = False
test_pruning = False
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Persimmon
def setUp(self):
self.model_tester = PersimmonModelTester(self)
self.config_tester = ConfigTester(self, config_class=PersimmonConfig, hidden_size=37)
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
def test_config(self):
self.config_tester.run_common_tests()
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Persimmon,llama->persimmon
def test_persimmon_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
@@ -317,6 +321,7 @@ def test_persimmon_sequence_classification_model(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Persimmon,llama->persimmon
def test_persimmon_sequence_classification_model_for_single_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
@@ -330,6 +335,7 @@ def test_persimmon_sequence_classification_model_for_single_label(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Persimmon,llama->persimmon
def test_persimmon_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
@@ -346,10 +352,12 @@ def test_persimmon_sequence_classification_model_for_multi_label(self):
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("Persimmon buffers include complex numbers, which breaks this test")
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base
def test_save_load_fast_init_from_base(self):
pass
@parameterized.expand([("linear",), ("dynamic",)])
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling with Llama->Persimmon
def test_model_rope_scaling(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
diff --git a/tests/models/roberta/test_tokenization_roberta.py b/tests/models/roberta/test_tokenization_roberta.py
index 78bac218351bf3..3190ab13be4ea1 100644
--- a/tests/models/roberta/test_tokenization_roberta.py
+++ b/tests/models/roberta/test_tokenization_roberta.py
@@ -76,8 +76,7 @@ def get_tokenizer(self, **kwargs):
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
- return RobertaTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
- return RobertaTokenizerFast(self.vocab_file, self.merges_file, **kwargs)
+ return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self, tokenizer):
input_text = "lower newer"
diff --git a/tests/models/roberta_prelayernorm/test_modeling_flax_roberta_prelayernorm.py b/tests/models/roberta_prelayernorm/test_modeling_flax_roberta_prelayernorm.py
index 8500dfcb67a84f..65dbe65974d4c4 100644
--- a/tests/models/roberta_prelayernorm/test_modeling_flax_roberta_prelayernorm.py
+++ b/tests/models/roberta_prelayernorm/test_modeling_flax_roberta_prelayernorm.py
@@ -36,7 +36,7 @@
)
-# Copied from tests.models.roberta.test_modelling_flax_roberta.FlaxRobertaModelTester with Roberta->RobertaPreLayerNorm
+# Copied from tests.models.roberta.test_modeling_flax_roberta.FlaxRobertaModelTester with Roberta->RobertaPreLayerNorm
class FlaxRobertaPreLayerNormModelTester(unittest.TestCase):
def __init__(
self,
@@ -134,7 +134,7 @@ def prepare_config_and_inputs_for_decoder(self):
@require_flax
-# Copied from tests.models.roberta.test_modelling_flax_roberta.FlaxRobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta-base->andreasmadsen/efficient_mlm_m0.40
+# Copied from tests.models.roberta.test_modeling_flax_roberta.FlaxRobertaModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta-base->andreasmadsen/efficient_mlm_m0.40
class FlaxRobertaPreLayerNormModelTest(FlaxModelTesterMixin, unittest.TestCase):
test_head_masking = True
diff --git a/tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py b/tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py
index ee0972eec32964..a2f56e31a09192 100644
--- a/tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py
+++ b/tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py
@@ -44,7 +44,7 @@
)
-# Copied from tests.models.roberta.test_modelling_roberta.RobertaModelTester with Roberta->RobertaPreLayerNorm
+# Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTester with Roberta->RobertaPreLayerNorm
class RobertaPreLayerNormModelTester:
def __init__(
self,
@@ -365,7 +365,6 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
-# Copied from tests.models.roberta.test_modelling_roberta.RobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm
class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
@@ -397,27 +396,33 @@ class RobertaPreLayerNormModelTest(ModelTesterMixin, GenerationTesterMixin, Pipe
fx_compatible = False
model_split_percents = [0.5, 0.8, 0.9]
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.setUp with Roberta->RobertaPreLayerNorm
def setUp(self):
self.model_tester = RobertaPreLayerNormModelTester(self)
self.config_tester = ConfigTester(self, config_class=RobertaPreLayerNormConfig, hidden_size=37)
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_config
def test_config(self):
self.config_tester.run_common_tests()
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_various_embeddings
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_as_decoder
def test_model_as_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_as_decoder_with_default_input_mask
def test_model_as_decoder_with_default_input_mask(self):
# This regression test was failing with PyTorch < 1.3
(
@@ -446,42 +451,50 @@ def test_model_as_decoder_with_default_input_mask(self):
encoder_attention_mask,
)
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_causal_lm
def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_decoder_model_past_with_large_inputs
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_masked_lm
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_token_classification
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_multiple_choice
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_for_question_answering
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
@slow
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_from_pretrained with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm
def test_model_from_pretrained(self):
for model_name in ROBERTA_PRELAYERNORM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = RobertaPreLayerNormModel.from_pretrained(model_name)
self.assertIsNotNone(model)
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_create_position_ids_respects_padding_index with Roberta->RobertaPreLayerNorm
def test_create_position_ids_respects_padding_index(self):
"""Ensure that the default position ids only assign a sequential . This is a regression
test for https://github.com/huggingface/transformers/issues/1761
- The position ids should be masked with the embedding object's padding index. Therefore, the
- first available non-padding position index is RobertaPreLayerNormEmbeddings.padding_idx + 1
+ The position ids should be masked with the embedding object's padding index. Therefore, the first available
+ non-padding position index is RobertaPreLayerNormEmbeddings.padding_idx + 1
"""
config = self.model_tester.prepare_config_and_inputs()[0]
model = RobertaPreLayerNormEmbeddings(config=config)
@@ -495,12 +508,13 @@ def test_create_position_ids_respects_padding_index(self):
self.assertEqual(position_ids.shape, expected_positions.shape)
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
+ # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_create_position_ids_from_inputs_embeds with Roberta->RobertaPreLayerNorm
def test_create_position_ids_from_inputs_embeds(self):
"""Ensure that the default position ids only assign a sequential . This is a regression
test for https://github.com/huggingface/transformers/issues/1761
- The position ids should be masked with the embedding object's padding index. Therefore, the
- first available non-padding position index is RobertaPreLayerNormEmbeddings.padding_idx + 1
+ The position ids should be masked with the embedding object's padding index. Therefore, the first available
+ non-padding position index is RobertaPreLayerNormEmbeddings.padding_idx + 1
"""
config = self.model_tester.prepare_config_and_inputs()[0]
embeddings = RobertaPreLayerNormEmbeddings(config=config)
diff --git a/tests/models/roberta_prelayernorm/test_modeling_tf_roberta_prelayernorm.py b/tests/models/roberta_prelayernorm/test_modeling_tf_roberta_prelayernorm.py
index 9c1a25ccb982ea..384fa2e9e40013 100644
--- a/tests/models/roberta_prelayernorm/test_modeling_tf_roberta_prelayernorm.py
+++ b/tests/models/roberta_prelayernorm/test_modeling_tf_roberta_prelayernorm.py
@@ -42,7 +42,7 @@
)
-# Copied from tests.models.roberta.test_modelling_tf_roberta.TFRobertaModelTester with Roberta->RobertaPreLayerNorm
+# Copied from tests.models.roberta.test_modeling_tf_roberta.TFRobertaModelTester with Roberta->RobertaPreLayerNorm
class TFRobertaPreLayerNormModelTester:
def __init__(
self,
@@ -551,7 +551,7 @@ def prepare_config_and_inputs_for_common(self):
@require_tf
-# Copied from tests.models.roberta.test_modelling_tf_roberta.TFRobertaPreLayerNormModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm
+# Copied from tests.models.roberta.test_modeling_tf_roberta.TFRobertaModelTest with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm
class TFRobertaPreLayerNormModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
diff --git a/tests/models/roc_bert/test_tokenization_roc_bert.py b/tests/models/roc_bert/test_tokenization_roc_bert.py
index 0f8fe08efd1517..6a24514b3c2c2b 100644
--- a/tests/models/roc_bert/test_tokenization_roc_bert.py
+++ b/tests/models/roc_bert/test_tokenization_roc_bert.py
@@ -68,13 +68,13 @@ def test_full_tokenizer(self):
self.assertListEqual(tokenizer.convert_tokens_to_shape_ids(tokens), [5, 6, 2, 5, 7, 8])
self.assertListEqual(tokenizer.convert_tokens_to_pronunciation_ids(tokens), [5, 6, 2, 5, 7, 8])
- # Copied from tests.models.bert.test_tokenization_bert.test_chinese with BasicTokenizer->RoCBertBertBasicTokenizer
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_chinese with BasicTokenizer->RoCBertBasicTokenizer
def test_chinese(self):
tokenizer = RoCBertBasicTokenizer()
self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
- # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_lower with BasicTokenizer->RoCBertBertBasicTokenizer
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_lower(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=True)
@@ -83,7 +83,7 @@ def test_basic_tokenizer_lower(self):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
- # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_lower_strip_accents_false with BasicTokenizer->RoCBertBertBasicTokenizer
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_false with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_lower_strip_accents_false(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=True, strip_accents=False)
@@ -92,7 +92,7 @@ def test_basic_tokenizer_lower_strip_accents_false(self):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"])
- # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_lower_strip_accents_true with BertBasicTokenizer->RoCBertBertBasicTokenizer
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_true with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_lower_strip_accents_true(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=True, strip_accents=True)
@@ -101,7 +101,7 @@ def test_basic_tokenizer_lower_strip_accents_true(self):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
- # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_lower_strip_accents_default with BasicTokenizer->RoCBertBertBasicTokenizer
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_lower_strip_accents_default with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_lower_strip_accents_default(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=True)
@@ -110,7 +110,7 @@ def test_basic_tokenizer_lower_strip_accents_default(self):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
- # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_no_lower with BasicTokenizer->RoCBertBertBasicTokenizer
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_no_lower(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=False)
@@ -118,7 +118,7 @@ def test_basic_tokenizer_no_lower(self):
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
)
- # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_no_lower_strip_accents_false with BertBasicTokenizer->RoCBertBertBasicTokenizer
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower_strip_accents_false with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_no_lower_strip_accents_false(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=False, strip_accents=False)
@@ -126,7 +126,7 @@ def test_basic_tokenizer_no_lower_strip_accents_false(self):
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HäLLo", "!", "how", "Are", "yoU", "?"]
)
- # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_no_lower_strip_accents_true with BasicTokenizer->RoCBertBertBasicTokenizer
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_no_lower_strip_accents_true with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_no_lower_strip_accents_true(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=False, strip_accents=True)
@@ -134,7 +134,7 @@ def test_basic_tokenizer_no_lower_strip_accents_true(self):
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"]
)
- # Copied from tests.models.bert.test_tokenization_bert.test_basic_tokenizer_respects_never_split_tokens with BasicTokenizer->RoCBertBertBasicTokenizer
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_basic_tokenizer_respects_never_split_tokens with BasicTokenizer->RoCBertBasicTokenizer
def test_basic_tokenizer_respects_never_split_tokens(self):
tokenizer = RoCBertBasicTokenizer(do_lower_case=False, never_split=["[UNK]"])
@@ -142,7 +142,7 @@ def test_basic_tokenizer_respects_never_split_tokens(self):
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
)
- # Copied from tests.models.bert.test_tokenization_bert.test_wordpiece_tokenizer with WordpieceTokenizer->RoCBertWordpieceTokenizer
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_wordpiece_tokenizer with WordpieceTokenizer->RoCBertWordpieceTokenizer
def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
@@ -157,7 +157,7 @@ def test_wordpiece_tokenizer(self):
self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
- # Copied from tests.models.bert.test_tokenization_bert.test_is_whitespace
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_whitespace
def test_is_whitespace(self):
self.assertTrue(_is_whitespace(" "))
self.assertTrue(_is_whitespace("\t"))
@@ -168,7 +168,7 @@ def test_is_whitespace(self):
self.assertFalse(_is_whitespace("A"))
self.assertFalse(_is_whitespace("-"))
- # Copied from tests.models.bert.test_tokenization_bert.test_is_control
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_control
def test_is_control(self):
self.assertTrue(_is_control("\u0005"))
@@ -177,7 +177,7 @@ def test_is_control(self):
self.assertFalse(_is_control("\t"))
self.assertFalse(_is_control("\r"))
- # Copied from tests.models.bert.test_tokenization_bert.test_is_punctuation
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_is_punctuation
def test_is_punctuation(self):
self.assertTrue(_is_punctuation("-"))
self.assertTrue(_is_punctuation("$"))
@@ -199,7 +199,7 @@ def test_clean_text(self):
[rust_tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]]
)
- # Copied from tests.models.bert.test_tokenization_bert. test_offsets_with_special_characters
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_offsets_with_special_characters
def test_offsets_with_special_characters(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
@@ -252,7 +252,7 @@ def test_offsets_with_special_characters(self):
)
self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
- # Copied from tests.models.bert.test_tokenization_bert. test_change_tokenize_chinese_chars
+ # Copied from tests.models.bert.test_tokenization_bert.BertTokenizationTest.test_change_tokenize_chinese_chars
def test_change_tokenize_chinese_chars(self):
list_of_commun_chinese_char = ["的", "人", "有"]
text_with_chinese_char = "".join(list_of_commun_chinese_char)
diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py
index 7ec5f90f0fcd53..982dcb4827a0d1 100644
--- a/tests/models/whisper/test_modeling_flax_whisper.py
+++ b/tests/models/whisper/test_modeling_flax_whisper.py
@@ -46,6 +46,7 @@
WhisperProcessor,
)
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
+ from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init
@require_flax
@@ -387,6 +388,19 @@ def test_save_load_to_base(self):
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
+ def test_encoder_sinusoidal_embed_positions(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ params = model.params
+ if model.base_model_prefix in params:
+ params = model.params[model.base_model_prefix]
+
+ embeds = params["encoder"]["embed_positions"]["embedding"]
+ sinusoids = sinusoidal_embedding_init(None, embeds.shape)
+ self.assertTrue(jax.numpy.allclose(embeds, sinusoids))
+
@slow
@require_flax
diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py
index 7fae1e466e7a6e..75c62ae1ad07e6 100644
--- a/tests/models/whisper/test_modeling_tf_whisper.py
+++ b/tests/models/whisper/test_modeling_tf_whisper.py
@@ -42,7 +42,11 @@
import tensorflow as tf
from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed
- from transformers.models.whisper.modeling_tf_whisper import TFWhisperDecoder, TFWhisperEncoder
+ from transformers.models.whisper.modeling_tf_whisper import (
+ TFWhisperDecoder,
+ TFWhisperEncoder,
+ sinusoidal_embedding_init,
+ )
def prepare_whisper_inputs_dict(
@@ -297,6 +301,23 @@ def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
+ def test_requires_grad_encoder_embed_positions(self):
+ config = self.model_tester.get_config()
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ encoder = model.get_encoder()
+ self.assertFalse(encoder.embed_positions.trainable)
+
+ def test_encoder_sinusoidal_embed_positions(self):
+ config = self.model_tester.get_config()
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ model.build()
+
+ embeds = model.get_encoder().embed_positions.get_weights()[0]
+ sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
+ self.assertTrue(np.allclose(embeds, sinusoids))
+
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py
index 9decb7192aee00..337d33485210cd 100644
--- a/tests/models/whisper/test_modeling_whisper.py
+++ b/tests/models/whisper/test_modeling_whisper.py
@@ -49,7 +49,7 @@
WhisperProcessor,
set_seed,
)
- from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
+ from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
if is_flax_available():
import jax.numpy as jnp
@@ -351,6 +351,20 @@ def test_requires_grad_with_frozen_encoder(self):
self.assertFalse(all(encoder_grads))
self.assertTrue(all(decoder_grads))
+ def test_requires_grad_encoder_embed_positions(self):
+ config = self.model_tester.get_config()
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ encoder = model.get_encoder()
+ self.assertFalse(encoder.embed_positions.weight.requires_grad)
+
+ def test_encoder_sinusoidal_embed_positions(self):
+ config = self.model_tester.get_config()
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ embeds = model.get_encoder().embed_positions.weight
+ self.assertTrue(torch.allclose(embeds, sinusoids(*embeds.shape)))
+
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py
index be9e11de5401ea..fd1c135deb53b5 100644
--- a/tests/models/whisper/test_tokenization_whisper.py
+++ b/tests/models/whisper/test_tokenization_whisper.py
@@ -376,7 +376,7 @@ def test_tokenizer_special(self):
def test_vocab_size(self):
self.assertEqual(self.tokenizer.vocab_size, 50257)
- # Copied from transformers.tests.speech_to_test.test_tokenization_speech_to_text.py
+ # Copied from tests.models.speech_to_text.test_tokenization_speech_to_text.SpeechToTextTokenizerMultilinguialTest.test_tokenizer_decode_ignores_language_codes
def test_tokenizer_decode_ignores_language_codes(self):
self.assertIn(ES_CODE, self.tokenizer.all_special_ids)
generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2]
diff --git a/utils/check_copies.py b/utils/check_copies.py
index f198b9e062b38b..667c11ec724b42 100644
--- a/utils/check_copies.py
+++ b/utils/check_copies.py
@@ -51,6 +51,7 @@
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
TRANSFORMERS_PATH = "src/transformers"
+MODEL_TEST_PATH = "tests/models"
PATH_TO_DOCS = "docs/source/en"
REPO_PATH = "."
@@ -132,12 +133,15 @@ def _should_continue(line: str, indent: str) -> bool:
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
-def find_code_in_transformers(object_name: str) -> str:
+def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
"""
Find and return the source code of an object.
Args:
- object_name (`str`): The name of the object we want the source code of.
+ object_name (`str`):
+ The name of the object we want the source code of.
+ base_path (`str`, *optional*):
+ The path to the base folder where files are checked. If not set, it will be set to `TRANSFORMERS_PATH`.
Returns:
`str`: The source code of the object.
@@ -145,9 +149,21 @@ def find_code_in_transformers(object_name: str) -> str:
parts = object_name.split(".")
i = 0
+ # We can't set this as the default value in the argument, otherwise `CopyCheckTester` will fail, as it uses a
+ # patched temp directory.
+ if base_path is None:
+ base_path = TRANSFORMERS_PATH
+
+ # Detail: the `Copied from` statement is originally designed to work with the last part of `TRANSFORMERS_PATH`,
+ # (which is `transformers`). The same should be applied for `MODEL_TEST_PATH`. However, its last part is `models`
+ # (to only check and search in it) which is a bit confusing. So we keep the copied statement staring with
+ # `tests.models.` and change it to `tests` here.
+ if base_path == MODEL_TEST_PATH:
+ base_path = "tests"
+
# First let's find the module where our object lives.
module = parts[i]
- while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")):
+ while i < len(parts) and not os.path.isfile(os.path.join(base_path, f"{module}.py")):
i += 1
if i < len(parts):
module = os.path.join(module, parts[i])
@@ -156,7 +172,7 @@ def find_code_in_transformers(object_name: str) -> str:
f"`object_name` should begin with the name of a module of transformers but got {object_name}."
)
- with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
+ with open(os.path.join(base_path, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Now let's find the class / func in the code!
@@ -186,6 +202,7 @@ def find_code_in_transformers(object_name: str) -> str:
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
+_re_copy_warning_for_test_file = re.compile(r"^(\s*)#\s*Copied from\s+tests\.(\S+\.\S+)\s*($|\S.*$)")
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
_re_fill_pattern = re.compile(r"]*>")
@@ -284,14 +301,20 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
line_index = 0
# Not a for loop cause `lines` is going to change (if `overwrite=True`).
while line_index < len(lines):
- search = _re_copy_warning.search(lines[line_index])
+ search_re = _re_copy_warning
+ if filename.startswith("tests"):
+ search_re = _re_copy_warning_for_test_file
+
+ search = search_re.search(lines[line_index])
if search is None:
line_index += 1
continue
# There is some copied code here, let's retrieve the original.
indent, object_name, replace_pattern = search.groups()
- theoretical_code = find_code_in_transformers(object_name)
+
+ base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH
+ theoretical_code = find_code_in_transformers(object_name, base_path=base_path)
theoretical_indent = get_indent(theoretical_code)
start_index = line_index + 1 if indent == theoretical_indent else line_index
@@ -357,6 +380,9 @@ def check_copies(overwrite: bool = False):
Whether or not to overwrite the copies when they don't match.
"""
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
+ all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True)
+ all_files = list(all_files) + list(all_test_files)
+
diffs = []
for filename in all_files:
new_diffs = is_copy_consistent(filename, overwrite)
diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py
index 2832e347ab5f1e..f142c5dbccd1df 100644
--- a/utils/check_docstrings.py
+++ b/utils/check_docstrings.py
@@ -361,7 +361,6 @@
"LevitConfig",
"LiltConfig",
"LiltModel",
- "LlamaConfig",
"LlamaTokenizer",
"LlamaTokenizerFast",
"LongT5Config",
@@ -500,7 +499,6 @@
"SqueezeBertTokenizerFast",
"SummarizationPipeline",
"Swin2SRImageProcessor",
- "SwinModel",
"Swinv2Model",
"SwitchTransformersConfig",
"T5Config",