From 8ce4fefc52d6146fb6aaf2b896cbb07b9fc4d947 Mon Sep 17 00:00:00 2001 From: Aaron Jimenez Date: Fri, 10 May 2024 08:29:26 -0800 Subject: [PATCH 01/19] [docs] Update link in es/pipeline_webserver.md (#30745) * update link * run make style --- docs/source/es/pipeline_webserver.md | 8 +------- examples/pytorch/object-detection/run_object_detection.py | 4 +--- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/docs/source/es/pipeline_webserver.md b/docs/source/es/pipeline_webserver.md index e77e620f58b78b..e268daabcbe881 100644 --- a/docs/source/es/pipeline_webserver.md +++ b/docs/source/es/pipeline_webserver.md @@ -9,13 +9,7 @@ Crear un motor de inferencia es un tema complejo, y la "mejor" solución probabl -Lo fundamental para entender es que podemos usar un iterador, tal como [en un conjunto de datos](https://huggingface.co/docs/transformers/pipeline_tutorial#using-pipelines-on-a-dataset), ya que un servidor web es básicamente un sistema que espera solicitudes y las trata a medida que llegan. - - +Lo fundamental para entender es que podemos usar un iterador, tal como [en un conjunto de datos](pipeline_tutorial#uso-de-pipelines-en-un-conjunto-de-datos), ya que un servidor web es básicamente un sistema que espera solicitudes y las trata a medida que llegan. Por lo general, los servidores web están multiplexados (multihilo, asíncrono, etc.) para manejar varias solicitudes simultáneamente. Por otro lado, los flujos de trabajo (y principalmente los modelos subyacentes) no son realmente ideales para el paralelismo; consumen mucha RAM, por lo que es mejor darles todos los recursos disponibles cuando se están ejecutando o es un trabajo intensivo en cómputo. diff --git a/examples/pytorch/object-detection/run_object_detection.py b/examples/pytorch/object-detection/run_object_detection.py index 3f0769568f981a..ba6ee1e55a481a 100644 --- a/examples/pytorch/object-detection/run_object_detection.py +++ b/examples/pytorch/object-detection/run_object_detection.py @@ -244,9 +244,7 @@ class DataTrainingArguments: ) image_square_size: Optional[int] = field( default=600, - metadata={ - "help": "Image longest size will be resized to this value, then image will be padded to square." - }, + metadata={"help": "Image longest size will be resized to this value, then image will be padded to square."}, ) max_train_samples: Optional[int] = field( default=None, From e0c3cee17085914bbe505c159beeb8ae39bc37dd Mon Sep 17 00:00:00 2001 From: mobicham <37179323+mobicham@users.noreply.github.com> Date: Fri, 10 May 2024 19:29:35 +0200 Subject: [PATCH 02/19] hqq - fix weight check in check_quantized_param (#30748) * hqq - fix weight check in check_quantized_param * ruff format --- src/transformers/quantizers/quantizer_hqq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 06949d059a5de3..14be75369dec0e 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -101,7 +101,7 @@ def check_quantized_param( ) -> bool: module, tensor_name = get_module_from_name(model, param_name) - return isinstance(module, torch.nn.Linear) + return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") def create_quantized_param( self, From de6e0db184d565847356a6a08dde2f043e744c72 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Mon, 13 May 2024 11:41:03 +0200 Subject: [PATCH 03/19] [awq] replace scale when we have GELU (#30074) * fix awq test * style * add log * new fix * style * only modifying impacted model in the end * rename function --- src/transformers/integrations/__init__.py | 2 ++ src/transformers/integrations/awq.py | 31 +++++++++++++++++++- src/transformers/quantizers/quantizer_awq.py | 4 ++- tests/quantization/autoawq/test_awq.py | 19 ++++++++++++ 4 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 69fb0e3259b1d5..19a3f421caf49a 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -21,6 +21,7 @@ "awq": [ "fuse_awq_modules", "post_init_awq_exllama_modules", + "replace_quantization_scales", "replace_with_awq_linear", ], "bitsandbytes": [ @@ -92,6 +93,7 @@ from .awq import ( fuse_awq_modules, post_init_awq_exllama_modules, + replace_quantization_scales, replace_with_awq_linear, ) from .bitsandbytes import ( diff --git a/src/transformers/integrations/awq.py b/src/transformers/integrations/awq.py index a543860f100396..a83b27e95a090d 100644 --- a/src/transformers/integrations/awq.py +++ b/src/transformers/integrations/awq.py @@ -14,7 +14,7 @@ "AWQ (Activation aware Weight Quantization) integration file" from ..activations import ACT2FN from ..modeling_utils import PreTrainedModel -from ..utils import is_auto_awq_available, is_torch_available +from ..utils import is_auto_awq_available, is_torch_available, logging from ..utils.quantization_config import ( AwqBackendPackingMethod, AwqConfig, @@ -27,6 +27,7 @@ import torch import torch.nn as nn +logger = logging.get_logger(__name__) AWQ_FUSED_MAPPINGS = { "mistral": { @@ -56,6 +57,34 @@ }, } +AWQ_SCALES_MAPPINGS = { + "starcoder2": {"act": "act", "layer_before_act": "c_fc"}, + "RefinedWebModel": {"act": "act", "layer_before_act": "dense_h_to_4h"}, + "falcon": {"act": "act", "layer_before_act": "dense_h_to_4h"}, + "mpt": {"act": "act", "layer_before_act": "up_proj"}, + "gptj": {"act": "act", "layer_before_act": "fc_in"}, + "gpt_neox": {"act": "act", "layer_before_act": "dense_h_to_4h"}, + "gpt_bigcode": {"act": "act", "layer_before_act": "c_fc"}, + "bloom": {"act": "gelu_impl", "layer_before_act": "dense_h_to_4h"}, +} + + +def replace_quantization_scales(model, model_type): + from awq.modules.act import ScaledActivation + + if model_type not in AWQ_SCALES_MAPPINGS: + return model + for name, module in model.named_children(): + act_name = AWQ_SCALES_MAPPINGS[model_type]["act"] + layer_before_act_name = AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"] + if name == act_name and hasattr(model, layer_before_act_name): + layer_before_act = getattr(model, AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"]) + size = layer_before_act.out_features + scale_like = torch.ones(size) + model._modules[name] = ScaledActivation(module, scale_like) + _ = replace_quantization_scales(module, model_type) + return model + def replace_with_awq_linear( model, diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py index 5e66f9baf1c0a7..f9e4444f07015c 100644 --- a/src/transformers/quantizers/quantizer_awq.py +++ b/src/transformers/quantizers/quantizer_awq.py @@ -75,7 +75,7 @@ def update_torch_dtype(self, torch_dtype): return torch_dtype def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): - from ..integrations import get_keys_to_not_convert, replace_with_awq_linear + from ..integrations import get_keys_to_not_convert, replace_quantization_scales, replace_with_awq_linear self.modules_to_not_convert = get_keys_to_not_convert(model) @@ -86,6 +86,8 @@ def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwarg model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert ) + model = replace_quantization_scales(model, model.config.model_type) + if not has_been_replaced: logger.warning( "You are loading an AWQ model but no linear modules were found in your model." diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index e2369f07b23121..20ecd783cf04e7 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -471,3 +471,22 @@ def test_generation_mixtral_fused(self): outputs = model.generate(**inputs, max_new_tokens=12) self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_MIXTRAL) + + +@slow +@require_torch_gpu +@require_auto_awq +@require_accelerate +class AwqScaleTest(unittest.TestCase): + model_name = "TechxGenus/starcoder2-3b-AWQ" + + def test_load_quantized_model(self): + from awq.modules.act import ScaledActivation + + """ + Simple test that checks if the scales have been replaced in the quantized model + """ + quantized_model = AutoModelForCausalLM.from_pretrained( + "TechxGenus/starcoder2-3b-AWQ", torch_dtype=torch.float16, device_map="cuda" + ) + self.assertTrue(isinstance(quantized_model.model.layers[0].mlp.act, ScaledActivation)) From a4e530e3c89fcd1cba869587d6d04929bc28bbbe Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 13 May 2024 12:08:48 +0200 Subject: [PATCH 04/19] Workflow: Replace `actions/post-slack` with centrally defined workflow (#30737) * Remove commit details * remove old workflow --- .github/actions/post-slack/action.yml | 79 --------------------- .github/workflows/push-important-models.yml | 4 +- 2 files changed, 2 insertions(+), 81 deletions(-) delete mode 100644 .github/actions/post-slack/action.yml diff --git a/.github/actions/post-slack/action.yml b/.github/actions/post-slack/action.yml deleted file mode 100644 index 74075a4fedc427..00000000000000 --- a/.github/actions/post-slack/action.yml +++ /dev/null @@ -1,79 +0,0 @@ -name: Send message to slack - -description: 'Send results to slack' -author: 'Hugging Face' -inputs: - slack_channel: - required: true - type: string - title: - required: true - type: string - status: - required: true - type: string - slack_token: - required: true - type: string - -runs: - using: "composite" - steps: - - name: Create content to post - id: create-message - run: | - if [ "${{ inputs.status }}" == "success" ]; then - echo STATUS_MESSAGE='🟢 Tests are passing!' >> $GITHUB_ENV - else - echo STATUS_MESSAGE='🔴 Tests failed! Please check the GitHub action link below' >> $GITHUB_ENV - fi - shell: bash - - - name: Post Canceled results Slack channel - id: post-slack - uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001 - with: - # Slack channel id, channel name, or user id to post message. - # See also: https://api.slack.com/methods/chat.postMessage#channels - channel-id: ${{ inputs.slack_channel }} - # For posting a rich message using Block Kit - payload: | - { - "text": "${{ inputs.title }}", - "blocks": [ - { - "type": "header", - "text": { - "type": "plain_text", - "text": "${{ inputs.title }}" - } - }, - { - "type": "section", - "text": { - "type": "mrkdwn", - "text": "${{ env.STATUS_MESSAGE }}" - } - }, - { - "type": "section", - "text": {"type": "mrkdwn", "text": "*Click the button for more details about the commit*"}, - "accessory": { - "type": "button", - "text": {"type": "plain_text", "text": "Check Commit results"}, - "url": "${{ github.event.pull_request.html_url || github.event.head_commit.url }}" - } - }, - { - "type": "section", - "text": {"type": "mrkdwn", "text": "*Click here for more details about the action ran*"}, - "accessory": { - "type": "button", - "text": {"type": "plain_text", "text": "Check Action results"}, - "url": "${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}" - } - } - ] - } - env: - SLACK_BOT_TOKEN: ${{ inputs.slack_token }} \ No newline at end of file diff --git a/.github/workflows/push-important-models.yml b/.github/workflows/push-important-models.yml index cf86a8fc8d5354..ef965396361116 100644 --- a/.github/workflows/push-important-models.yml +++ b/.github/workflows/push-important-models.yml @@ -97,7 +97,7 @@ jobs: - name: Post to Slack if: always() - uses: ./.github/actions/post-slack + uses: huggingface/hf-workflows/.github/actions/post-slack@main with: slack_channel: ${{ env.OUTPUT_SLACK_CHANNEL_ID }} title: 🤗 Results of the FA2 tests - ${{ matrix.model-name }} @@ -119,7 +119,7 @@ jobs: - name: Post to Slack if: always() - uses: ./.github/actions/post-slack + uses: huggingface/hf-workflows/.github/actions/post-slack@main with: slack_channel: ${{ env.OUTPUT_SLACK_CHANNEL_ID }} title: 🤗 Results of the Integration tests - ${{ matrix.model-name }} From f63d822242e706e09308474b7b81c2f916d2d4a2 Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Mon, 13 May 2024 13:20:16 +0200 Subject: [PATCH 05/19] Blip dynamic input resolution (#30722) * blip with interpolated pos encoding * feat: Add interpolate_pos_encoding option to other models from `BLIP` family. * include check for textual generated content in tests --- src/transformers/models/blip/modeling_blip.py | 81 +++++++++++++++++-- .../models/blip_2/modeling_blip_2.py | 66 +++++++++++++-- .../instructblip/modeling_instructblip.py | 62 ++++++++++++-- tests/models/blip/test_modeling_blip.py | 14 ++++ tests/models/blip_2/test_modeling_blip_2.py | 16 ++++ .../test_modeling_instructblip.py | 21 +++++ 6 files changed, 240 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index c99c3c06b9dd5b..576d4dd5c0cb50 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -14,6 +14,7 @@ # limitations under the License. """ PyTorch BLIP model.""" +import math import warnings from dataclasses import dataclass from typing import Any, Optional, Tuple, Union @@ -231,15 +232,51 @@ def __init__(self, config: BlipVisionConfig): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 + + if num_patches == num_positions and height == width: + return self.position_embedding + + class_pos_embed = self.position_embedding[:, 0, :] + patch_pos_embed = self.position_embedding[:, 1:, :] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) + if interpolate_pos_encoding: + position_embedding = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embedding = self.position_embedding + embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype) return embeddings @@ -509,6 +546,8 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ BLIP_INPUTS_DOCSTRING = r""" @@ -545,6 +584,8 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ @@ -657,6 +698,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -671,7 +713,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -779,6 +821,7 @@ def get_image_features( self, pixel_values: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: r""" Returns: @@ -804,7 +847,11 @@ def get_image_features( ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=return_dict) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) pooled_output = vision_outputs[1] # pooled_output image_features = self.visual_projection(pooled_output) @@ -818,6 +865,7 @@ def get_multimodal_features( pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: r""" Returns: @@ -846,6 +894,7 @@ def get_multimodal_features( output_attentions=True, output_hidden_states=True, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -876,6 +925,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BlipOutput]: r""" Returns: @@ -913,6 +963,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) text_outputs = self.text_model( @@ -999,6 +1050,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BlipForConditionalGenerationModelOutput]: r""" Returns: @@ -1033,6 +1085,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -1065,6 +1118,7 @@ def generate( pixel_values: torch.FloatTensor, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: r""" @@ -1100,7 +1154,10 @@ def generate( """ batch_size = pixel_values.shape[0] - vision_outputs = self.vision_model(pixel_values=pixel_values) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) image_embeds = vision_outputs[0] @@ -1174,6 +1231,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BlipTextVisionModelOutput]: r""" Returns: @@ -1227,6 +1285,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -1279,6 +1338,7 @@ def generate( input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: r""" @@ -1316,7 +1376,10 @@ def generate( 2 ``` """ - vision_outputs = self.vision_model(pixel_values=pixel_values) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) image_embeds = vision_outputs[0] @@ -1408,6 +1471,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BlipTextVisionModelOutput]: r""" Returns: @@ -1441,6 +1505,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 12396bf286eaaf..8986eabc7d89f3 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -101,15 +101,51 @@ def __init__(self, config: Blip2VisionConfig): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 + + if num_patches == num_positions and height == width: + return self.position_embedding + + class_pos_embed = self.position_embedding[:, 0, :] + patch_pos_embed = self.position_embedding[:, 1:, :] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) + if interpolate_pos_encoding: + position_embedding = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embedding = self.position_embedding + embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype) return embeddings @@ -321,6 +357,8 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ BLIP_2_TEXT_INPUTS_DOCSTRING = r""" @@ -402,6 +440,8 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ @@ -516,6 +556,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -530,7 +571,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1297,6 +1338,7 @@ def get_image_features( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ): r""" Returns: @@ -1330,6 +1372,7 @@ def get_image_features( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) return vision_outputs @@ -1341,6 +1384,7 @@ def get_qformer_features( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ): r""" Returns: @@ -1374,6 +1418,7 @@ def get_qformer_features( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -1406,6 +1451,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: r""" Returns: @@ -1441,6 +1487,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -1623,6 +1670,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: r""" Returns: @@ -1695,6 +1743,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -1779,6 +1828,7 @@ def generate( pixel_values: torch.FloatTensor, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: """ @@ -1800,7 +1850,11 @@ def generate( self._preprocess_accelerate() batch_size = pixel_values.shape[0] - image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state + image_embeds = self.vision_model( + pixel_values, + return_dict=True, + interpolate_pos_encoding=interpolate_pos_encoding, + ).last_hidden_state image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 291db19721ed6d..0fb089b62729f4 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -102,15 +102,51 @@ def __init__(self, config: InstructBlipVisionConfig): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) - def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embedding.shape[1] - 1 + + if num_patches == num_positions and height == width: + return self.position_embedding + + class_pos_embed = self.position_embedding[:, 0, :] + patch_pos_embed = self.position_embedding[:, 1:, :] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) + if interpolate_pos_encoding: + position_embedding = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embedding = self.position_embedding + embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype) return embeddings @@ -328,6 +364,8 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ INSTRUCTBLIP_INPUTS_DOCSTRING = r""" @@ -391,6 +429,8 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. """ @@ -505,6 +545,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -519,7 +560,7 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - hidden_states = self.embeddings(pixel_values) + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1327,6 +1368,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1379,6 +1421,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) image_embeds = vision_outputs[0] @@ -1473,6 +1516,7 @@ def generate( qformer_attention_mask: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: """ @@ -1489,6 +1533,8 @@ def generate( The sequence used as a prompt for the generation. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): Mask to avoid performing attention on padding token indices. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the positional encoding of the image embeddings. Returns: captions (list): A list of strings of length batch_size * num_captions. @@ -1498,7 +1544,11 @@ def generate( self._preprocess_accelerate() batch_size = pixel_values.shape[0] - image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state + image_embeds = self.vision_model( + pixel_values, + return_dict=True, + interpolate_pos_encoding=interpolate_pos_encoding, + ).last_hidden_state image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 4caba63a310462..89404342f0b084 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -1381,6 +1381,20 @@ def test_inference_image_captioning_fp16(self): [30522, 1037, 3861, 1997, 1037, 2450, 1998, 2014, 3899, 2006, 1996, 3509, 102], ) + def test_inference_interpolate_pos_encoding(self): + model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(torch_device) + processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + processor.image_processor.size = {"height": 500, "width": 500} + + image = prepare_img() + inputs = processor(images=image, return_tensors="pt").to(torch_device) + + predictions = model.generate(**inputs, interpolate_pos_encoding=True) + generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() + + self.assertEqual(predictions[0].tolist(), [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 1037, 3899, 102]) + self.assertEqual(generated_text, "a woman sitting on the beach with a dog") + def test_inference_vqa(self): model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(torch_device) processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 927f5341272f45..d2f3b2b719f27c 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -882,6 +882,22 @@ def test_inference_opt(self): ) self.assertEqual(generated_text, "it's not a city, it's a beach") + def test_inference_interpolate_pos_encoding(self): + processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + model = Blip2ForConditionalGeneration.from_pretrained( + "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16 + ).to(torch_device) + processor.image_processor.size = {"height": 500, "width": 500} + + image = prepare_img() + inputs = processor(images=image, return_tensors="pt").to(torch_device) + + predictions = model.generate(**inputs, interpolate_pos_encoding=True) + generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() + + self.assertEqual(predictions[0].tolist(), [2, 102, 693, 8, 2335, 15, 5, 4105, 50118]) + self.assertEqual(generated_text, "a woman and dog on the beach") + def test_inference_opt_batched_beam_search(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") model = Blip2ForConditionalGeneration.from_pretrained( diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index dcb8040bfcf9da..86aea876fa507e 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -612,3 +612,24 @@ def test_inference_flant5_xl(self): generated_text, "The image depicts a man ironing clothes on the back of a yellow van in the middle of a busy city street. The man is wearing a yellow shirt with a bright yellow tie, and he is using an ironing board to complete his task. The image is unusual due to the fact that it shows a man ironing clothes on the back of a van in the middle of a busy city street. It is possible that the man is trying to save money by doing his laundry on the back of the van, but it is also possible that he is trying to save time by doing his laundry on the back of the van in the middle of a busy city street. Regardless of the reason for the man's actions, it is clear that he is trying to save time by doing his laundry on the back of the van in the middle of a busy city street.", ) + + def test_inference_interpolate_pos_encoding(self): + processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl") + model = InstructBlipForConditionalGeneration.from_pretrained( + "Salesforce/instructblip-flan-t5-xl", + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ).to(torch_device) + processor.image_processor.size = {"height": 500, "width": 500} + + image = prepare_img() + prompt = "What's in the image?" + inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device) + + predictions = model.generate(**inputs, interpolate_pos_encoding=True) + generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() + + self.assertEqual( + predictions[0].tolist(), [0, 37, 1023, 753, 3, 9, 2335, 3823, 30, 8, 2608, 28, 3, 9, 1782, 5, 1] + ) + self.assertEqual(generated_text, "The image features a woman sitting on the beach with a dog.") From e52741f60117acea1ab113db5e3b4e4245dd0d45 Mon Sep 17 00:00:00 2001 From: Nilabhra Roy Chowdhury Date: Mon, 13 May 2024 13:32:43 +0200 Subject: [PATCH 06/19] Support for Falcon2-11B (#30771) * remove unrelated changes * remove unrelated changes on phi and stable LM * add: Test for Falcon 10B * fix: formatting * fix: loading the falcon 10B in 8 bit precision using bitsanbytes. * fix: device placement * fix: broken tests. * fix: backwards compatibility for falcon 1B architecture. * chore: updated test. * chore: test_modeling_falcon.py to use the 11B model. * chore: minor edit * chore: formating. --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Co-authored-by: ArthurZucker --- .../models/falcon/configuration_falcon.py | 5 ++++ .../models/falcon/modeling_falcon.py | 30 +++++++++++++------ tests/models/falcon/test_modeling_falcon.py | 21 ++++++++++++- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/falcon/configuration_falcon.py b/src/transformers/models/falcon/configuration_falcon.py index ce10faeb20cf91..0dd61047dd275f 100644 --- a/src/transformers/models/falcon/configuration_falcon.py +++ b/src/transformers/models/falcon/configuration_falcon.py @@ -42,6 +42,9 @@ class FalconConfig(PretrainedConfig): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 71): Number of attention heads for each attention layer in the Transformer encoder. + num_ln_in_parallel_attn (`int`, *optional*): + Set to 2 if separate layer norms are to be used for the MLP and the attention output when using parallel + attention, otherwise, 1. layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): The epsilon used by the layer normalization layers. initializer_range (`float`, *optional*, defaults to 0.02): @@ -115,6 +118,7 @@ def __init__( hidden_size=4544, num_hidden_layers=32, num_attention_heads=71, + num_ln_in_parallel_attn=None, layer_norm_epsilon=1e-5, initializer_range=0.02, use_cache=True, @@ -154,6 +158,7 @@ def __init__( self.multi_query = multi_query # Ignored when new_decoder_architecture is True self.parallel_attn = parallel_attn self.bias = bias + self.num_ln_in_parallel_attn = num_ln_in_parallel_attn self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta self.rope_scaling = rope_scaling diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 76ca4110e81848..a171c875dbdc0a 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -767,15 +767,20 @@ def __init__(self, config: FalconConfig): self.hidden_dropout = config.hidden_dropout self.config = config - if config.new_decoder_architecture: - # The layer norm before self-attention - self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - # The layer norm before the MLP - self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - else: + if config.num_ln_in_parallel_attn is None and config.new_decoder_architecture: + config.num_ln_in_parallel_attn = 2 + + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - if not config.parallel_attn: - self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + else: + if config.num_ln_in_parallel_attn == 2: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) def forward( self, @@ -796,7 +801,7 @@ def forward( residual = hidden_states - if self.config.new_decoder_architecture: + if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2: attention_layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: @@ -826,6 +831,13 @@ def forward( ) mlp_layernorm_out = self.post_attention_layernorm(residual) + if ( + self.config.new_decoder_architecture + and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1 + ): + mlp_layernorm_out = attention_layernorm_out + outputs = attn_outputs[1:] # MLP. diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index c8ad2d9b322725..59ab3161403421 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -602,6 +602,25 @@ def test_lm_generate_falcon(self): self.assertEqual(output_str, EXPECTED_OUTPUT) + @slow + @require_bitsandbytes + def test_lm_generate_falcon_11b(self): + tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-11B", padding_side="left") + model = FalconForCausalLM.from_pretrained( + "tiiuae/falcon-11B", device_map={"": torch_device}, load_in_8bit=True + ) + model.eval() + inputs = tokenizer( + "Two roads diverged in a yellow wood,", return_tensors="pt", return_token_type_ids=False + ).to(torch_device) + + EXPECTED_OUTPUT = "Two roads diverged in a yellow wood,\nAnd sorry I could not travel both\n" + + output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=9) + output_str = tokenizer.batch_decode(output_ids)[0] + + self.assertEqual(output_str, EXPECTED_OUTPUT) + @slow def test_lm_generation_big_models(self): # The big models are way too big for the CI, so we use tiny random models that resemble their @@ -647,7 +666,7 @@ def test_batched_generation(self): tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( "tiiuae/falcon-7b", - device_map="auto", + device_map={"": torch_device}, load_in_4bit=True, ) From 453893ed154e5cd31ecf7d4b1b71189aeb29a7f0 Mon Sep 17 00:00:00 2001 From: Eduardo Pacheco <69953243+EduardoPach@users.noreply.github.com> Date: Mon, 13 May 2024 13:34:45 +0200 Subject: [PATCH 07/19] [GroundingDino] Adding ms_deform_attn kernels (#30768) * Adding ms_deform_attn kernels to GroundingDino * Pointing to deformable detr kernels --- .../models/grounding_dino/modeling_grounding_dino.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index dc3523f33d46bf..72fc14edebd560 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -73,7 +73,7 @@ def load_cuda_kernels(): global MultiScaleDeformableAttention - root = Path(__file__).resolve().parent.parent.parent / "kernels" / "grounding_dino" + root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr" src_files = [ root / filename for filename in [ From a0779b9e19093dc0371abbf516030491eec3d86c Mon Sep 17 00:00:00 2001 From: Poedator <24738311+poedator@users.noreply.github.com> Date: Mon, 13 May 2024 13:46:06 +0200 Subject: [PATCH 08/19] Llama: fix custom 4D masks, v2 (#30348) * 4d mask fixes * Update custom 4D mask logic * test moved to mixin * extra tests 4d mask * upd 4d mask and StaticCache handling * added Mask4DTestHard to mistral tests * post-rebase fixes * test fixes for StaticCache * make fix-copies * upd 1 after #30476 * fix common tests * rm elif attention_mask.dim() == 4: * tests combined, fixed, mixtral supported * bigbird style chg reverted * rm if attention_mask.dim() == 2 * modeling_llama formatting chg --------- Co-authored-by: Joao Gante --- src/transformers/modeling_attn_mask_utils.py | 26 +- .../modeling_bigbird_pegasus.py | 1 - .../models/cohere/modeling_cohere.py | 40 +-- src/transformers/models/dbrx/modeling_dbrx.py | 40 +-- .../models/gemma/modeling_gemma.py | 40 +-- .../models/llama/modeling_llama.py | 40 +-- src/transformers/models/olmo/modeling_olmo.py | 40 +-- tests/models/llama/test_modeling_llama.py | 257 +++++++++++++++++- tests/models/mistral/test_modeling_mistral.py | 124 +++++++++ tests/test_modeling_common.py | 74 +++++ tests/test_modeling_utils.py | 225 --------------- 11 files changed, 541 insertions(+), 366 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 8dcf40268d0324..fb85d018c9f979 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -250,7 +250,7 @@ def _ignore_causal_mask_sdpa( allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). """ - batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] + _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] key_value_length = query_length + past_key_values_length is_tracing = ( @@ -275,11 +275,7 @@ def _ignore_causal_mask_sdpa( ignore_causal_mask = True elif sliding_window is None or key_value_length < sliding_window: if len(attention_mask.shape) == 4: - expected_shape = (batch_size, 1, query_length, key_value_length) - if tuple(attention_mask.shape) != expected_shape: - raise ValueError( - f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." - ) + return False elif (is_training or not is_tracing) and torch.all(attention_mask == 1): if query_length == 1 or key_value_length == query_length: # For query_length == 1, causal attention and bi-directional attention are the same. @@ -387,12 +383,18 @@ def _prepare_4d_causal_attention_mask_for_sdpa( input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device ) else: - expanded_4d_mask = attn_mask_converter.to_4d( - attention_mask, - input_shape[-1], - dtype=inputs_embeds.dtype, - key_value_length=key_value_length, - ) + if attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + expanded_4d_mask = attention_mask + else: + expanded_4d_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + dtype=inputs_embeds.dtype, + key_value_length=key_value_length, + ) # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 74ec4432a57a66..b4e6419f9905e4 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -14,7 +14,6 @@ # limitations under the License. """ PyTorch BigBirdPegasus model.""" - import copy import math from typing import List, Optional, Tuple, Union diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index d96131d7705ea4..b25528dfe73e2e 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -995,37 +995,27 @@ def _update_causal_mask( else past_seen_tokens + sequence_length + 1 ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - logger.warning_once( - "Passing a 4d mask shorter than the input length is deprecated and will be removed in " - "transformers v4.42.0" - ) - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice - if ( self.config._attn_implementation == "sdpa" and attention_mask is not None diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 2e185aa885d6dd..38c1fc814b1cad 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1241,37 +1241,27 @@ def _update_causal_mask( else past_seen_tokens + sequence_length + 1 ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - logger.warning_once( - "Passing a 4d mask shorter than the input length is deprecated and will be removed in " - "transformers v4.42.0" - ) - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice - if ( self.config._attn_implementation == "sdpa" and attention_mask is not None diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8f7893704780d1..12d01a6ea04d3e 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -986,37 +986,27 @@ def _update_causal_mask( else past_seen_tokens + sequence_length + 1 ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - logger.warning_once( - "Passing a 4d mask shorter than the input length is deprecated and will be removed in " - "transformers v4.42.0" - ) - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice - if ( self.config._attn_implementation == "sdpa" and attention_mask is not None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d840b03faf71fb..c6da59fcfb3edc 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1073,37 +1073,27 @@ def _update_causal_mask( else past_seen_tokens + sequence_length + 1 ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - logger.warning_once( - "Passing a 4d mask shorter than the input length is deprecated and will be removed in " - "transformers v4.42.0" - ) - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice - if ( self.config._attn_implementation == "sdpa" and attention_mask is not None diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 5009ac84be2ea7..6a7b2f748fcf03 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1052,37 +1052,27 @@ def _update_causal_mask( else past_seen_tokens + sequence_length + 1 ) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - elif attention_mask.dim() == 4: - # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with - # cache. In that case, the 4D attention mask attends to the newest tokens only. - if attention_mask.shape[-2] < cache_position[0] + sequence_length: - logger.warning_once( - "Passing a 4d mask shorter than the input length is deprecated and will be removed in " - "transformers v4.42.0" - ) - offset = cache_position[0] - else: - offset = 0 - mask_shape = attention_mask.shape - mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice - if ( self.config._attn_implementation == "sdpa" and attention_mask is not None diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index e63e53797462b4..5d402bd8599477 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -12,8 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Testing suite for the PyTorch LLaMA model. """ +"""Testing suite for the PyTorch LLaMA model.""" +import gc import tempfile import unittest @@ -21,7 +22,7 @@ from packaging import version from parameterized import parameterized -from transformers import LlamaConfig, is_torch_available, set_seed +from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, @@ -804,7 +805,7 @@ def test_model_7b_logits(self): '
 \ndef main():\n    factory = InterfaceManagerFactory(start=datetime.now())\n    managers = []\n    for i in range(10):\n        managers.append(factory.build(id=i))\n  class InterfaceManagerFactory(AbstractManagerFactory):\n    def __init__(',
             '
  = 0 :=\nbegin\nsplit,\n{ intros h f,\n    rw pi_1_etalisation at h,\n    simp [h],\n    refl\n},\n{ intro h,\n    have := @quasi_adjoint C D P,\n    simp [←pi_1_etalisation, this, h],\n    refl\n}\nend\n  /-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/\ntheorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :\nπ₁ P = 0 ↔ '
         ]
-        EXPECTED_IDS = torch.tensor([[    1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898,29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]])
+        EXPECTED_IDS = torch.tensor([[1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898, 29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]])
         # fmt: on
         self.assertEqual(processed_text_suffix_first, EXPECTED_TEXT)
         input_ids = tokenizer(self.PROMPTS[0], return_tensors="pt")["input_ids"]
@@ -816,3 +817,253 @@ def test_model_7b_logits(self):
         ]
         infilling = tokenizer.batch_decode(generated_ids)
         self.assertEqual(infilling, EXPECTED_INFILLING)
+
+
+@slow
+@require_torch_gpu
+class Mask4DTestHard(unittest.TestCase):
+    def tearDown(self):
+        gc.collect()
+        torch.cuda.empty_cache()
+
+    def setUp(self):
+        model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
+        self.model_dtype = torch.float32
+        self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
+        self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
+
+    def get_test_data(self):
+        template = "my favorite {}"
+        items = ("pet is a", "artist plays a", "name is L")  # same number of tokens in each item
+
+        batch_separate = [template.format(x) for x in items]  # 3 separate lines
+        batch_shared_prefix = template.format(" ".join(items))  # 1 line with options concatenated
+
+        input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
+        input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
+
+        mask_shared_prefix = torch.tensor(
+            [
+                [
+                    [
+                        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
+                    ]
+                ]
+            ],
+            device=torch_device,
+        )
+
+        position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
+
+        # building custom positions ids based on custom mask
+        position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
+        # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
+
+        # inverting the mask
+        min_dtype = torch.finfo(self.model_dtype).min
+        mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
+
+        return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
+
+    def test_stacked_causal_mask(self):
+        (
+            input_ids,
+            position_ids,
+            input_ids_shared_prefix,
+            mask_shared_prefix,
+            position_ids_shared_prefix,
+        ) = self.get_test_data()
+
+        # regular batch
+        logits = self.model.forward(input_ids, position_ids=position_ids).logits
+        logits_last = logits[:, -1, :]  # last tokens in each batch line
+        decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
+
+        # single forward run with 4D custom mask
+        logits_shared_prefix = self.model.forward(
+            input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
+        ).logits
+        logits_shared_prefix_last = logits_shared_prefix[
+            0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
+        ]  # last three tokens
+        decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
+
+        self.assertEqual(decoded, decoded_shared_prefix)
+
+    def test_partial_stacked_causal_mask(self):
+        # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
+
+        (
+            input_ids,
+            position_ids,
+            input_ids_shared_prefix,
+            mask_shared_prefix,
+            position_ids_shared_prefix,
+        ) = self.get_test_data()
+
+        # regular batch
+        logits = self.model.forward(input_ids, position_ids=position_ids).logits
+        logits_last = logits[:, -1, :]  # last tokens in each batch line
+        decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
+
+        # 2 forward runs with custom 4D masks
+        part_a = 3  # split point
+
+        input_1a = input_ids_shared_prefix[:, :part_a]
+        position_ids_1a = position_ids_shared_prefix[:, :part_a]
+        mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
+
+        outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
+        past_key_values_a = outs_1a["past_key_values"]
+
+        # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
+        input_1b = input_ids_shared_prefix[:, part_a:]
+        position_ids_1b = position_ids_shared_prefix[:, part_a:]
+        mask_1b = mask_shared_prefix[:, :, part_a:, :]
+        outs_1b = self.model.forward(
+            input_1b,
+            attention_mask=mask_1b,
+            position_ids=position_ids_1b,
+            past_key_values=past_key_values_a,
+        )
+        decoded_1b = [
+            self.tokenizer.decode(t)
+            for t in outs_1b.logits.argmax(-1)[
+                0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
+            ]
+        ]
+        self.assertEqual(decoded, decoded_1b)
+
+    def test_stacked_causal_mask_static_cache(self):
+        """same as above but with StaticCache"""
+        (
+            input_ids,
+            position_ids,
+            input_ids_shared_prefix,
+            mask_shared_prefix,
+            position_ids_shared_prefix,
+        ) = self.get_test_data()
+
+        # regular batch
+        logits = self.model.forward(input_ids, position_ids=position_ids).logits
+        logits_last = logits[:, -1, :]  # last tokens in each batch line
+        decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
+
+        # upgrade the model with StaticCache
+        max_cache_len = 16  # note that max_cache_len is greater than the attention_mask.shape[-1]
+        past_key_values = StaticCache(
+            config=self.model.config,
+            max_batch_size=1,
+            max_cache_len=max_cache_len,
+            device=torch_device,
+            dtype=self.model.dtype,
+        )
+
+        padded_attention_mask = torch.nn.functional.pad(
+            input=mask_shared_prefix,
+            pad=(0, max_cache_len - mask_shared_prefix.shape[-1]),
+            mode="constant",
+            value=torch.finfo(self.model_dtype).min,
+        )
+
+        # single forward run with 4D custom mask
+        logits_shared_prefix = self.model.forward(
+            input_ids_shared_prefix,
+            attention_mask=padded_attention_mask,
+            position_ids=position_ids_shared_prefix,
+            cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device),
+            past_key_values=past_key_values,
+        ).logits
+        logits_shared_prefix_last = logits_shared_prefix[
+            0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
+        ]  # last three tokens
+        decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
+
+        self.assertEqual(decoded, decoded_shared_prefix)
+
+    def test_partial_stacked_causal_mask_static_cache(self):
+        # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
+        # we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len])
+        (
+            input_ids,
+            position_ids,
+            input_ids_shared_prefix,
+            mask_shared_prefix,
+            position_ids_shared_prefix,
+        ) = self.get_test_data()
+
+        # regular batch
+        logits = self.model.forward(input_ids, position_ids=position_ids).logits
+        logits_last = logits[:, -1, :]  # last tokens in each batch line
+        decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
+
+        # upgrade the model with StaticCache
+        max_cache_len = 16  # note that max_cache_len is greater than the attention_mask.shape[-1]
+        past_key_values = StaticCache(
+            config=self.model.config,
+            max_batch_size=1,
+            max_cache_len=max_cache_len,
+            device=torch_device,
+            dtype=self.model.dtype,
+        )
+
+        # forward run for the first part of input
+        part_a = 3  # split point
+
+        input_1a = input_ids_shared_prefix[:, :part_a]
+        position_ids_1a = position_ids_shared_prefix[:, :part_a]
+        mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
+
+        padded_mask_1a = torch.nn.functional.pad(
+            input=mask_1a,
+            pad=(0, max_cache_len - mask_1a.shape[-1]),
+            mode="constant",
+            value=torch.finfo(self.model_dtype).min,
+        )
+
+        _ = self.model.forward(
+            input_1a,
+            attention_mask=padded_mask_1a,
+            position_ids=position_ids_1a,
+            cache_position=torch.arange(part_a, device=torch_device),
+            past_key_values=past_key_values,
+        )
+
+        # forward run for the second part of input
+        input_1b = input_ids_shared_prefix[:, part_a:]
+        position_ids_1b = position_ids_shared_prefix[:, part_a:]
+        mask_1b = mask_shared_prefix[:, :, part_a:, :]
+
+        padded_mask_1b = torch.nn.functional.pad(
+            input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0
+        )
+
+        outs_1b = self.model.forward(
+            input_1b,
+            attention_mask=padded_mask_1b,
+            position_ids=position_ids_1b,
+            cache_position=torch.arange(
+                part_a,
+                input_ids_shared_prefix.shape[-1],
+                device=torch_device,
+            ),
+            past_key_values=past_key_values,
+        )
+        decoded_1b = [
+            self.tokenizer.decode(t)
+            for t in outs_1b.logits.argmax(-1)[
+                0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
+            ]
+        ]
+        self.assertEqual(decoded, decoded_1b)
diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py
index 3500024b3ea173..bbc36c050e23f0 100644
--- a/tests/models/mistral/test_modeling_mistral.py
+++ b/tests/models/mistral/test_modeling_mistral.py
@@ -627,3 +627,127 @@ def test_speculative_generation(self):
         del model
         backend_empty_cache(torch_device)
         gc.collect()
+
+
+@slow
+@require_torch_gpu
+class Mask4DTestHard(unittest.TestCase):
+    def tearDown(self):
+        gc.collect()
+        torch.cuda.empty_cache()
+
+    def setUp(self):
+        model_name = "mistralai/Mistral-7B-v0.1"
+        self.model_dtype = torch.float32
+        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
+        self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
+
+    def get_test_data(self):
+        template = "my favorite {}"
+        items = ("pet is a", "artist plays a", "name is L")  # same number of tokens in each item
+
+        batch_separate = [template.format(x) for x in items]  # 3 separate lines
+        batch_shared_prefix = template.format(" ".join(items))  # 1 line with options concatenated
+
+        input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
+        input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
+
+        mask_shared_prefix = torch.tensor(
+            [
+                [
+                    [
+                        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
+                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
+                    ]
+                ]
+            ],
+            device=torch_device,
+        )
+
+        position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
+
+        # building custom positions ids based on custom mask
+        position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
+        # effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
+
+        # inverting the mask
+        min_dtype = torch.finfo(self.model_dtype).min
+        mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
+
+        return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
+
+    def test_stacked_causal_mask(self):
+        (
+            input_ids,
+            position_ids,
+            input_ids_shared_prefix,
+            mask_shared_prefix,
+            position_ids_shared_prefix,
+        ) = self.get_test_data()
+
+        # regular batch
+        logits = self.model.forward(input_ids, position_ids=position_ids).logits
+        logits_last = logits[:, -1, :]  # last tokens in each batch line
+        decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
+
+        # single forward run with 4D custom mask
+        logits_shared_prefix = self.model.forward(
+            input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
+        ).logits
+        logits_shared_prefix_last = logits_shared_prefix[
+            0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
+        ]  # last three tokens
+        decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
+
+        self.assertEqual(decoded, decoded_shared_prefix)
+
+    def test_partial_stacked_causal_mask(self):
+        # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
+
+        (
+            input_ids,
+            position_ids,
+            input_ids_shared_prefix,
+            mask_shared_prefix,
+            position_ids_shared_prefix,
+        ) = self.get_test_data()
+
+        # regular batch
+        logits = self.model.forward(input_ids, position_ids=position_ids).logits
+        logits_last = logits[:, -1, :]  # last tokens in each batch line
+        decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
+
+        # 2 forward runs with custom 4D masks
+        part_a = 3  # split point
+
+        input_1a = input_ids_shared_prefix[:, :part_a]
+        position_ids_1a = position_ids_shared_prefix[:, :part_a]
+        mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
+
+        outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
+        past_key_values_a = outs_1a["past_key_values"]
+
+        # Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
+        input_1b = input_ids_shared_prefix[:, part_a:]
+        position_ids_1b = position_ids_shared_prefix[:, part_a:]
+        mask_1b = mask_shared_prefix[:, :, part_a:, :]
+        outs_1b = self.model.forward(
+            input_1b, attention_mask=mask_1b, position_ids=position_ids_1b, past_key_values=past_key_values_a
+        )
+        decoded_1b = [
+            self.tokenizer.decode(t)
+            for t in outs_1b.logits.argmax(-1)[
+                0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
+            ]
+        ]
+        self.assertEqual(decoded, decoded_1b)
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index df585f4afc65e1..daa438e9f1bdea 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -4277,6 +4277,80 @@ def test_flash_attn_2_from_config(self):
 
                 self.assertFalse(fa2_correctly_converted)
 
+    def _get_custom_4d_mask_test_data(self):
+        # Sequence in which all but the last token is the same
+        input_ids = torch.tensor(
+            [[10, 11, 12, 13], [10, 11, 12, 14], [10, 11, 12, 15]], device=torch_device, dtype=torch.int64
+        )
+        position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
+
+        # Combining common prefix with the unique ending tokens:
+        input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0)
+
+        # Creating a 4D mask where each of the last 3 tokens do not attend to each other.
+        mask_shared_prefix = torch.tensor(
+            [
+                [
+                    [
+                        [1, 0, 0, 0, 0, 0],
+                        [1, 1, 0, 0, 0, 0],
+                        [1, 1, 1, 0, 0, 0],
+                        [1, 1, 1, 1, 0, 0],
+                        [1, 1, 1, 0, 1, 0],
+                        [1, 1, 1, 0, 0, 1],
+                    ]
+                ]
+            ],
+        )
+        # inverting the attention mask
+        mask_dtype = torch.float32
+        min_dtype = torch.finfo(mask_dtype).min
+        mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=mask_dtype, device=torch_device) * min_dtype
+
+        # Creating a position_ids tensor. note the repeating figures in the end.
+        position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
+
+        return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
+
+    def test_custom_4d_attention_mask(self):
+        if len(self.all_generative_model_classes) == 0:
+            self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")
+
+        for model_class in self.all_generative_model_classes:
+            if not model_class._supports_cache_class:
+                self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
+            config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+            model = model_class(config).to(device=torch_device, dtype=torch.float32)
+
+            (
+                input_ids,
+                position_ids,
+                input_ids_shared_prefix,
+                mask_shared_prefix,
+                position_ids_shared_prefix,
+            ) = self._get_custom_4d_mask_test_data()
+
+            logits = model.forward(input_ids, position_ids=position_ids).logits
+            # logits.shape == torch.Size([3, 4, ...])
+
+            logits_shared_prefix = model(
+                input_ids_shared_prefix,
+                attention_mask=mask_shared_prefix,
+                position_ids=position_ids_shared_prefix,
+            )[0]
+            # logits_shared_prefix.shape == torch.Size([1, 6, ...])
+
+            out_last_tokens = logits[:, -1, :]  # last tokens in each batch line
+            out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :]  # last three tokens
+
+            # comparing greedily-chosen tokens:
+            assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
+
+            # comparing softmax-normalized logits:
+            normalized_0 = F.softmax(out_last_tokens)
+            normalized_1 = F.softmax(out_shared_prefix_last_tokens)
+            torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
+
 
 global_rng = random.Random()
 
diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py
index f98e1a2a239105..9a00340d14b658 100755
--- a/tests/test_modeling_utils.py
+++ b/tests/test_modeling_utils.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import copy
-import gc
 import glob
 import json
 import os
@@ -53,7 +52,6 @@
     require_tf,
     require_torch,
     require_torch_accelerator,
-    require_torch_gpu,
     require_torch_multi_accelerator,
     require_usr_bin_time,
     slow,
@@ -2107,229 +2105,6 @@ def test_not_available_sdpa(self):
         self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))
 
 
-@require_torch_gpu
-class Mask4DTestBase(unittest.TestCase):
-    def tearDown(self):
-        gc.collect()
-        torch.cuda.empty_cache()
-
-    def get_test_data(self):
-        texts = ["the cat sat", "the cat had", "the cat is"]
-        encoded = [self.tokenizer.encode(t) for t in texts]
-        input_0 = torch.tensor(encoded, device=torch_device)
-        # tensor([[   1,  278, 6635, 3290],
-        # [   1,  278, 6635,  750],
-        # [   1,  278, 6635,  338]], device='cuda:0')
-
-        position_ids_0 = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
-
-        # Combining common prefix with the unique ending tokens:
-        input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0)
-        # tensor([[   1,  278, 6635, 3290,  750,  338]], device='cuda:0')
-
-        # Creating a 4D mask where each of the last 3 tokens do not attend to each other.
-        mask_1 = torch.tensor(
-            [
-                [
-                    [
-                        [1, 0, 0, 0, 0, 0],
-                        [1, 1, 0, 0, 0, 0],
-                        [1, 1, 1, 0, 0, 0],
-                        [1, 1, 1, 1, 0, 0],
-                        [1, 1, 1, 0, 1, 0],
-                        [1, 1, 1, 0, 0, 1],
-                    ]
-                ]
-            ],
-            device="cuda:0",
-            dtype=torch.int64,
-        )
-
-        # Creating a position_ids tensor. note the repeating figures in the end.
-        position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
-
-        return input_0, position_ids_0, input_1, mask_1, position_ids_1
-
-
-@require_torch_gpu
-class Mask4DTestFP32(Mask4DTestBase):
-    def setUp(self):
-        model_name = "JackFram/llama-68m"  # small Llama-like model from FlexFlow
-        self.model_dtype = torch.float32
-        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
-        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
-
-    def test_attention(self):
-        """comparing outputs of attention layer"""
-        # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
-        input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
-        causal_mask_1 = (1 - mask_1).to(self.model_dtype) * torch.finfo(self.model_dtype).min
-
-        hid_0 = self.model.model.embed_tokens(input_0)
-        outs_0 = self.model.model.layers[0].self_attn.forward(hid_0, position_ids=position_ids_0)[0]
-        # outs_0.shape == torch.Size([3, 4, 768])
-
-        hid_1 = self.model.model.embed_tokens(input_1)
-        outs_1 = self.model.model.layers[0].self_attn.forward(
-            hid_1, attention_mask=causal_mask_1, position_ids=position_ids_1
-        )[0]
-        # outs_1.shape == torch.Size([1, 6, 768])
-
-        outs_0_last_tokens = outs_0[:, -1, :]  # last tokens in each batch line
-        outs_1_last_tokens = outs_1[0, -3:, :]  # last three tokens
-        torch.testing.assert_close(outs_0_last_tokens, outs_1_last_tokens)
-
-    def test_causal_model_logits(self):
-        """comparing logits outputs of whole inner model"""
-        # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
-        input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
-
-        logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
-        logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
-
-        logits_0_last_tokens = logits_0[:, -1, :]  # last tokens in each batch line
-        logits_1_last_tokens = logits_1[0, -3:, :]  # last three tokens
-        torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens)
-
-
-@require_torch_gpu
-class Mask4DTestFP16(Mask4DTestBase):
-    test_attention = Mask4DTestFP32.test_attention
-
-    def setUp(self):
-        model_name = "JackFram/llama-68m"  # small Llama-like model from FlexFlow
-        self.model_dtype = torch.float16
-        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
-        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
-
-    def test_causal_model_logits(self):
-        """comparing logits outputs of whole inner model"""
-        # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
-        input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
-
-        logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
-        logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
-
-        logits_0_last_tokens = logits_0[:, -1, :]  # last tokens in each batch line
-        logits_1_last_tokens = logits_1[0, -3:, :]  # last three tokens
-
-        indices_0 = logits_0_last_tokens.sort(descending=True).indices
-        indices_1 = logits_1_last_tokens.sort(descending=True).indices
-
-        # checking logits, but note relaxed tolerances for FP16
-        torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens, atol=0.02, rtol=0.001)
-
-        # checking tokens order for the top tokens
-        for token_ids_0, token_ids_1 in zip(indices_0, indices_1):
-            self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128]))
-
-
-@slow
-@require_torch_gpu
-class Mask4DTestHard(unittest.TestCase):
-    def tearDown(self):
-        gc.collect()
-        torch.cuda.empty_cache()
-
-    def setUp(self):
-        model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
-        self.model_dtype = torch.float32
-        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
-        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
-
-    def get_test_data(self):
-        template = "my favorite {}"
-        items = ("pet is a", "artist plays a", "name is L")  # same number of tokens in each item
-
-        batch_0 = [template.format(x) for x in items]  # 3 separate lines
-        batch_1 = template.format(" ".join(items))  # 1 line with options concatenated
-
-        input_0 = self.tokenizer(batch_0, return_tensors="pt").input_ids.to(torch_device)
-        input_1 = self.tokenizer(batch_1, return_tensors="pt").input_ids.to(torch_device)
-
-        mask_1 = torch.tensor(
-            [
-                [
-                    [
-                        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-                        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-                        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
-                        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
-                        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
-                        [1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
-                        [1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
-                        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
-                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
-                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
-                        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
-                    ]
-                ]
-            ],
-            device=torch_device,
-            dtype=torch.int64,
-        )
-
-        position_ids_0 = torch.arange(input_0.shape[1]).tile(input_0.shape[0], 1).to(torch_device)
-        # equivalent: position_ids_1 = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
-        position_ids_1 = (mask_1.sum(dim=-1) - 1).reshape(1, -1)  # same but nicer
-
-        return input_0, position_ids_0, input_1, mask_1, position_ids_1
-
-    def test_stacked_causal_mask(self):
-        # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
-        input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
-
-        # regular batch
-        logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
-        logits_0_last = logits_0[:, -1, :]  # last tokens in each batch line
-        decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]
-
-        # single forward run with 4D custom mask
-        logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
-        logits_1_last = logits_1[0, torch.where(position_ids_1 == position_ids_1.max())[1], :]  # last three tokens
-        decoded_1 = [self.tokenizer.decode(t) for t in logits_1_last.argmax(dim=-1)]
-
-        self.assertEqual(decoded_0, decoded_1)
-
-    def test_partial_stacked_causal_mask(self):
-        # Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention
-        # masks
-
-        # Input 0: one row per sentence; Input 1: same data, but stacked into a single row with custom attention
-        input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
-
-        # regular batch
-        logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
-        logits_0_last = logits_0[:, -1, :]  # last tokens in each batch line
-        decoded_0 = [self.tokenizer.decode(t) for t in logits_0_last.argmax(dim=-1)]
-
-        # 2 forward runs with custom 4D masks
-        part_a = 3  # split point
-
-        input_1a = input_1[:, :part_a]
-        position_ids_1a = position_ids_1[:, :part_a]
-        mask_1a = mask_1[:, :, :part_a, :part_a]
-
-        outs_1a = self.model.forward(input_1a, attention_mask=mask_1a.bool(), position_ids=position_ids_1a)
-        past_key_values_a = outs_1a["past_key_values"]
-
-        input_1b = input_1[:, part_a:]
-        position_ids_1b = position_ids_1[:, part_a:]
-        mask_1b = mask_1[:, :, part_a:, :]
-
-        outs_1b = self.model.forward(
-            input_1b, attention_mask=mask_1b.bool(), position_ids=position_ids_1b, past_key_values=past_key_values_a
-        )
-
-        decoded_1b = [
-            self.tokenizer.decode(t)
-            for t in outs_1b.logits.argmax(-1)[0, torch.where(position_ids_1 == position_ids_1.max())[1] - part_a]
-        ]
-
-        self.assertEqual(decoded_0, decoded_1b)
-
-
 @require_torch
 class TestTensorSharing(TestCasePlus):
     def test_disjoint(self):

From f823fec53e6c18542a6064d28e4b4ada03e50c5b Mon Sep 17 00:00:00 2001
From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date: Mon, 13 May 2024 14:35:45 +0200
Subject: [PATCH 09/19] Generation / FIX: Fix multi-device generation (#30746)

* attempt to fix multi-device generation

* fix

* final fix

* final fix

* fix

* fix

* fix

* fix

* add joao suggestion

* fix
---
 src/transformers/generation/utils.py | 25 +++++++++++++++++--------
 1 file changed, 17 insertions(+), 8 deletions(-)

diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index 3bff8eea50f0c6..9135bb204846ee 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -476,6 +476,7 @@ def _prepare_attention_mask_for_generation(
         )
         can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
         attention_mask_from_padding = inputs.ne(pad_token_id).long()
+
         attention_mask = (
             attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
         )
@@ -1340,7 +1341,10 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa
         return self._static_cache
 
     def _prepare_special_tokens(
-        self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None
+        self,
+        generation_config: GenerationConfig,
+        kwargs_has_attention_mask: Optional[bool] = None,
+        device: Optional[Union[torch.device, str]] = None,
     ):
         """
         Prepares the special tokens for generation, overwriting the generation config with their processed versions
@@ -1352,15 +1356,18 @@ def _prepare_special_tokens(
         """
 
         # Convert special tokens to tensors (if they exist)
-        def _tensor_or_none(token):
+        def _tensor_or_none(token, device=None):
+            if device is None:
+                device = self.device
+
             if token is None or isinstance(token, torch.Tensor):
                 return token
-            return torch.tensor(token, device=self.device, dtype=torch.long)
+            return torch.tensor(token, device=device, dtype=torch.long)
 
-        bos_token_id = _tensor_or_none(generation_config.bos_token_id)
-        eos_token_id = _tensor_or_none(generation_config.eos_token_id)
-        pad_token_id = _tensor_or_none(generation_config.pad_token_id)
-        decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id)
+        bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
+        eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
+        pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
+        decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
         decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
 
         # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
@@ -1511,7 +1518,6 @@ def generate(
         accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
         requires_attention_mask = "encoder_outputs" not in model_kwargs
         kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
-        self._prepare_special_tokens(generation_config, kwargs_has_attention_mask)
 
         # 3. Define model inputs
         inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
@@ -1519,6 +1525,9 @@ def generate(
         )
         batch_size = inputs_tensor.shape[0]
 
+        device = inputs_tensor.device
+        self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
+
         # decoder-only models must use left-padding for batched generation.
         if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
             # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`

From f4dc26d46687f5f4baf3fe64a1d87cafefbeec53 Mon Sep 17 00:00:00 2001
From: Joao Gante 
Date: Mon, 13 May 2024 14:12:58 +0100
Subject: [PATCH 10/19] Qwen: incorrect setup flag (#30776)

qwen does not support the new cache classes
---
 src/transformers/models/qwen2/modeling_qwen2.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py
index b5a1370ae1fc8f..709504aba7157c 100644
--- a/src/transformers/models/qwen2/modeling_qwen2.py
+++ b/src/transformers/models/qwen2/modeling_qwen2.py
@@ -821,7 +821,6 @@ class Qwen2PreTrainedModel(PreTrainedModel):
     _skip_keys_device_placement = "past_key_values"
     _supports_flash_attn_2 = True
     _supports_sdpa = True
-    _supports_cache_class = True
 
     def _init_weights(self, module):
         std = self.config.initializer_range

From 69d9bca55af0f25a3b1a2cde5df3b3c8d42cb3ad Mon Sep 17 00:00:00 2001
From: Fanli Lin 
Date: Mon, 13 May 2024 22:00:39 +0800
Subject: [PATCH 11/19] enable Pipeline to get device from model  (#30534)

* check model.device

* fix

* style fix

* move model device

* remove print

* add comment

* fix

* add unit test

* optimize

* change test names and add more cases

* Update tests/pipelines/test_pipelines_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
---
 src/transformers/pipelines/base.py       |  7 ++--
 tests/pipelines/test_pipelines_common.py | 47 ++++++++++++++++++++++++
 2 files changed, 51 insertions(+), 3 deletions(-)

diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py
index 4bb5cffb1287a9..b318e1b12b414c 100644
--- a/src/transformers/pipelines/base.py
+++ b/src/transformers/pipelines/base.py
@@ -845,6 +845,8 @@ def __init__(
                 device = -1
 
         if is_torch_available() and self.framework == "pt":
+            if device == -1 and self.model.device is not None:
+                device = self.model.device
             if isinstance(device, torch.device):
                 if device.type == "xpu" and not is_torch_xpu_available(check_device=True):
                     raise ValueError(f'{device} is not available, you should use device="cpu" instead')
@@ -871,11 +873,10 @@ def __init__(
             self.device = device if device is not None else -1
 
         self.binary_output = binary_output
-
-        # We shouldn't call `model.to()` for models loaded with accelerate
+        # We shouldn't call `model.to()` for models loaded with accelerate as well as the case that model is already on device
         if (
             self.framework == "pt"
-            and self.device is not None
+            and self.model.device != self.device
             and not (isinstance(self.device, int) and self.device < 0)
             and hf_device_map is None
         ):
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index c680b4c634de40..763c7d1a883314 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -48,6 +48,7 @@
     require_tf,
     require_torch,
     require_torch_accelerator,
+    require_torch_multi_accelerator,
     require_torch_or_tf,
     slow,
     torch_device,
@@ -519,6 +520,52 @@ def test_pipeline_negative_device(self):
         actual_output = classifier("Test input.")
         self.assertEqual(expected_output, actual_output)
 
+    @require_torch_accelerator
+    def test_pipeline_no_device(self):
+        # Test when no device is passed to pipeline
+        import torch
+
+        from transformers import AutoModelForCausalLM
+
+        tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
+        # Case 1: Model is manually moved to device
+        model = AutoModelForCausalLM.from_pretrained(
+            "hf-internal-testing/tiny-random-bert", torch_dtype=torch.float16
+        ).to(torch_device)
+        model_device = model.device
+        pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
+        self.assertEqual(pipe.model.device, model_device)
+        # Case 2: Model is loaded by accelerate
+        model = AutoModelForCausalLM.from_pretrained(
+            "hf-internal-testing/tiny-random-bert", device_map=torch_device, torch_dtype=torch.float16
+        )
+        model_device = model.device
+        pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
+        self.assertEqual(pipe.model.device, model_device)
+        # Case 3: device_map is passed to model and device is passed to pipeline
+        model = AutoModelForCausalLM.from_pretrained(
+            "hf-internal-testing/tiny-random-bert", device_map=torch_device, torch_dtype=torch.float16
+        )
+        with self.assertRaises(ValueError):
+            pipe = pipeline("text-generation", model=model, device="cpu", tokenizer=tokenizer)
+
+    @require_torch_multi_accelerator
+    def test_pipeline_device_not_equal_model_device(self):
+        # Test when device ids are different, pipeline should move the model to the passed device id
+        import torch
+
+        from transformers import AutoModelForCausalLM
+
+        tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
+        model_device = f"{torch_device}:1"
+        model = AutoModelForCausalLM.from_pretrained(
+            "hf-internal-testing/tiny-random-bert", torch_dtype=torch.float16
+        ).to(model_device)
+        target_device = f"{torch_device}:0"
+        self.assertNotEqual(model_device, target_device)
+        pipe = pipeline("text-generation", model=model, device=target_device, tokenizer=tokenizer)
+        self.assertEqual(pipe.model.device, torch.device(target_device))
+
     @slow
     @require_torch
     def test_load_default_pipelines_pt(self):

From ce87dca1d7336a46747bf32645b93712c95d0606 Mon Sep 17 00:00:00 2001
From: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Date: Mon, 13 May 2024 16:47:58 +0200
Subject: [PATCH 12/19] [Object detection pipeline] Lower threshold (#30710)

* Lower threshold

* Address comment
---
 src/transformers/pipelines/object_detection.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/transformers/pipelines/object_detection.py b/src/transformers/pipelines/object_detection.py
index d6ae63f4bd19f3..36946cbf8a4511 100644
--- a/src/transformers/pipelines/object_detection.py
+++ b/src/transformers/pipelines/object_detection.py
@@ -83,7 +83,7 @@ def __call__(self, *args, **kwargs) -> Union[Predictions, List[Prediction]]:
 
                 The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
                 same format: all as HTTP(S) links, all as local paths, or all as PIL images.
-            threshold (`float`, *optional*, defaults to 0.9):
+            threshold (`float`, *optional*, defaults to 0.5):
                 The probability necessary to make a prediction.
             timeout (`float`, *optional*, defaults to None):
                 The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
@@ -120,7 +120,7 @@ def _forward(self, model_inputs):
             model_outputs["bbox"] = model_inputs["bbox"]
         return model_outputs
 
-    def postprocess(self, model_outputs, threshold=0.9):
+    def postprocess(self, model_outputs, threshold=0.5):
         target_size = model_outputs["target_size"]
         if self.tokenizer is not None:
             # This is a LayoutLMForTokenClassification variant.

From de2f722172089473a0d1ff0c037cd6b29460493f Mon Sep 17 00:00:00 2001
From: Joao Gante 
Date: Mon, 13 May 2024 15:48:20 +0100
Subject: [PATCH 13/19] Generate: remove near-duplicate sample/greedy copy
 (#30773)

---
 src/transformers/generation/utils.py          | 500 +++---------------
 .../models/musicgen/modeling_musicgen.py      |   4 +-
 .../modeling_musicgen_melody.py               |   4 +-
 src/transformers/models/rag/modeling_rag.py   |   2 +-
 4 files changed, 92 insertions(+), 418 deletions(-)

diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index 9135bb204846ee..1c90fdd30753e5 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -1683,17 +1683,6 @@ def generate(
                 streamer=streamer,
                 **model_kwargs,
             )
-        if generation_mode == GenerationMode.GREEDY_SEARCH:
-            # 11. run greedy search
-            result = self._greedy_search(
-                input_ids,
-                logits_processor=prepared_logits_processor,
-                stopping_criteria=prepared_stopping_criteria,
-                generation_config=generation_config,
-                synced_gpus=synced_gpus,
-                streamer=streamer,
-                **model_kwargs,
-            )
 
         elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
             if not model_kwargs["use_cache"]:
@@ -1709,9 +1698,11 @@ def generate(
                 **model_kwargs,
             )
 
-        elif generation_mode == GenerationMode.SAMPLE:
+        elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
             # 11. prepare logits warper
-            logits_warper = self._get_logits_warper(generation_config)
+            prepared_logits_warper = (
+                self._get_logits_warper(generation_config) if generation_config.do_sample else None
+            )
 
             # 12. expand input_ids with `num_return_sequences` additional sequences per batch
             input_ids, model_kwargs = self._expand_inputs_for_generation(
@@ -1721,11 +1712,11 @@ def generate(
                 **model_kwargs,
             )
 
-            # 13. run sample
+            # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
             result = self._sample(
                 input_ids,
                 logits_processor=prepared_logits_processor,
-                logits_warper=logits_warper,
+                logits_warper=prepared_logits_warper,
                 stopping_criteria=prepared_stopping_criteria,
                 generation_config=generation_config,
                 synced_gpus=synced_gpus,
@@ -1733,38 +1724,11 @@ def generate(
                 **model_kwargs,
             )
 
-        elif generation_mode == GenerationMode.BEAM_SEARCH:
-            # 11. prepare beam search scorer
-            beam_scorer = BeamSearchScorer(
-                batch_size=batch_size,
-                num_beams=generation_config.num_beams,
-                device=inputs_tensor.device,
-                length_penalty=generation_config.length_penalty,
-                do_early_stopping=generation_config.early_stopping,
-                num_beam_hyps_to_keep=generation_config.num_return_sequences,
-                max_length=generation_config.max_length,
-            )
-            # 12. interleave input_ids with `num_beams` additional sequences per batch
-            input_ids, model_kwargs = self._expand_inputs_for_generation(
-                input_ids=input_ids,
-                expand_size=generation_config.num_beams,
-                is_encoder_decoder=self.config.is_encoder_decoder,
-                **model_kwargs,
-            )
-            # 13. run beam search
-            result = self._beam_search(
-                input_ids,
-                beam_scorer,
-                logits_processor=prepared_logits_processor,
-                stopping_criteria=prepared_stopping_criteria,
-                generation_config=generation_config,
-                synced_gpus=synced_gpus,
-                **model_kwargs,
-            )
-
-        elif generation_mode == GenerationMode.BEAM_SAMPLE:
+        elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
             # 11. prepare logits warper
-            logits_warper = self._get_logits_warper(generation_config)
+            prepared_logits_warper = (
+                self._get_logits_warper(generation_config) if generation_config.do_sample else None
+            )
 
             # 12. prepare beam search scorer
             beam_scorer = BeamSearchScorer(
@@ -1786,11 +1750,11 @@ def generate(
             )
 
             # 14. run beam sample
-            result = self._beam_sample(
+            result = self._beam_search(
                 input_ids,
                 beam_scorer,
                 logits_processor=prepared_logits_processor,
-                logits_warper=logits_warper,
+                logits_warper=prepared_logits_warper,
                 stopping_criteria=prepared_stopping_criteria,
                 generation_config=generation_config,
                 synced_gpus=synced_gpus,
@@ -2284,162 +2248,32 @@ def _greedy_search(
         **model_kwargs,
     ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
         r"""
-        Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
-        used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
-
-        Parameters:
-            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
-                The sequence used as a prompt for the generation.
-            logits_processor (`LogitsProcessorList`):
-                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
-                used to modify the prediction scores of the language modeling head applied at each generation step.
-            stopping_criteria (`StoppingCriteriaList`):
-                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
-                used to tell if the generation loop should stop.
-            generation_config ([`~generation.GenerationConfig`]):
-                The generation configuration to be used as parametrization of the decoding method.
-            synced_gpus (`bool`):
-                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
-            streamer (`BaseStreamer`, *optional*):
-                Streamer object that will be used to stream the generated sequences. Generated tokens are passed
-                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
-            model_kwargs:
-                Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
-                If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
-
-        Return:
-            [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or
-            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
-            [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
-            `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
-            `model.config.is_encoder_decoder=True`.
+        Deprecated. Use `._sample()` instead, passing the same arguments.
         """
-        # init values
-        pad_token_id = generation_config.pad_token_id
-        output_attentions = generation_config.output_attentions
-        output_hidden_states = generation_config.output_hidden_states
-        output_scores = generation_config.output_scores
-        output_logits = generation_config.output_logits
-        return_dict_in_generate = generation_config.return_dict_in_generate
-        has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
 
-        # init attention / hidden states / scores tuples
-        raw_logits = () if (return_dict_in_generate and output_logits) else None
-        scores = () if (return_dict_in_generate and output_scores) else None
-        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
-        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
-        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
-
-        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
-        if return_dict_in_generate and self.config.is_encoder_decoder:
-            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
-            encoder_hidden_states = (
-                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
-            )
-
-        # keep track of which sequences are already finished
-        batch_size = input_ids.shape[0]
-        this_peer_finished = False
-        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
-        model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
-
-        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
-            # prepare model inputs
-            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
-
-            # forward pass to get next token
-            outputs = self(
-                **model_inputs,
-                return_dict=True,
-                output_attentions=output_attentions,
-                output_hidden_states=output_hidden_states,
-            )
-
-            if synced_gpus and this_peer_finished:
-                continue  # don't waste resources running the code we don't need
-
-            next_token_logits = outputs.logits[:, -1, :]
-
-            # pre-process distribution
-            next_tokens_scores = logits_processor(input_ids, next_token_logits)
-
-            # Store scores, attentions and hidden_states when required
-            if return_dict_in_generate:
-                if output_scores:
-                    scores += (next_tokens_scores,)
-                if output_logits:
-                    raw_logits += (next_token_logits,)
-                if output_attentions:
-                    decoder_attentions += (
-                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
-                    )
-                    if self.config.is_encoder_decoder:
-                        cross_attentions += (outputs.cross_attentions,)
-
-                if output_hidden_states:
-                    decoder_hidden_states += (
-                        (outputs.decoder_hidden_states,)
-                        if self.config.is_encoder_decoder
-                        else (outputs.hidden_states,)
-                    )
-
-            # argmax
-            next_tokens = torch.argmax(next_tokens_scores, dim=-1)
-
-            # finished sentences should have their next token be a padding token
-            if has_eos_stopping_criteria:
-                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
-
-            # update generated ids, model inputs, and length for next step
-            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
-            if streamer is not None:
-                streamer.put(next_tokens.cpu())
-            model_kwargs = self._update_model_kwargs_for_generation(
-                outputs,
-                model_kwargs,
-                is_encoder_decoder=self.config.is_encoder_decoder,
-            )
-
-            unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
-            this_peer_finished = unfinished_sequences.max() == 0
-
-        if streamer is not None:
-            streamer.end()
-
-        if return_dict_in_generate:
-            if self.config.is_encoder_decoder:
-                return GenerateEncoderDecoderOutput(
-                    sequences=input_ids,
-                    scores=scores,
-                    logits=raw_logits,
-                    encoder_attentions=encoder_attentions,
-                    encoder_hidden_states=encoder_hidden_states,
-                    decoder_attentions=decoder_attentions,
-                    cross_attentions=cross_attentions,
-                    decoder_hidden_states=decoder_hidden_states,
-                    past_key_values=model_kwargs.get("past_key_values"),
-                )
-            else:
-                return GenerateDecoderOnlyOutput(
-                    sequences=input_ids,
-                    scores=scores,
-                    logits=raw_logits,
-                    attentions=decoder_attentions,
-                    hidden_states=decoder_hidden_states,
-                    past_key_values=model_kwargs.get("past_key_values"),
-                )
-        else:
-            return input_ids
+        logger.warning_once(
+            "Calling `._greedy_search()` directly is deprecated and will be removed in v4.42. Use `._sample()` "
+            "instead, passing the same arguments."
+        )
+        return self._sample(
+            input_ids=input_ids,
+            logits_processor=logits_processor,
+            stopping_criteria=stopping_criteria,
+            generation_config=generation_config,
+            synced_gpus=synced_gpus,
+            streamer=streamer,
+            **model_kwargs,
+        )
 
     def _sample(
         self,
         input_ids: torch.LongTensor,
         logits_processor: LogitsProcessorList,
         stopping_criteria: StoppingCriteriaList,
-        logits_warper: LogitsProcessorList,
         generation_config: GenerationConfig,
         synced_gpus: bool,
         streamer: Optional["BaseStreamer"],
+        logits_warper: Optional[LogitsProcessorList] = None,
         **model_kwargs,
     ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
         r"""
@@ -2455,10 +2289,6 @@ def _sample(
             stopping_criteria (`StoppingCriteriaList`):
                 An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                 used to tell if the generation loop should stop.
-            logits_warper (`LogitsProcessorList`):
-                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
-                to warp the prediction score distribution of the language modeling head applied before multinomial
-                sampling at each generation step.
             generation_config ([`~generation.GenerationConfig`]):
                 The generation configuration to be used as parametrization of the decoding method.
             synced_gpus (`bool`):
@@ -2466,6 +2296,11 @@ def _sample(
             streamer (`BaseStreamer`, *optional*):
                 Streamer object that will be used to stream the generated sequences. Generated tokens are passed
                 through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+            logits_warper (`LogitsProcessorList`, *optional*):
+                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
+                to warp the prediction score distribution of the language modeling head applied before multinomial
+                sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
+                `generation_config`)
             model_kwargs:
                 Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                 an encoder-decoder model the kwargs should include `encoder_outputs`.
@@ -2485,6 +2320,12 @@ def _sample(
         output_logits = generation_config.output_logits
         return_dict_in_generate = generation_config.return_dict_in_generate
         has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
+        do_sample = generation_config.do_sample
+        if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
+            raise ValueError(
+                "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
+                f"{logits_warper})."
+            )
 
         # init attention / hidden states / scores tuples
         scores = () if (return_dict_in_generate and output_scores) else None
@@ -2525,7 +2366,8 @@ def _sample(
 
             # pre-process distribution
             next_token_scores = logits_processor(input_ids, next_token_logits)
-            next_token_scores = logits_warper(input_ids, next_token_scores)
+            if do_sample:
+                next_token_scores = logits_warper(input_ids, next_token_scores)
 
             # Store scores, attentions and hidden_states when required
             if return_dict_in_generate:
@@ -2547,9 +2389,12 @@ def _sample(
                         else (outputs.hidden_states,)
                     )
 
-            # sample
-            probs = nn.functional.softmax(next_token_scores, dim=-1)
-            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+            # token selection
+            if do_sample:
+                probs = nn.functional.softmax(next_token_scores, dim=-1)
+                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+            else:
+                next_tokens = torch.argmax(next_token_scores, dim=-1)
 
             # finished sentences should have their next token be a padding token
             if has_eos_stopping_criteria:
@@ -2622,6 +2467,7 @@ def _temporary_reorder_cache(self, past_key_values, beam_idx):
             past_key_values.reorder_cache(beam_idx)
         return past_key_values
 
+    # TODO (joao, v4.42): remove default for `logits_warper`
     def _beam_search(
         self,
         input_ids: torch.LongTensor,
@@ -2630,6 +2476,7 @@ def _beam_search(
         stopping_criteria: StoppingCriteriaList,
         generation_config: GenerationConfig,
         synced_gpus: bool,
+        logits_warper: Optional[LogitsProcessorList] = None,
         **model_kwargs,
     ) -> Union[GenerateBeamOutput, torch.LongTensor]:
         r"""
@@ -2652,6 +2499,11 @@ def _beam_search(
                 The generation configuration to be used as parametrization of the decoding method.
             synced_gpus (`bool`):
                 Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+            logits_warper (`LogitsProcessorList`, *optional*):
+                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
+                to warp the prediction score distribution of the language modeling head applied before multinomial
+                sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
+                `generation_config`)
             model_kwargs:
                 Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                 an encoder-decoder model the kwargs should include `encoder_outputs`.
@@ -2672,6 +2524,12 @@ def _beam_search(
         output_logits = generation_config.output_logits
         return_dict_in_generate = generation_config.return_dict_in_generate
         sequential = generation_config.low_memory
+        do_sample = generation_config.do_sample
+        if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
+            raise ValueError(
+                "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
+                f"{logits_warper})."
+            )
 
         batch_size = len(beam_scorer._beam_hyps)
         num_beams = beam_scorer.num_beams
@@ -2768,6 +2626,8 @@ def _beam_search(
             )  # (batch_size * num_beams, vocab_size)
 
             next_token_scores_processed = logits_processor(input_ids, next_token_scores)
+            if do_sample:
+                next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
             next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
                 next_token_scores_processed
             )
@@ -2795,11 +2655,20 @@ def _beam_search(
             vocab_size = next_token_scores.shape[-1]
             next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
 
-            # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
+            # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
+            # non eos token per beam.
             n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
-            next_token_scores, next_tokens = torch.topk(
-                next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
-            )
+            n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
+            if do_sample:
+                probs = nn.functional.softmax(next_token_scores, dim=-1)
+                next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
+                next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
+                next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
+                next_tokens = torch.gather(next_tokens, -1, _indices)
+            else:
+                next_token_scores, next_tokens = torch.topk(
+                    next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
+                )
 
             next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
             next_tokens = next_tokens % vocab_size
@@ -2897,219 +2766,24 @@ def _beam_sample(
         **model_kwargs,
     ) -> Union[GenerateBeamOutput, torch.LongTensor]:
         r"""
-        Generates sequences of token ids for models with a language modeling head using **beam search multinomial
-        sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
-
-        Parameters:
-            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
-                The sequence used as a prompt for the generation.
-            beam_scorer (`BeamScorer`):
-                A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
-                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
-            logits_processor (`LogitsProcessorList`):
-                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
-                used to modify the prediction scores of the language modeling head applied at each generation step.
-            stopping_criteria (`StoppingCriteriaList`):
-                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
-                used to tell if the generation loop should stop.
-            logits_warper (`LogitsProcessorList`):
-                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
-                to warp the prediction score distribution of the language modeling head applied before multinomial
-                sampling at each generation step.
-            generation_config ([`~generation.GenerationConfig`]):
-                The generation configuration to be used as parametrization of the decoding method.
-            synced_gpus (`bool`):
-                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
-            model_kwargs:
-                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
-                an encoder-decoder model the kwargs should include `encoder_outputs`.
-
-        Return:
-            [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
-            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
-            [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
-            `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
-            `model.config.is_encoder_decoder=True`.
+        Deprecated. Use `._beam_search()` instead, passing the same arguments.
         """
-        # init values
-        pad_token_id = generation_config.pad_token_id
-        eos_token_id = generation_config.eos_token_id
-        output_attentions = generation_config.output_attentions
-        output_hidden_states = generation_config.output_hidden_states
-        output_scores = generation_config.output_scores
-        output_logits = generation_config.output_logits
-        return_dict_in_generate = generation_config.return_dict_in_generate
 
-        batch_size = len(beam_scorer._beam_hyps)
-        num_beams = beam_scorer.num_beams
-
-        batch_beam_size, cur_len = input_ids.shape
-        model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
-
-        # init attention / hidden states / scores tuples
-        scores = () if (return_dict_in_generate and output_scores) else None
-        raw_logits = () if (return_dict_in_generate and output_logits) else None
-        beam_indices = (
-            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
+        logger.warning_once(
+            "Calling `._beam_sample()` directly is deprecated and will be removed in v4.42. Use `._beam_search()` "
+            "instead, passing the same arguments."
         )
-        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
-        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
-        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
-
-        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
-        if return_dict_in_generate and self.config.is_encoder_decoder:
-            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
-            encoder_hidden_states = (
-                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
-            )
-
-        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
-        beam_scores = beam_scores.view((batch_size * num_beams,))
-
-        this_peer_finished = False
-
-        decoder_prompt_len = input_ids.shape[-1]  # record the prompt length of decoder
-        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
-            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
-
-            outputs = self(
-                **model_inputs,
-                return_dict=True,
-                output_attentions=output_attentions,
-                output_hidden_states=output_hidden_states,
-            )
-
-            if synced_gpus and this_peer_finished:
-                cur_len = cur_len + 1
-                continue  # don't waste resources running the code we don't need
-
-            next_token_logits = outputs.logits[:, -1, :]
-
-            next_token_scores = nn.functional.log_softmax(
-                next_token_logits, dim=-1
-            )  # (batch_size * num_beams, vocab_size)
-
-            next_token_scores_processed = logits_processor(input_ids, next_token_scores)
-            next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
-            next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
-                next_token_scores_processed
-            )
-
-            # Store scores, attentions and hidden_states when required
-            if return_dict_in_generate:
-                if output_scores:
-                    scores += (next_token_scores_processed,)
-                if output_logits:
-                    raw_logits += (next_token_logits,)
-                if output_attentions:
-                    decoder_attentions += (
-                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
-                    )
-                    if self.config.is_encoder_decoder:
-                        cross_attentions += (outputs.cross_attentions,)
-
-                if output_hidden_states:
-                    decoder_hidden_states += (
-                        (outputs.decoder_hidden_states,)
-                        if self.config.is_encoder_decoder
-                        else (outputs.hidden_states,)
-                    )
-
-            # reshape for beam search
-            vocab_size = next_token_scores.shape[-1]
-            next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
-
-            probs = nn.functional.softmax(next_token_scores, dim=-1)
-
-            next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
-            next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
-
-            next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
-            next_tokens = torch.gather(next_tokens, -1, _indices)
-
-            next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
-            next_tokens = next_tokens % vocab_size
-
-            # stateless
-            beam_outputs = beam_scorer.process(
-                input_ids,
-                next_token_scores,
-                next_tokens,
-                next_indices,
-                pad_token_id=pad_token_id,
-                eos_token_id=eos_token_id,
-                beam_indices=beam_indices,
-                decoder_prompt_len=decoder_prompt_len,
-            )
-            beam_scores = beam_outputs["next_beam_scores"]
-            beam_next_tokens = beam_outputs["next_beam_tokens"]
-            beam_idx = beam_outputs["next_beam_indices"]
-
-            input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
-
-            model_kwargs = self._update_model_kwargs_for_generation(
-                outputs,
-                model_kwargs,
-                is_encoder_decoder=self.config.is_encoder_decoder,
-            )
-            if model_kwargs.get("past_key_values", None) is not None:
-                model_kwargs["past_key_values"] = self._temporary_reorder_cache(
-                    model_kwargs["past_key_values"], beam_idx
-                )
-
-            if return_dict_in_generate and output_scores:
-                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
-
-            # increase cur_len
-            cur_len = cur_len + 1
-
-            if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
-                this_peer_finished = True
-
-        sequence_outputs = beam_scorer.finalize(
-            input_ids,
-            beam_scores,
-            next_tokens,
-            next_indices,
-            pad_token_id=pad_token_id,
-            eos_token_id=eos_token_id,
-            max_length=stopping_criteria.max_length,
-            beam_indices=beam_indices,
-            decoder_prompt_len=decoder_prompt_len,
+        return self._beam_search(
+            input_ids=input_ids,
+            beam_scorer=beam_scorer,
+            logits_processor=logits_processor,
+            stopping_criteria=stopping_criteria,
+            logits_warper=logits_warper,
+            generation_config=generation_config,
+            synced_gpus=synced_gpus,
+            **model_kwargs,
         )
 
-        if return_dict_in_generate:
-            if not output_scores:
-                sequence_outputs["sequence_scores"] = None
-
-            if self.config.is_encoder_decoder:
-                return GenerateBeamEncoderDecoderOutput(
-                    sequences=sequence_outputs["sequences"],
-                    sequences_scores=sequence_outputs["sequence_scores"],
-                    scores=scores,
-                    logits=raw_logits,
-                    beam_indices=sequence_outputs["beam_indices"],
-                    encoder_attentions=encoder_attentions,
-                    encoder_hidden_states=encoder_hidden_states,
-                    decoder_attentions=decoder_attentions,
-                    cross_attentions=cross_attentions,
-                    decoder_hidden_states=decoder_hidden_states,
-                    past_key_values=model_kwargs.get("past_key_values"),
-                )
-            else:
-                return GenerateBeamDecoderOnlyOutput(
-                    sequences=sequence_outputs["sequences"],
-                    sequences_scores=sequence_outputs["sequence_scores"],
-                    scores=scores,
-                    logits=raw_logits,
-                    beam_indices=sequence_outputs["beam_indices"],
-                    attentions=decoder_attentions,
-                    hidden_states=decoder_hidden_states,
-                    past_key_values=model_kwargs.get("past_key_values"),
-                )
-        else:
-            return sequence_outputs["sequences"]
-
     def _group_beam_search(
         self,
         input_ids: torch.LongTensor,
diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py
index 9d1cf6e568f64d..8e8b1fe2842f23 100644
--- a/src/transformers/models/musicgen/modeling_musicgen.py
+++ b/src/transformers/models/musicgen/modeling_musicgen.py
@@ -1739,7 +1739,7 @@ def generate(
                 )
 
             # 11. run greedy search
-            outputs = self._greedy_search(
+            outputs = self._sample(
                 input_ids,
                 logits_processor=logits_processor,
                 stopping_criteria=stopping_criteria,
@@ -2832,7 +2832,7 @@ def generate(
                 )
 
             # 11. run greedy search
-            outputs = self._greedy_search(
+            outputs = self._sample(
                 input_ids,
                 logits_processor=logits_processor,
                 stopping_criteria=stopping_criteria,
diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
index 63fc638f164424..9865a4b9179ac8 100644
--- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
+++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
@@ -1676,7 +1676,7 @@ def generate(
                 )
 
             # 11. run greedy search
-            outputs = self._greedy_search(
+            outputs = self._sample(
                 input_ids,
                 logits_processor=logits_processor,
                 stopping_criteria=stopping_criteria,
@@ -2691,7 +2691,7 @@ def generate(
                 )
 
             # 11. run greedy search
-            outputs = self._greedy_search(
+            outputs = self._sample(
                 input_ids,
                 logits_processor=logits_processor,
                 stopping_criteria=stopping_criteria,
diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py
index 7eac28ca77e913..3590369d5b907a 100644
--- a/src/transformers/models/rag/modeling_rag.py
+++ b/src/transformers/models/rag/modeling_rag.py
@@ -1550,7 +1550,7 @@ def extend_enc_output(tensor, num_beams=None):
                     f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
                     " greedy search."
                 )
-            return self._greedy_search(
+            return self._sample(
                 input_ids,
                 logits_processor=pre_processor,
                 stopping_criteria=prepared_stopping_criteria,

From 94306352f489c7c2a8dc18af89e2efe0a76a5159 Mon Sep 17 00:00:00 2001
From: Alazar 
Date: Mon, 13 May 2024 17:59:46 +0300
Subject: [PATCH 14/19] Port IDEFICS to tensorflow (#26870)

* Initial commit

* Just a copy of modeling_idefics.py that will be ported to TF

* - Prepend TF to the name of all classes
- Convert pytorch ops to TF (not all operations are converted yet)

* Add TF imports

* Add autotranslated files

* Add TF classes to model_tf_auto.py

* Add the TF classes in model_doc

* include auto-translated code

* Adopted from auto-translated version

* Add a forgotten super().build

* Add test code for TF version.

* Fix indentation and load pytorch weights for now

* Some fixes. Many tests are still failing but some are passing now.

- I have added TODO's for some of the hacks I made to unblock me
  and I will address them soon
- I have the processing_idefics.py hacked in my view to support TF temporarily

* Add ALL_LAYERNORM_LAYERS to match pytorch

* Revert "Add ALL_LAYERNORM_LAYERS to match pytorch"

This reverts commit 7e0a35119b4d7a6284d04d8c543fba1b29e573c9 as it
is not needed in the tf implementation.

* Fix freeze_relevant_params()

* Some more fixes

* Fix test_attention_outputs

* Add tf stuff to processing_idefics.py

processing_idefics.py supports both pytorch and tf now.

test_processor_idefics.py for pytorch is passing, so i didn't break anything
but still some issues with tf. I also need to add tf tests in
test_processor_idefics.py.

* Pass return_tensors to image processing code and fix test

* Pass return_tensors to the image processor __init__

* Fix several test cases

- Make input to some of the forward pass of type `TFModelInputType`
- Decorate main layer forward pass with `@unpack_inputs`
- Decorate main layer with `@keras_serializable`
- Pass `inputs` to TFIdeficsModel

* Some more fixes forgotten in last commit

* Fix processing code and vision_tf.py

* Fix perceiver bug

* Import from

* Auto-add build() methods + style pass

* Fix build() errors due to `None` being passed as shape to some layers

* Change name in TFIdeficsForVisionText2Text to attribute in IdeficsForVisionText2Text

* Fix pytorch weights load for tf2

There were a lot of `name=` missing in weight initialization code.

* Attempt to fix CI

* Add back accidently removed line

* Remove torch-specific stuff from the TF test file

* make fix-copies, make style, remove autotranslated files

* Fixes to imports/docstrings

* Let's try the from future import in desperation

* Fix the core random_attention_mask fn to match the torch/flax behaviour

* Clean random_attention_mask up correctly

* Remove torch-only test

* Fix loss shape, couple of nits

* make style

* Don't test for OOB embeddings because IDEFICS uses those deliberately

* Fix loss computation to handle masking

* Fix test failures when flattening

* Fix some test failures

- Add cross attention gate which was missing and wasn't being passed arround
- Fix overwriting of image_attention_mask due to hack I had for dummy inputs

* Add a proper stateless scaled_dot_product_attention

* make style

* Adding missing attribute from the PyTorch version

* Small cleanups to decoupledlinearlayer in case that helps

* Pass epsilon to LayerNormalization

* Attemp to fix pytorch weight cross-loading for TFIdeficsEmbedding

* Fix a bug in TFIdeficsGatedCrossAttentionLayer

* Patching up build() methods

* Constant self.inv_freq

* Constant self.inv_freq

* First working version

The TF implementation works now, there was a bug in the TFIdeficsDecoupledLinear
where the weights were mis-intialized (in_features,out_features)
when it should be: (out_features, in_features)

I have tested this so far with tiny-random and idefics-9b-instruct
and gives correct output.

I also dumped the final outputs for both pytorch and TF
and they are identical.

* Fix some test failures

* remove print statement

* Fix return_tensors

* Fix CI test failure check_code_quality

* Attempt to fix CI failures by running `make fixup`

The hardcoded IDs in test_modeling_tf_idefics.py are for the integration
test and makes that file unreadable and should probably be moved to a seperate file.

* Attempt to fix tests_pr_documentation_tests

* Fix a test failure in test_image_processing_idefics.py

* Fix test test_pt_tf_model_equivalence

* Fix a few failures

* Tiny fix

* Some minor fixes

* Remove a duplicate test

* Override a few test failures for IDEFICS

- `test_keras_save_load` is passing now
- `test_compile_tf_model` is still failing

* Fix processing_idefics.py after rebase

* Guard import keras with is_tf_available

* fix check code quality

* fix check code quality

* Minor fixes

* Skip test_save_load temporarily

This test passed on my local box but fails on the CI, skipping
for now to see if there are other remaining failures on the CI.

* Run `ruff format tests src utils`

* Fix last failing test, `test_compile_tf_model`

* Add fixes for vision_tf.py

I forgot to add this file in last commit.

* Minor fixes

* Replace "<<<" with "<<" for doc tests

IDEFICS-9B is too big for doctest runner, so don't run it there

* Make code more readable

* Fix bug after code review

I added a layer_norm_eps to IdeficsConfig but I don't even need it
since the vision config has a layer_norm_eps.

* Fix after code review

Use original code tokenizer.convert_tokens_to_ids

* Keep PyTorch as the default return_tensors

* Fixes to modeling_tf after code review

* Fixes from code review

- Remove all references of `TF_IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST`
- Pass 1e-5 to LayerNormalization in perceiver

* Run ruff

* Undo a change

* Refactor processing code after Matt's suggestion

* Remove TODO's that aren't needed anymore

* For pytorch, Use original pytorch processing code from main

Since this PR is a TF port it shouldn't make any modifications
to pytorch IDEFICS code. This changes undo's the pytorch processing
modifications I made and uses original code from main.

* Update tests/models/idefics/test_modeling_idefics.py

* Update tests/models/idefics/test_modeling_tf_idefics.py

* Add missing imports for is_pt_tf_cross_test

* [DO NOT MERGE]: This is a commit for debugging and will be reverted

The cross test `test_pt_tf_model_equivalence` passes locally but
fails when running on the CI. This commit is to help debug that
and will be reverted.

* Revert "[DO NOT MERGE]: This is a commit for debugging and will be reverted"

This reverts commit 8f0d709ec5bd46685fb0b4259d914ffee794875b.

* [DO NOT MERGE]: This commit is for debugging a CI failure and will be reverted

* [DO NOT MERGE]: This commit is for debugging a CI failure and will be reverted

* Revert "[DO NOT MERGE]: This commit is for debugging a CI failure and will be reverted"

This reverts commit 998cc38b8c3d313bf5e5eb55a7f5b7b881897b89.

* Revert "[DO NOT MERGE]: This commit is for debugging a CI failure and will be reverted"

This reverts commit 1c695ac4219c4ae4d39b330b01744dc27deb7dd4.

* Don't skip test_save_load

IIRC test_save_load was also failing on the CI but not on my local
box, it might be easier to debug that on the CI first than the cross tests

* Debugging commit, will be reverted

* Revert "Debugging commit, will be reverted"

This reverts commit 8eafc8e41e20c4e95a3a90834f06a6e9f445e2d5.

* Override `test_save_load` and push model to save

Maybe this will help me repro this weird bug

* pass my repo_id

* add endpoint

* Pass a temp (write) token just for this CI

* Undo last few commits, still pushing to hub for model debugging

The issue seems to be with save_pretrained(),  when I looked at the model saved
from the CI test failure it is basically empty and has no weights.
`self.save_weights(..)` seems to be failing in save_pretrained but needs
more debugging

* Add logging to modeling tf utils, will be reverted just for debugging

* Debugging, will revert

* Revert "Debugging, will revert"

This reverts commit 9d0d3075fb7c82d8cde3a5c76bc8f3876c5c55d3.

* Revert "Add logging to modeling tf utils, will be reverted just for debugging"

This reverts commit 774b6b7b1c17b3ce5d7634ade768f2f686cee617.

* Remove `test_save_load`

The CI failures are gone after my latest rebase, no idea why
but I was still saving the model to my hub on HF and the tf_model.h5
file now has everything.

* Run make fix-copies

* Run ruff format tests src utils

* Debugging commit, will be reverted

* Run ruff, also trigger CI run

* Run ruff again

* Undo debugging commit

---------

Co-authored-by: Matt 
Co-authored-by: Matt 
---
 docs/source/en/index.md                       |    2 +-
 docs/source/en/model_doc/idefics.md           |   10 +
 src/transformers/__init__.py                  |   14 +
 .../models/auto/modeling_tf_auto.py           |    2 +
 src/transformers/models/idefics/__init__.py   |   30 +-
 .../idefics/image_processing_idefics.py       |    6 +-
 .../models/idefics/modeling_tf_idefics.py     | 1812 +++++++++++++++++
 .../models/idefics/perceiver_tf.py            |  194 ++
 .../models/idefics/processing_idefics.py      |  146 +-
 src/transformers/models/idefics/vision_tf.py  |  573 ++++++
 src/transformers/tf_utils.py                  |   27 +
 src/transformers/utils/dummy_tf_objects.py    |   21 +
 .../idefics/test_image_processing_idefics.py  |    6 +-
 tests/models/idefics/test_modeling_idefics.py |    6 +
 .../idefics/test_modeling_tf_idefics.py       |  565 +++++
 .../models/idefics/test_processor_idefics.py  |   13 +-
 tests/test_modeling_tf_common.py              |   14 +-
 17 files changed, 3392 insertions(+), 49 deletions(-)
 create mode 100644 src/transformers/models/idefics/modeling_tf_idefics.py
 create mode 100644 src/transformers/models/idefics/perceiver_tf.py
 create mode 100644 src/transformers/models/idefics/vision_tf.py
 create mode 100644 tests/models/idefics/test_modeling_tf_idefics.py

diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 419d3d5b1dc2cc..9adb669e2cad66 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -160,7 +160,7 @@ Flax), PyTorch, and/or TensorFlow.
 |                       [HerBERT](model_doc/herbert)                       |       ✅        |         ✅         |      ✅      |
 |                        [Hubert](model_doc/hubert)                        |       ✅        |         ✅         |      ❌      |
 |                        [I-BERT](model_doc/ibert)                         |       ✅        |         ❌         |      ❌      |
-|                       [IDEFICS](model_doc/idefics)                       |       ✅        |         ❌         |      ❌      |
+|                       [IDEFICS](model_doc/idefics)                       |       ✅        |         ✅         |      ❌      |
 |                      [Idefics2](model_doc/idefics2)                      |       ✅        |         ❌         |      ❌      |
 |                      [ImageGPT](model_doc/imagegpt)                      |       ✅        |         ❌         |      ❌      |
 |                      [Informer](model_doc/informer)                      |       ✅        |         ❌         |      ❌      |
diff --git a/docs/source/en/model_doc/idefics.md b/docs/source/en/model_doc/idefics.md
index 9989f89d682e8f..ab66bd555a71d5 100644
--- a/docs/source/en/model_doc/idefics.md
+++ b/docs/source/en/model_doc/idefics.md
@@ -52,6 +52,16 @@ To train a new IDEFICS model from scratch use the m4 codebase (a link will be pr
 [[autodoc]] IdeficsForVisionText2Text
     - forward
 
+## TFIdeficsModel
+
+[[autodoc]] TFIdeficsModel
+    - call
+
+## TFIdeficsForVisionText2Text
+
+[[autodoc]] TFIdeficsForVisionText2Text
+    - call
+
 ## IdeficsImageProcessor
 
 [[autodoc]] IdeficsImageProcessor
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 21222be3fb414a..97a4e89684eb7e 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -3862,6 +3862,15 @@
             "TFHubertPreTrainedModel",
         ]
     )
+
+    _import_structure["models.idefics"].extend(
+        [
+            "TFIdeficsForVisionText2Text",
+            "TFIdeficsModel",
+            "TFIdeficsPreTrainedModel",
+        ]
+    )
+
     _import_structure["models.layoutlm"].extend(
         [
             "TFLayoutLMForMaskedLM",
@@ -7905,6 +7914,11 @@
             TFHubertModel,
             TFHubertPreTrainedModel,
         )
+        from .models.idefics import (
+            TFIdeficsForVisionText2Text,
+            TFIdeficsModel,
+            TFIdeficsPreTrainedModel,
+        )
         from .models.layoutlm import (
             TFLayoutLMForMaskedLM,
             TFLayoutLMForQuestionAnswering,
diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py
index a3df614b9b7922..756da20dbc51a6 100644
--- a/src/transformers/models/auto/modeling_tf_auto.py
+++ b/src/transformers/models/auto/modeling_tf_auto.py
@@ -58,6 +58,7 @@
         ("gptj", "TFGPTJModel"),
         ("groupvit", "TFGroupViTModel"),
         ("hubert", "TFHubertModel"),
+        ("idefics", "TFIdeficsModel"),
         ("layoutlm", "TFLayoutLMModel"),
         ("layoutlmv3", "TFLayoutLMv3Model"),
         ("led", "TFLEDModel"),
@@ -112,6 +113,7 @@
         ("funnel", "TFFunnelForPreTraining"),
         ("gpt-sw3", "TFGPT2LMHeadModel"),
         ("gpt2", "TFGPT2LMHeadModel"),
+        ("idefics", "TFIdeficsForVisionText2Text"),
         ("layoutlm", "TFLayoutLMForMaskedLM"),
         ("lxmert", "TFLxmertForPreTraining"),
         ("mobilebert", "TFMobileBertForPreTraining"),
diff --git a/src/transformers/models/idefics/__init__.py b/src/transformers/models/idefics/__init__.py
index 7a4e8056f540d5..3b32064789cabe 100644
--- a/src/transformers/models/idefics/__init__.py
+++ b/src/transformers/models/idefics/__init__.py
@@ -13,7 +13,13 @@
 # limitations under the License.
 from typing import TYPE_CHECKING
 
-from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
+from ...utils import (
+    OptionalDependencyNotAvailable,
+    _LazyModule,
+    is_tf_available,
+    is_torch_available,
+    is_vision_available,
+)
 
 
 _import_structure = {"configuration_idefics": ["IdeficsConfig"]}
@@ -39,6 +45,17 @@
     ]
     _import_structure["processing_idefics"] = ["IdeficsProcessor"]
 
+try:
+    if not is_tf_available():
+        raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+    pass
+else:
+    _import_structure["modeling_tf_idefics"] = [
+        "TFIdeficsForVisionText2Text",
+        "TFIdeficsModel",
+        "TFIdeficsPreTrainedModel",
+    ]
 
 if TYPE_CHECKING:
     from .configuration_idefics import IdeficsConfig
@@ -64,6 +81,17 @@
         )
         from .processing_idefics import IdeficsProcessor
 
+    try:
+        if not is_tf_available():
+            raise OptionalDependencyNotAvailable()
+    except OptionalDependencyNotAvailable:
+        pass
+    else:
+        from .modeling_tf_idefics import (
+            TFIdeficsForVisionText2Text,
+            TFIdeficsModel,
+            TFIdeficsPreTrainedModel,
+        )
 
 else:
     import sys
diff --git a/src/transformers/models/idefics/image_processing_idefics.py b/src/transformers/models/idefics/image_processing_idefics.py
index ee8dfbb4077c66..f4998020daf642 100644
--- a/src/transformers/models/idefics/image_processing_idefics.py
+++ b/src/transformers/models/idefics/image_processing_idefics.py
@@ -92,8 +92,9 @@ def preprocess(
         image_mean: Optional[Union[float, List[float]]] = None,
         image_std: Optional[Union[float, List[float]]] = None,
         transform: Callable = None,
+        return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
         **kwargs,
-    ) -> TensorType.PYTORCH:
+    ) -> TensorType:
         """
         Preprocess a batch of images.
 
@@ -162,7 +163,6 @@ def preprocess(
         images = [self.rescale(image=image, scale=1 / 255) for image in images]
         images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
         images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images]
-        # TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available
-        images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"]
+        images = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)["pixel_values"]
 
         return images
diff --git a/src/transformers/models/idefics/modeling_tf_idefics.py b/src/transformers/models/idefics/modeling_tf_idefics.py
new file mode 100644
index 00000000000000..8d9322b0edc272
--- /dev/null
+++ b/src/transformers/models/idefics/modeling_tf_idefics.py
@@ -0,0 +1,1812 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TF 2.0 Idefics model. """
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import tensorflow as tf
+
+from ... import TFPreTrainedModel
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import ModelOutput
+from ...modeling_tf_utils import (
+    TFCausalLanguageModelingLoss,
+    TFModelInputType,
+    keras_serializable,
+    shape_list,
+    unpack_inputs,
+)
+from ...tf_utils import invert_attention_mask, scaled_dot_product_attention
+from ...utils import (
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    replace_return_docstrings,
+)
+from .configuration_idefics import IdeficsConfig
+from .perceiver_tf import TFIdeficsPerceiverResampler
+from .vision_tf import TFIdeficsVisionTransformer
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "IdeficsConfig"
+
+
+@dataclass
+class TFIdeficsBaseModelOutputWithPast(ModelOutput):
+    """
+    Base class for Idefics model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+    Args:
+        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+            `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+            encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+            input) to speed up sequential decoding.
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        image_hidden_states (`tuple(tf.Tensor)`, *optional*):
+            Tuple of `tf.Tensor` (one for the output of the image embeddings, `(batch_size, num_images,
+            sequence_length, hidden_size)`.
+
+            image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
+    """
+
+    last_hidden_state: tf.Tensor = None
+    past_key_values: Optional[Tuple[Tuple[tf.Tensor]]] = None
+    hidden_states: Optional[Tuple[tf.Tensor]] = None
+    attentions: Optional[Tuple[tf.Tensor]] = None
+    image_hidden_states: Optional[Tuple[tf.Tensor]] = None
+
+
+@dataclass
+class TFIdeficsCausalLMOutputWithPast(ModelOutput):
+    """
+    Base class for Idefics causal language model (or autoregressive) outputs.
+
+    Args:
+        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Language modeling loss (for next-token prediction).
+        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        image_hidden_states (`tuple(tf.Tensor)`, *optional*):
+            Tuple of `tf.Tensor` (one for the output of the image embeddings, `(batch_size, num_images,
+            sequence_length, hidden_size)`.
+
+            image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
+    """
+
+    loss: Optional[tf.Tensor] = None
+    logits: tf.Tensor = None
+    past_key_values: Optional[List[tf.Tensor]] = None
+    hidden_states: Optional[Tuple[tf.Tensor]] = None
+    attentions: Optional[Tuple[tf.Tensor]] = None
+    image_hidden_states: Optional[Tuple[tf.Tensor]] = None
+
+
+def expand_inputs_for_generation(
+    input_ids,
+    expand_size=1,
+    is_encoder_decoder=False,
+    attention_mask=None,
+    encoder_outputs=None,
+    **model_kwargs,
+):
+    expanded_return_idx = tf.reshape(tf.repeat(tf.range(tf.shape(input_ids)[0]), expand_size), [-1])
+    input_ids = tf.gather(input_ids, expanded_return_idx)
+    model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
+    model_kwargs["image_encoder_embeddings"] = model_kwargs.get("image_encoder_embeddings", None)
+    model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings", None)
+    model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None)
+
+    if "token_type_ids" in model_kwargs:
+        token_type_ids = model_kwargs["token_type_ids"]
+        model_kwargs["token_type_ids"] = tf.gather(token_type_ids, expanded_return_idx)
+
+    if attention_mask is not None:
+        model_kwargs["attention_mask"] = tf.gather(attention_mask, expanded_return_idx)
+
+    if model_kwargs["image_attention_mask"] is not None:
+        model_kwargs["image_attention_mask"] = tf.gather(model_kwargs["image_attention_mask"], expanded_return_idx)
+
+    if model_kwargs["pixel_values"] is not None:
+        model_kwargs["pixel_values"] = tf.gather(model_kwargs["pixel_values"], expanded_return_idx)
+
+    elif model_kwargs["image_encoder_embeddings"] is not None:
+        model_kwargs["image_encoder_embeddings"] = tf.gather(
+            model_kwargs["image_encoder_embeddings"], expanded_return_idx
+        )
+
+    elif model_kwargs["perceiver_embeddings"] is not None:
+        model_kwargs["perceiver_embeddings"] = tf.gather(model_kwargs["perceiver_embeddings"], expanded_return_idx)
+
+    return input_ids, model_kwargs
+
+
+def update_model_kwargs_for_generation(outputs, model_kwargs):
+    # must have this key set to at least None
+    if "past_key_values" in outputs:
+        model_kwargs["past_key_values"] = outputs.past_key_values
+    else:
+        model_kwargs["past_key_values"] = None
+
+    # update token_type_ids with last value
+    if "token_type_ids" in model_kwargs:
+        token_type_ids = model_kwargs["token_type_ids"]
+        model_kwargs["token_type_ids"] = tf.concat([token_type_ids, token_type_ids[:, -1:, ...]], axis=-1)
+
+    # update attention masks
+    if "attention_mask" in model_kwargs:
+        attention_mask = model_kwargs["attention_mask"]
+        model_kwargs["attention_mask"] = tf.concat(
+            [attention_mask, tf.ones_like(attention_mask[:, -1:, ...])], axis=-1
+        )
+    if "image_attention_mask" in model_kwargs:
+        image_attention_mask = model_kwargs["image_attention_mask"]
+        last_mask = image_attention_mask[:, -1:, ...]
+        model_kwargs["image_attention_mask"] = last_mask
+
+    # Get the precomputed image_hidden_states
+    model_kwargs["image_hidden_states"] = outputs.image_hidden_states
+
+    return model_kwargs
+
+
+def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
+    token_type_ids = kwargs.get("token_type_ids", None)
+    # only last token for inputs_ids if past is defined in kwargs
+    if past_key_values is not None:
+        input_ids = input_ids[:, -1:]
+        if token_type_ids is not None:
+            token_type_ids = token_type_ids[:, -1:]
+
+    attention_mask = kwargs.get("attention_mask", None)
+    position_ids = kwargs.get("position_ids", None)
+
+    if attention_mask is not None and position_ids is None:
+        # create position_ids on the fly for batch generation
+        position_ids = tf.math.cumsum(tf.cast(attention_mask, dtype=tf.int64), axis=-1) - 1
+        position_ids = tf.where(attention_mask == 0, 1, position_ids)
+        if past_key_values is not None:
+            position_ids = position_ids[:, -1:]
+
+    pixel_values = kwargs.get("pixel_values", None)
+    image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None)
+    perceiver_embeddings = kwargs.get("perceiver_embeddings", None)
+    image_attention_mask = kwargs.get("image_attention_mask", None)
+    interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False)
+
+    return {
+        "input_ids": input_ids,
+        "past_key_values": past_key_values,
+        "use_cache": kwargs.get("use_cache"),
+        "position_ids": position_ids,
+        "attention_mask": attention_mask,
+        "token_type_ids": token_type_ids,
+        "pixel_values": pixel_values,
+        "image_encoder_embeddings": image_encoder_embeddings,
+        "perceiver_embeddings": perceiver_embeddings,
+        "image_attention_mask": image_attention_mask,
+        "interpolate_pos_encoding": interpolate_pos_encoding,
+    }
+
+
+def freeze_model(model, module_exceptions=[]):
+    mapping = {
+        "LayerNorm": tf.keras.layers.LayerNormalization,
+        "Dense": tf.keras.layers.Dense,
+        "Embedding": tf.keras.layers.Embedding,
+    }
+    module_exceptions_mapped = [mapping[m] for m in module_exceptions]
+    if not hasattr(model, "layers"):
+        model.trainable = False  # It is just a layer
+        return model
+    for layer in model.layers:
+        if module_exceptions and any(isinstance(layer, t) for t in module_exceptions_mapped):
+            layer.trainable = True  # Explicitly setting it to true to avoid any mistakes
+        else:
+            layer.trainable = False
+    return model
+
+
+class TFIdeficsDecoupledEmbedding(tf.keras.layers.Embedding):
+    """
+    Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
+    regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
+    then it will create `num_additional_embeddings` additional parameters that are always trained. If
+    `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `tf.keras.layers.Embedding`.
+    """
+
+    def __init__(
+        self,
+        num_embeddings,
+        num_additional_embeddings,
+        embedding_dim,
+        partially_freeze: Optional[bool] = False,
+        dtype=None,
+        **kwargs,
+    ) -> None:
+        """
+        Args:
+            num_embeddings (`int`):
+                Size of the dictionary of embeddings
+            num_additional_embeddings (`int`):
+                Number of additional embeddings. Only useful when you `partially_freeze=True`.
+            embedding_dim (`int`):
+                The size of each embedding vector
+            partially_freeze: (`bool`, *optional*, defaults to `False`):
+                If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
+
+        Note: there are a lot of other parameters to initialize a standard `tf.keras.layers.Embedding` such as `mask_zero`,
+        `input_length` or `embeddings_initializer`. We are not supporting these.
+        """
+        super().__init__(
+            input_dim=num_embeddings,
+            output_dim=embedding_dim,
+            dtype=dtype,
+            **kwargs,
+        )
+        self.num_embeddings = num_embeddings
+        self.num_additional_embeddings = num_additional_embeddings
+        self.partially_freeze = partially_freeze
+
+        if partially_freeze:
+            self.trainable = False
+
+        if self.num_additional_embeddings > 0:
+            self.additional_embedding = tf.keras.layers.Embedding(
+                input_dim=self.num_additional_embeddings,
+                output_dim=embedding_dim,
+                dtype=dtype,
+                name="additional_embedding",
+            )
+
+    def call(self, input_ids):
+        """
+        we have 2 embeddings, with different indices - one pretrained self.weight and another
+        self.additional_embedding.weight that is being trained.
+
+        in order to make a lookup of the input ids, we:
+        1. find out the indices of the entries belonging to the 2nd embedding
+        2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
+           embedding starts from 0 and not num_embeddings
+        3. perform the 2nd embedding lookup
+        4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
+        5. perform the 1st embedding lookup
+        6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
+
+        note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
+        then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
+        i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
+        usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
+        measure.
+
+        """
+        if self.num_additional_embeddings == 0:
+            return super().call(input_ids)
+
+        # Clone so that we don't modify the original input_ids later on
+        input_ids = tf.identity(input_ids)
+        additional_vocab_indices = tf.where(input_ids >= self.num_embeddings)
+        input_ids_additional_vocab = tf.gather_nd(input_ids, additional_vocab_indices)
+        additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)
+
+        # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
+        input_ids = tf.tensor_scatter_nd_update(
+            input_ids,
+            additional_vocab_indices,
+            # tensor filled with 0, having the same length as additional_vocab_indices
+            tf.zeros(tf.shape(additional_vocab_indices)[0], dtype=input_ids.dtype),
+        )
+        full_vector = super().call(input_ids)
+
+        # overwrite the records with high indices
+        full_vector = tf.tensor_scatter_nd_update(full_vector, additional_vocab_indices, additional_embeddings)
+
+        return full_vector
+
+    def extra_repr(self) -> str:
+        return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
+            self.num_embeddings,
+            self.num_additional_embeddings,
+            self.output_dim,
+            self.partially_freeze,
+        )
+
+
+class TFIdeficsDecoupledLinear(tf.keras.layers.Layer):
+    """
+    Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
+    regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0,
+    then it will create `out_additional_features * in_features` additional parameters that are always trained. If
+    `out_additional_features=0`, then the module defaults back to the regular behavior of `tf.keras.layers.Dense`.
+    """
+
+    def __init__(
+        self,
+        in_features: int,
+        out_features: int,
+        out_additional_features: int = 0,
+        bias: bool = True,
+        partially_freeze: bool = True,
+        **kwargs,
+    ) -> None:
+        """
+        out_additional_features: int. Number of additional trainable dimensions. Only makes sense when
+        `partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen and extra
+        parameters (if any) will be trainable. If False, default to the regular behavior of tf.keras.layers.Dense.
+        """
+        super().__init__(**kwargs)
+        self.out_additional_features = out_additional_features
+        self.partially_freeze = partially_freeze
+
+        self.in_features = in_features
+        self.out_features = out_features
+        self.use_bias = bias
+
+        if out_additional_features > 0:
+            self.additional_fc = tf.keras.layers.Dense(
+                units=out_additional_features, use_bias=bias, name="additional_fc"
+            )
+
+    def call(self, inputs: tf.Tensor) -> tf.Tensor:
+        output = tf.linalg.matmul(a=inputs, b=self.weight, transpose_b=True)
+        if self.bias is not None:
+            output = tf.nn.bias_add(output, self.bias)
+
+        if self.out_additional_features > 0:
+            additional_features = self.additional_fc(inputs)
+            output = tf.concat([output, additional_features], axis=-1)
+
+        return output
+
+    def get_config(self):
+        config = super().get_config()
+        config.update(
+            {
+                "in_features": self.in_features,
+                "out_features": self.out_features,
+                "out_additional_features": self.out_additional_features,
+                "bias": self.bias is not None,
+                "partially_freeze": self.partially_freeze,
+            }
+        )
+        return config
+
+    def extra_repr(self) -> str:
+        """Overwriting `nn.Linear.extra_repr` to include new parameters."""
+        return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
+            self.in_features,
+            self.out_features,
+            self.out_additional_features,
+            self.bias is not None,
+            self.partially_freeze,
+        )
+
+    @classmethod
+    def from_config(cls, config):
+        return cls(**config)
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        self.weight = self.add_weight(
+            shape=(self.out_features, self.in_features), trainable=not self.partially_freeze, name="weight"
+        )
+        if self.use_bias:
+            self.bias = self.add_weight(shape=(self.out_features,), trainable=not self.partially_freeze, name="bias")
+        else:
+            self.bias = None
+        if getattr(self, "additional_fc", None) is not None:
+            with tf.name_scope(self.additional_fc.name):
+                self.additional_fc.build(self.in_features)
+
+
+def _make_causal_mask(input_ids_shape, dtype, past_key_values_length=0):
+    """
+    Make causal mask used for bi-directional self-attention, supporting both static and dynamic shapes.
+    """
+    bsz, tgt_len = input_ids_shape
+
+    # Create a matrix where only the lower triangle and diagonal are filled with zeros (causal mask)
+    mask = tf.fill((tgt_len, tgt_len), tf.dtypes.as_dtype(dtype).min)
+    mask_cond = tf.range(tgt_len)
+    mask = tf.where(mask_cond[:, None] >= mask_cond[None, :], 0.0, mask)
+
+    if past_key_values_length > 0:
+        mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1)
+
+    if bsz is None:
+        # When batch size is dynamic, expand and tile
+        # so we can compile a functional model
+        mask = tf.expand_dims(mask, 0)
+        mask = tf.expand_dims(mask, 0)  # shape: (1, 1, tgt_len, tgt_len + past_key_values_length)
+        mask = tf.tile(mask, [bsz, 1, 1, 1])
+    else:
+        # When batch size is static, directly use broadcast_to
+        mask = tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
+
+    return mask
+
+
+def _expand_mask(mask, dtype, tgt_len=None):
+    """
+    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+    """
+    bsz, src_len = shape_list(mask)
+    tgt_len = tgt_len if tgt_len is not None else src_len
+
+    expanded_mask = tf.expand_dims(tf.expand_dims(mask, 1), 1)
+    expanded_mask = tf.broadcast_to(expanded_mask, [bsz, 1, tgt_len, src_len])
+
+    inverted_mask = 1.0 - tf.cast(expanded_mask, dtype)
+
+    return tf.where(
+        tf.cast(inverted_mask, bool), tf.fill(dims=shape_list(inverted_mask), value=tf.float32.min), inverted_mask
+    )
+
+
+class TFIdeficsRMSNorm(tf.keras.layers.Layer):
+    def __init__(self, hidden_size, eps=1e-6, **kwargs):
+        """
+        TFIdeficsRMSNorm is equivalent to T5LayerNorm
+        """
+        super().__init__(**kwargs)
+        self.hidden_size = hidden_size
+        self.variance_epsilon = eps
+
+    def build(self, input_shape):
+        if self.built:
+            return
+        self.built = True
+        self.weight = self.add_weight(name="weight", shape=[self.hidden_size], initializer="ones")
+
+        super().build(input_shape)
+
+    def call(self, hidden_states):
+        variance = tf.math.reduce_mean(tf.math.square(tf.cast(hidden_states, tf.float32)), axis=-1, keepdims=True)
+        hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon)
+
+        # convert into half-precision if necessary
+        if self.weight.dtype in [tf.float16, tf.bfloat16]:
+            hidden_states = tf.cast(hidden_states, self.weight.dtype)
+
+        return self.weight * hidden_states
+
+
+class TFIdeficsEmbedding(tf.keras.layers.Layer):
+    def __init__(self, dim, max_position_embeddings=2048, base=10000, **kwargs):
+        super().__init__(**kwargs)
+
+        self.dim = dim
+        self.max_position_embeddings = max_position_embeddings
+        self.base = base
+        self.inv_freq = tf.constant(
+            1.0 / (self.base ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim))
+        )
+
+    def _compute_cos_sin(self, seq_len):
+        t = tf.range(seq_len, dtype=self.inv_freq.dtype)
+        freqs = tf.einsum("i, j -> ij", t, self.inv_freq)  # Outer multiplication
+        emb = tf.concat((freqs, freqs), axis=-1)
+
+        return tf.cos(emb), tf.sin(emb)
+
+    def call(self, x, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        if seq_len is None:
+            seq_len = shape_list(x)[2]
+        return self._compute_cos_sin(seq_len=seq_len)
+
+
+def rotate_half(x):
+    """Rotates half the hidden dims of the input."""
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2 :]
+    return tf.concat((-x2, x1), axis=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+    cos = tf.gather(cos, position_ids)  # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
+    sin = tf.gather(sin, position_ids)
+    cos = tf.expand_dims(cos, 1)
+    sin = tf.expand_dims(sin, 1)
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+class TFIdeficsMLP(tf.keras.layers.Layer):
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.gate_proj = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="gate_proj")
+        self.down_proj = tf.keras.layers.Dense(hidden_size, use_bias=False, name="down_proj")
+        self.up_proj = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="up_proj")
+        self.act_fn = get_tf_activation(hidden_act)
+        self.intermediate_size = intermediate_size
+        self.hidden_size = hidden_size
+
+    def call(self, x):
+        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "gate_proj", None) is not None:
+            with tf.name_scope(self.gate_proj.name):
+                self.gate_proj.build(self.hidden_size)
+        if getattr(self, "down_proj", None) is not None:
+            with tf.name_scope(self.down_proj.name):
+                self.down_proj.build(self.intermediate_size)
+        if getattr(self, "up_proj", None) is not None:
+            with tf.name_scope(self.up_proj.name):
+                self.up_proj.build(self.hidden_size)
+
+
+class TFIdeficsAttention(tf.keras.layers.Layer):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        dropout: float = 0.0,
+        is_cross_attention: bool = False,
+        config: IdeficsConfig = None,
+        qk_layer_norms: bool = False,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.hidden_size = hidden_size
+        self.num_heads = num_heads
+        self.head_dim = hidden_size // num_heads
+        self.dropout = dropout
+        self.config = config
+        self.is_causal = True
+
+        if (self.head_dim * num_heads) != self.hidden_size:
+            raise ValueError(
+                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+                f" and `num_heads`: {num_heads})."
+            )
+
+        self.is_cross_attention = is_cross_attention
+
+        self.q_proj = tf.keras.layers.Dense(
+            num_heads * self.head_dim,
+            use_bias=False,
+            name="q_proj",
+        )
+        self.k_proj = tf.keras.layers.Dense(
+            num_heads * self.head_dim,
+            use_bias=False,
+            name="k_proj",
+        )
+        self.v_proj = tf.keras.layers.Dense(
+            num_heads * self.head_dim,
+            use_bias=False,
+            name="v_proj",
+        )
+        self.o_proj = tf.keras.layers.Dense(
+            hidden_size,
+            use_bias=False,
+            name="o_proj",
+        )
+        self.rotary_emb = TFIdeficsEmbedding(self.head_dim, name="rotary_emb")
+
+        self.qk_layer_norms = qk_layer_norms
+        if self.qk_layer_norms:
+            self.q_layer_norm = TFIdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps, name="q_layer_norm")
+            self.k_layer_norm = TFIdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps, name="k_layer_norm")
+
+    def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
+        return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), perm=[0, 2, 1, 3])
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        key_value_states: Optional[tf.Tensor] = None,
+        attention_mask: Optional[tf.Tensor] = None,
+        position_ids: Optional[tf.Tensor] = None,
+        past_key_value: Optional[Tuple[tf.Tensor]] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+    ) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[Tuple[tf.Tensor]]]:
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        is_cross_attention = self.is_cross_attention or key_value_states is not None
+
+        bsz, q_len, _ = shape_list(hidden_states)
+
+        query_states = self._shape(self.q_proj(hidden_states), q_len, bsz)
+        if not is_cross_attention:
+            key_states = self._shape(self.k_proj(hidden_states), q_len, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), q_len, bsz)
+        else:
+            _, kv_len, _ = shape_list(key_value_states)  # Note that, in this case, `kv_len` == `kv_seq_len`
+            key_states = self._shape(self.k_proj(key_value_states), kv_len, bsz)
+            value_states = self._shape(self.v_proj(key_value_states), kv_len, bsz)
+
+        kv_seq_len = shape_list(key_states)[-2]
+        if past_key_value is not None:
+            kv_seq_len += shape_list(past_key_value[0])[-2]
+        if not is_cross_attention:
+            # Below is to allow symbolic tensors compilation
+            if tf.is_tensor(kv_seq_len):
+                seq_len = tf.reduce_max(kv_seq_len, q_len)
+            else:
+                seq_len = max(kv_seq_len, q_len)
+            cos, sin = self.rotary_emb(value_states, seq_len)
+            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+        # [bsz, nh, t, hd]
+
+        if past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = tf.concat([past_key_value[0], key_states], axis=2)
+            value_states = tf.concat([past_key_value[1], value_states], axis=2)
+
+        past_key_value = (key_states, value_states) if use_cache else None
+
+        if self.qk_layer_norms:
+            query_states = self.q_layer_norm(query_states)
+            key_states = self.k_layer_norm(key_states)
+
+        tf.debugging.assert_equal(
+            tf.shape(attention_mask),
+            [bsz, 1, q_len, kv_seq_len],
+            message=f"Attention weights should be of size {[bsz, 1, q_len, kv_seq_len]}, but is {tf.shape(attention_mask)}",
+        )
+
+        attn_output = scaled_dot_product_attention(
+            query_states,
+            key_states,
+            value_states,
+            attn_mask=attention_mask,
+            # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+            is_causal=self.is_causal and attention_mask is None and q_len > 1,
+        )
+
+        tf.debugging.assert_equal(
+            tf.shape(attn_output),
+            [bsz, self.num_heads, q_len, self.head_dim],
+            message=f"Attention weights should be of size {[bsz, self.num_heads, q_len, self.head_dim]}, but is {tf.shape(attn_output)}",
+        )
+
+        attn_output = tf.reshape(tf.transpose(attn_output, perm=[0, 2, 1, 3]), (bsz, q_len, self.hidden_size))
+
+        attn_output = self.o_proj(attn_output)
+
+        attn_weights = None
+        if output_attentions:
+            logger.warning_once(
+                "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
+            )
+
+        return attn_output, attn_weights, past_key_value
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if self.is_cross_attention:
+            kv_input_dim = (
+                self.hidden_size
+                if not hasattr(self.config.vision_config, "embed_dim")
+                else self.config.vision_config.embed_dim
+            )
+        else:
+            kv_input_dim = self.hidden_size
+        if getattr(self, "o_proj", None) is not None:
+            with tf.name_scope(self.o_proj.name):
+                self.o_proj.build(self.num_heads * self.head_dim)
+        if getattr(self, "q_proj", None) is not None:
+            with tf.name_scope(self.q_proj.name):
+                self.q_proj.build(self.hidden_size)
+        if getattr(self, "k_proj", None) is not None:
+            with tf.name_scope(self.k_proj.name):
+                self.k_proj.build(kv_input_dim)
+        if getattr(self, "v_proj", None) is not None:
+            with tf.name_scope(self.v_proj.name):
+                self.v_proj.build(kv_input_dim)
+        if getattr(self, "rotary_emb", None) is not None:
+            with tf.name_scope(self.rotary_emb.name):
+                self.rotary_emb.build(None)
+
+
+class TFIdeficsDecoderLayer(tf.keras.layers.Layer):
+    def __init__(self, config: IdeficsConfig, **kwargs):
+        super().__init__(**kwargs)
+        self.hidden_size = config.hidden_size
+        self.self_attn = TFIdeficsAttention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            dropout=config.dropout,
+            config=config,
+            name="self_attn",
+        )
+        self.mlp = TFIdeficsMLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            name="mlp",
+        )
+        self.input_layernorm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm")
+        self.post_attention_layernorm = TFIdeficsRMSNorm(
+            config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm"
+        )
+        self.dropout = config.dropout
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        attention_mask: Optional[tf.Tensor] = None,
+        position_ids: Optional[tf.Tensor] = None,
+        past_key_value: Optional[Tuple[tf.Tensor]] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = False,
+        training=False,
+    ) -> Tuple[tf.Tensor, Optional[Tuple[tf.Tensor, tf.Tensor]]]:
+        """
+        Args:
+            hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`tf.Tensor`, *optional*): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+            past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states
+        """
+
+        residual = hidden_states
+
+        hidden_states = self.input_layernorm(hidden_states)
+
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+        )
+        hidden_states = tf.nn.dropout(hidden_states, rate=self.dropout)
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = tf.nn.dropout(hidden_states, rate=self.dropout)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "self_attn", None) is not None:
+            with tf.name_scope(self.self_attn.name):
+                self.self_attn.build(None)
+        if getattr(self, "mlp", None) is not None:
+            with tf.name_scope(self.mlp.name):
+                self.mlp.build(None)
+        if getattr(self, "input_layernorm", None) is not None:
+            with tf.name_scope(self.input_layernorm.name):
+                self.input_layernorm.build(None)
+        if getattr(self, "post_attention_layernorm", None) is not None:
+            with tf.name_scope(self.post_attention_layernorm.name):
+                self.post_attention_layernorm.build(None)
+
+
+class TFIdeficsGatedCrossAttentionLayer(tf.keras.layers.Layer):
+    def __init__(self, config: IdeficsConfig, **kwargs):
+        super().__init__(**kwargs)
+        self.hidden_size = config.hidden_size
+        self.cross_attn = TFIdeficsAttention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            is_cross_attention=True,
+            dropout=config.dropout,
+            config=config,
+            qk_layer_norms=config.qk_layer_norms,
+            name="cross_attn",
+        )
+        self.mlp = TFIdeficsMLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            name="mlp",
+        )
+        self.input_layernorm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm")
+        self.post_attention_layernorm = TFIdeficsRMSNorm(
+            config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm"
+        )
+        self.config = config.dropout
+
+        self.act_cross_attn = tf.keras.activations.tanh
+        self.act_dense = tf.keras.activations.tanh
+
+        self.alpha_initializer = config.alpha_initializer
+        self.alpha_type = config.alpha_type
+        self.alphas_initializer_range = config.alphas_initializer_range
+
+    def build(self, input_shape):
+        if self.built:
+            return
+        self.built = True
+        if self.alpha_initializer == "zeros":
+            if self.alpha_type == "vector":
+                self.alpha_cross_attn = self.add_weight(
+                    shape=(1, 1, self.hidden_size), initializer="zeros", trainable=True, name="alpha_cross_attn"
+                )
+                self.alpha_dense = self.add_weight(
+                    shape=(1, 1, self.hidden_size), initializer="zeros", trainable=True, name="alpha_dense"
+                )
+            elif self.alpha_type == "float":
+                self.alpha_cross_attn = self.add_weight(
+                    shape=(1,), initializer="zeros", trainable=True, name="alpha_cross_attn"
+                )
+                self.alpha_dense = self.add_weight(shape=(1,), initializer="zeros", trainable=True, name="alpha_dense")
+            else:
+                raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})")
+
+        elif self.alpha_initializer == "ones":
+            if self.alpha_type == "vector":
+                self.alpha_cross_attn = self.add_weight(
+                    shape=(1, 1, self.hidden_size), initializer="ones", trainable=True, name="alpha_cross_attn"
+                )
+                self.alpha_dense = self.add_weight(
+                    shape=(1, 1, self.hidden_size), initializer="ones", trainable=True, name="alpha_dense"
+                )
+            elif self.alpha_type == "float":
+                self.alpha_cross_attn = self.add_weight(
+                    shape=(1,), initializer="ones", trainable=True, name="alpha_cross_attn"
+                )
+                self.alpha_dense = self.add_weight(shape=(1,), initializer="ones", trainable=True, name="alpha_dense")
+            else:
+                raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})")
+
+        elif self.alpha_initializer in {"normal", "gaussian", "random"}:
+            if self.alpha_type == "vector":
+                self.alpha_cross_attn = self.add_weight(
+                    shape=(1, 1, self.hidden_size),
+                    initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range),
+                    trainable=True,
+                    name="alpha_cross_attn",
+                )
+                self.alpha_dense = self.add_weight(
+                    shape=(1, 1, self.hidden_size),
+                    initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range),
+                    trainable=True,
+                    name="alpha_dense",
+                )
+            elif self.alpha_type == "float":
+                self.alpha_cross_attn = self.add_weight(
+                    shape=(1,),
+                    initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range),
+                    trainable=True,
+                    name="alpha_type",
+                )
+                self.alpha_dense = self.add_weight(
+                    shape=(1,),
+                    initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range),
+                    trainable=True,
+                    name="alpha_dense",
+                )
+            else:
+                raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})")
+
+        else:
+            raise NotImplementedError(f"Alpha initialization scheme {self.alpha_initializer} not yet implemented!")
+
+        if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
+            raise ValueError("Alpha parameters not initialized correctly!")
+        with tf.name_scope(self.cross_attn.name):
+            self.cross_attn.build(None)
+        with tf.name_scope(self.mlp.name):
+            self.mlp.build(None)
+        with tf.name_scope(self.input_layernorm.name):
+            self.input_layernorm.build(None)
+        with tf.name_scope(self.post_attention_layernorm.name):
+            self.post_attention_layernorm.build(None)
+        super().build(input_shape)
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        attention_mask: Optional[tf.Tensor] = None,
+        image_hidden_states: Optional[tf.Tensor] = None,
+        image_attention_mask: Optional[tf.Tensor] = None,
+        cross_attention_gate: Optional[tf.Tensor] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = False,
+        past_key_value: Optional[Tuple[tf.Tensor]] = None,
+    ) -> Tuple[tf.Tensor, Optional[Tuple[tf.Tensor, tf.Tensor]]]:
+        """
+        Args:
+            hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`tf.Tensor`, *optional*): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+            past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states
+            no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored
+        """
+        if image_hidden_states is None:
+            raise ValueError(
+                "`image_hidden_states` is required for Idefics cross attention module which are visual features to be"
+                " conditioned on."
+            )
+
+        if cross_attention_gate is None:
+            raise ValueError(
+                "`cross_attention_gate` is required for Idefics cross attention module to zero-out the cross-attention hidden_states attending to no images."
+            )
+
+        if past_key_value is not None:
+            raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.")
+
+        residual = hidden_states
+
+        hidden_states = self.input_layernorm(hidden_states)
+
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.cross_attn(
+            hidden_states=hidden_states,
+            key_value_states=image_hidden_states,
+            attention_mask=image_attention_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = tf.nn.dropout(hidden_states, rate=self.config)
+        mask = tf.cast(cross_attention_gate == 0, dtype=hidden_states.dtype)
+        # Expand dimensions of mask to match hidden_states
+        mask = tf.expand_dims(mask, -1)
+        hidden_states = tf.where(
+            tf.broadcast_to(mask, tf.shape(hidden_states)) == 1, tf.zeros_like(hidden_states), hidden_states
+        )
+        # when there are no images the model is used in pure language mode
+        # gate = 0 if no_images else 1
+        hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = tf.nn.dropout(hidden_states, rate=self.config)
+        hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a TensorFlow [tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) subclass.
+    Use it as a regular TensorFlow Layer and refer to the TensorFlow documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`IdeficsConfig`]):
+            Model configuration class with all the parameters of the model. Initializing with a config file does not
+            load the weights associated with the model, only the configuration. Check out the
+            [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+    LLAMA_START_DOCSTRING,
+)
+class TFIdeficsPreTrainedModel(TFPreTrainedModel):
+    config_class = IdeficsConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["TFIdeficsDecoderLayer", "TFIdeficsGatedCrossAttentionLayer"]
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+            `past_key_values`).
+
+            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+            information on the default strategy.
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+        position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+        past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+    LLAMA_START_DOCSTRING,
+)
+@keras_serializable
+class TFIdeficsMainLayer(tf.keras.layers.Layer):
+    """
+    Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`]
+
+    Args:
+        config: IdeficsConfig
+    """
+
+    config_class = IdeficsConfig
+
+    def __init__(self, config: IdeficsConfig, add_pooling_year: bool = True, **kwargs):
+        super().__init__(**kwargs)
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = TFIdeficsDecoupledEmbedding(
+            num_embeddings=config.vocab_size,
+            num_additional_embeddings=config.additional_vocab_size,
+            embedding_dim=config.hidden_size,
+            partially_freeze=config.freeze_text_layers,
+            name="embed_tokens",
+        )
+
+        self.image_size = config.vision_config.image_size
+        self.vision_config = config.vision_config
+        self.vision_model = TFIdeficsVisionTransformer(config.vision_config, name="vision_model")
+
+        # Perceiver Resampler
+        if config.use_resampler:
+            perceiver_config = config.perceiver_config
+            self.perceiver_resampler = TFIdeficsPerceiverResampler(
+                config,
+                config.vision_config.embed_dim,
+                perceiver_config.resampler_depth,
+                perceiver_config.resampler_n_heads,
+                perceiver_config.resampler_head_dim,
+                perceiver_config.resampler_n_latents,
+                name="perceiver_resampler",
+            )
+
+        self.decoder_layers = [
+            TFIdeficsDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)
+        ]
+
+        self.cross_layer_interval = config.cross_layer_interval
+        num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
+        self.gated_cross_attn_layers = [
+            TFIdeficsGatedCrossAttentionLayer(config, name=f"gated_cross_attn_layers.{i}")
+            for i in range(num_cross_layers)
+        ]
+        self.gradient_checkpointing = False
+
+        self.norm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="norm")
+
+        self.gradient_checkpointing = False
+        self.freeze_relevant_params(config)
+
+    def freeze_relevant_params(self, config=None):
+        if config is None:
+            config = self.config
+
+        if config.freeze_text_layers:
+            self.freeze_text_layers(config.freeze_text_module_exceptions)
+
+        if config.freeze_vision_layers:
+            freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
+
+    def freeze_text_layers(self, module_exceptions=[]):
+        for module in [self.decoder_layers, self.norm]:
+            freeze_model(module, module_exceptions=module_exceptions)
+
+    def freeze_vision_layers(self, module_exceptions=[]):
+        freeze_model(self.vision_model, module_exceptions=module_exceptions)
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+        # create causal mask
+        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+        combined_attention_mask = None
+        # if input_shape[-1] > 1:
+        combined_attention_mask = _make_causal_mask(
+            input_shape,
+            inputs_embeds.dtype,
+            past_key_values_length=past_key_values_length,
+        )
+
+        if attention_mask is not None:
+            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
+            combined_attention_mask = (
+                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+            )
+
+        return combined_attention_mask
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+    def call(
+        self,
+        input_ids: TFModelInputType | None = None,
+        attention_mask: Optional[tf.Tensor] = None,
+        position_ids: Optional[tf.Tensor] = None,
+        past_key_values: Optional[List[tf.Tensor]] = None,
+        inputs_embeds: Optional[tf.Tensor] = None,
+        pixel_values: Optional[tf.Tensor] = None,
+        image_encoder_embeddings: Optional[tf.Tensor] = None,
+        perceiver_embeddings: Optional[tf.Tensor] = None,
+        image_attention_mask: Optional[tf.Tensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: Optional[bool] = False,
+        return_dict: Optional[bool] = None,
+        training: Optional[bool] = None,
+    ) -> Union[TFIdeficsBaseModelOutputWithPast, Tuple[tf.Tensor]]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            batch_size, seq_length = shape_list(input_ids)
+        elif inputs_embeds is not None:
+            batch_size, seq_length, _ = shape_list(inputs_embeds)
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        seq_length_with_past = seq_length
+        past_key_values_length = 0
+
+        if past_key_values is not None:
+            past_key_values_length = shape_list(past_key_values[0][0])[2]
+            seq_length_with_past = seq_length_with_past + past_key_values_length
+
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = tf.math.cumsum(tf.cast(attention_mask, dtype=tf.int32), axis=-1) - 1
+            position_ids = tf.where(attention_mask == 0, 1, position_ids)
+        elif position_ids is None:
+            position_ids = tf.range(past_key_values_length, seq_length + past_key_values_length, dtype=tf.int32)
+            position_ids = tf.expand_dims(position_ids, 0)
+
+        no_images = False
+        if (
+            sum((int(pixel_values is None), int(image_encoder_embeddings is None), int(perceiver_embeddings is None)))
+            != 2
+        ):
+            raise ValueError(
+                "Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None."
+            )
+
+        elif pixel_values is not None:
+            no_images = tf.reduce_sum(tf.cast(pixel_values, dtype=tf.int32)) == 0
+            pixel_values = tf.cast(pixel_values, dtype=self.dtype)  # fp16 compatibility
+            # Below hack is because when cross-loading pytorch weights, there is an
+            # initial forward pass with dummy input and code below is here to handle that
+            if len(pixel_values.shape) == 4:
+                batch_size = shape_list(pixel_values)[0]
+                num_images = shape_list(pixel_values)[0]
+                # pixel_values = tf.reshape(pixel_values, [batch_size * num_images, *pixel_values.shape[1:]])
+            elif len(pixel_values.shape) == 5:
+                batch_size, num_images = shape_list(pixel_values)[:2]
+                pixel_values = tf.reshape(pixel_values, [batch_size * num_images, *pixel_values.shape[2:]])
+
+            # Get sequence from the vision encoder
+            image_hidden_states = self.vision_model(
+                pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
+            ).last_hidden_state
+
+        elif image_encoder_embeddings is not None:
+            batch_size, num_images, image_seq_len, image_hidden_size = shape_list(image_encoder_embeddings)
+            image_hidden_states = tf.cast(image_encoder_embeddings, dtype=self.dtype)
+            image_hidden_states = tf.reshape(
+                image_hidden_states, (batch_size * num_images, image_seq_len, image_hidden_size)
+            )
+
+        if self.config.use_resampler:
+            if perceiver_embeddings is None:
+                perceiver_embeddings = self.perceiver_resampler(image_hidden_states)
+                image_seq_len, image_hidden_size = shape_list(perceiver_embeddings)[1:3]
+            else:
+                batch_size, num_images, image_seq_len, image_hidden_size = shape_list(perceiver_embeddings)
+            image_hidden_states = perceiver_embeddings
+        elif perceiver_embeddings is None:
+            image_seq_len, image_hidden_size = shape_list(image_hidden_states)[1:3]
+        else:
+            raise ValueError("If `perceiver_embeddings` are passed, use_resampler should be True")
+
+        image_hidden_states = tf.reshape(
+            image_hidden_states, (batch_size, num_images * image_seq_len, image_hidden_size)
+        )
+        # # Hack to use the model in full language modeling mode
+        # image_attention_mask = tf.zeros((batch_size, seq_length, 1), dtype=tf.int32)
+
+        # this is to account for the dummy inputs
+        if pixel_values is not None and len(pixel_values.shape) == 4 and image_attention_mask is None:
+            image_attention_mask = tf.zeros((batch_size, seq_length, 1), dtype=tf.int32)
+
+        text_seq_len = shape_list(image_attention_mask)[1]
+        image_attention_mask = tf.expand_dims(image_attention_mask, -1)
+        image_attention_mask = tf.repeat(image_attention_mask, repeats=image_seq_len)
+        image_attention_mask = tf.reshape(image_attention_mask, (batch_size, text_seq_len, num_images * image_seq_len))
+
+        if image_hidden_states is not None:
+            image_batch_size, image_sequence_length, _ = shape_list(image_hidden_states)
+            image_hidden_shape = (image_batch_size, image_sequence_length)
+            if image_attention_mask is None:
+                image_attention_mask = tf.ones(image_hidden_shape, dtype=tf.int32)
+            image_attention_mask = invert_attention_mask(image_attention_mask)
+        else:
+            image_attention_mask = None
+
+        cross_attention_gate = tf.squeeze(
+            tf.cast(tf.reduce_any(image_attention_mask == 0, axis=-1), dtype=self.dtype), axis=1
+        )
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+        # embed positions
+        if attention_mask is None:
+            attention_mask = tf.ones((batch_size, seq_length_with_past), dtype=tf.bool)
+        attention_mask = self._prepare_decoder_attention_mask(
+            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+        )
+
+        hidden_states = inputs_embeds
+
+        if self.gradient_checkpointing and training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        next_decoder_cache = () if use_cache else None
+
+        for idx, decoder_layer in enumerate(self.decoder_layers):
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            def vblock(
+                main_block,
+                hidden_states,
+                attention_mask,
+                position_ids,
+                past_key_value,
+                image_hidden_states,
+                image_attention_mask,
+                cross_attention_gate,
+                output_attentions,
+                use_cache,
+                layer_idx,
+                cross_layer_interval,
+                gated_cross_attn_layers,
+            ):
+                # TODO(ls): Add cross attention values to respective lists
+                if layer_idx % cross_layer_interval == 0:
+                    xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
+                    outputs = xblock(
+                        hidden_states,
+                        attention_mask=attention_mask,
+                        image_hidden_states=image_hidden_states,
+                        image_attention_mask=image_attention_mask,
+                        cross_attention_gate=cross_attention_gate,
+                        output_attentions=output_attentions,
+                        use_cache=use_cache,
+                        past_key_value=None,  # not implemented
+                    )
+                    hidden_states = outputs[0]
+
+                layer_outputs = main_block(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    position_ids=position_ids,
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+
+                return layer_outputs
+
+            if self.gradient_checkpointing and training:
+                past_key_value = None
+                if use_cache:
+                    logger.warning_once(
+                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                    )
+                    use_cache = False
+
+                layer_outputs = tf.recompute_grad(
+                    vblock,
+                    decoder_layer,
+                    hidden_states,
+                    attention_mask,
+                    position_ids,
+                    past_key_value,
+                    image_hidden_states,
+                    image_attention_mask,
+                    output_attentions,
+                    use_cache,
+                    no_images,
+                    idx,
+                    self.cross_layer_interval,
+                    self.gated_cross_attn_layers,
+                )
+            else:
+                layer_outputs = vblock(
+                    decoder_layer,
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    position_ids=position_ids,
+                    past_key_value=past_key_value,
+                    image_hidden_states=image_hidden_states,
+                    image_attention_mask=image_attention_mask,
+                    cross_attention_gate=cross_attention_gate,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                    layer_idx=idx,
+                    cross_layer_interval=self.cross_layer_interval,
+                    gated_cross_attn_layers=self.gated_cross_attn_layers,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+        hidden_states = self.norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        image_hidden_states = tf.reshape(
+            image_hidden_states, (batch_size, num_images, image_seq_len, image_hidden_size)
+        )
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states]
+                if v is not None
+            )
+        return TFIdeficsBaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            image_hidden_states=image_hidden_states,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "embed_tokens", None) is not None:
+            with tf.name_scope(self.embed_tokens.name):
+                self.embed_tokens.build(None)
+        if getattr(self, "vision_model", None) is not None:
+            with tf.name_scope(self.vision_model.name):
+                self.vision_model.build(None)
+        if getattr(self, "norm", None) is not None:
+            with tf.name_scope(self.norm.name):
+                self.norm.build(None)
+        if getattr(self, "perceiver_resampler", None) is not None:
+            with tf.name_scope(self.perceiver_resampler.name):
+                self.perceiver_resampler.build(None)
+        if getattr(self, "decoder_layers", None) is not None:
+            for layer in self.decoder_layers:
+                with tf.name_scope(layer.name):
+                    layer.build(None)
+        if getattr(self, "gated_cross_attn_layers", None) is not None:
+            for layer in self.gated_cross_attn_layers:
+                with tf.name_scope(layer.name):
+                    layer.build(None)
+
+
+class TFIdeficsModel(TFIdeficsPreTrainedModel):
+    def __init__(self, config: IdeficsConfig, *inputs, **kwargs):
+        super().__init__(config, *inputs, **kwargs)
+
+        self.model = TFIdeficsMainLayer(config, name="model")
+
+    def call(
+        self,
+        input_ids: TFModelInputType | None = None,
+        attention_mask: Optional[tf.Tensor] = None,
+        position_ids: Optional[tf.Tensor] = None,
+        past_key_values: Optional[List[tf.Tensor]] = None,
+        inputs_embeds: Optional[tf.Tensor] = None,
+        pixel_values: Optional[tf.Tensor] = None,
+        image_encoder_embeddings: Optional[tf.Tensor] = None,
+        perceiver_embeddings: Optional[tf.Tensor] = None,
+        image_attention_mask: Optional[tf.Tensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: Optional[bool] = False,
+        return_dict: Optional[bool] = None,
+        training: Optional[bool] = None,
+    ) -> Union[TFIdeficsBaseModelOutputWithPast, Tuple[tf.Tensor]]:
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            pixel_values=pixel_values,
+            image_encoder_embeddings=image_encoder_embeddings,
+            perceiver_embeddings=perceiver_embeddings,
+            image_attention_mask=image_attention_mask,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            return_dict=return_dict,
+            training=training,
+        )
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "model", None) is not None:
+            with tf.name_scope(self.model.name):
+                self.model.build(None)
+
+
+class TFIdeficsForVisionText2Text(TFPreTrainedModel, TFCausalLanguageModelingLoss):
+    _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
+    _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
+    config_class = IdeficsConfig
+
+    def __init__(self, config, vision_model=None, **kwargs):
+        super().__init__(config, **kwargs)
+        self.model = TFIdeficsMainLayer(config, name="model")
+        self.lm_head = TFIdeficsDecoupledLinear(
+            config.hidden_size,
+            config.vocab_size,
+            config.additional_vocab_size,
+            bias=False,
+            partially_freeze=config.freeze_lm_head,
+            name="lm_head",
+        )
+
+    def get_input_embeddings(self):
+        return self.model.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.model = decoder
+
+    def get_decoder(self):
+        return self.model
+
+    def tie_weights(self):
+        """
+        Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of
+        IdeficsDecoupledLinear and IdeficsDecoupledEmbedding.
+        """
+        output_embeddings = self.get_output_embeddings()
+        input_embeddings = self.get_input_embeddings()
+
+        if getattr(self.config, "tie_word_embeddings", True):
+            output_embeddings.weight = input_embeddings.weight
+            if input_embeddings.num_additional_embeddings > 0:
+                assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
+                output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
+
+        if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
+            output_embeddings.out_features = input_embeddings.num_embeddings
+            if hasattr(output_embeddings, "out_additional_features") and hasattr(
+                input_embeddings, "num_additional_embeddings"
+            ):
+                output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=TFIdeficsCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+    def call(
+        self,
+        input_ids: TFModelInputType | None = None,
+        attention_mask: Optional[tf.Tensor] = None,
+        position_ids: Optional[tf.Tensor] = None,
+        past_key_values: Optional[List[tf.Tensor]] = None,
+        inputs_embeds: Optional[tf.Tensor] = None,
+        pixel_values: Optional[tf.Tensor] = None,
+        image_encoder_embeddings: Optional[tf.Tensor] = None,
+        perceiver_embeddings: Optional[tf.Tensor] = None,
+        image_attention_mask: Optional[tf.Tensor] = None,
+        labels: Optional[tf.Tensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: Optional[bool] = False,
+        return_dict: Optional[bool] = None,
+        training=False,
+    ) -> Union[TFIdeficsCausalLMOutputWithPast, Tuple[tf.Tensor]]:
+        r"""
+        Args:
+            labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+        Returns:
+
+        Example:
+
+        ```python
+        >> from transformers import AutoTokenizer, TFIdeficsForVisionText2Text
+
+        >> model = TFIdeficsForVisionText2Text.from_pretrained("HuggingFaceM4/idefics-9b")
+        >> tokenizer = AutoTokenizer.from_pretrained("HuggingFaceM4/idefics-9b")
+
+        >> prompt = "Hey, are you consciours? Can you talk to me?"
+        >> inputs = tokenizer(prompt, return_tensors="tf")
+
+        >> # Generate
+        >> generate_ids = model.generate(inputs.input_ids, max_length=30)
+        >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+        "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+        ```"""
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            pixel_values=pixel_values,
+            image_encoder_embeddings=image_encoder_embeddings,
+            perceiver_embeddings=perceiver_embeddings,
+            image_attention_mask=image_attention_mask,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        hidden_states = outputs[0]
+        logits = self.lm_head(hidden_states)
+
+        loss = None
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            if attention_mask is not None:
+                shift_attention_mask = attention_mask[..., 1:]
+                shift_logits = logits[..., :-1, :][shift_attention_mask != 0]
+                shift_labels = labels[..., 1:][shift_attention_mask != 0]
+            else:
+                shift_logits = logits[..., :-1, :]
+                shift_labels = labels[..., 1:]
+            # Flatten the tokens
+            loss = self.hf_compute_loss(
+                labels=tf.reshape(shift_labels, [-1]), logits=tf.reshape(shift_logits, [-1, shift_logits.shape[-1]])
+            )
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return TFIdeficsCausalLMOutputWithPast(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            image_hidden_states=outputs.image_hidden_states,
+        )
+
+    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
+        image_hidden_states = kwargs.pop("image_hidden_states", None)
+        if image_hidden_states is not None:
+            if self.config.use_resampler:
+                kwargs["perceiver_embeddings"] = image_hidden_states
+            else:
+                kwargs["image_encoder_embeddings"] = image_hidden_states
+            kwargs["pixel_values"] = None
+        inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
+        unwanted_kwargs = ["token_type_ids"]
+        for kwarg in unwanted_kwargs:
+            inputs.pop(kwarg, None)
+        return inputs
+
+    @staticmethod
+    def _expand_inputs_for_generation(
+        *args,
+        **model_kwargs,
+    ):
+        return expand_inputs_for_generation(*args, **model_kwargs)
+
+    @staticmethod
+    def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder):
+        return update_model_kwargs_for_generation(outputs, model_kwargs)
+
+    @staticmethod
+    def _reorder_cache(past, beam_idx):
+        reordered_past = ()
+        for layer_past in past:
+            reordered_past += (tuple(tf.gather(past_state, beam_idx) for past_state in layer_past),)
+        return reordered_past
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "model", None) is not None:
+            with tf.name_scope(self.model.name):
+                self.model.build(None)
+        if getattr(self, "lm_head", None) is not None:
+            with tf.name_scope(self.lm_head.name):
+                self.lm_head.build(None)
diff --git a/src/transformers/models/idefics/perceiver_tf.py b/src/transformers/models/idefics/perceiver_tf.py
new file mode 100644
index 00000000000000..c9e76004a70ddc
--- /dev/null
+++ b/src/transformers/models/idefics/perceiver_tf.py
@@ -0,0 +1,194 @@
+# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License.
+#
+# MIT License
+#
+# Copyright (c) 2020  The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+"""
+
+Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially
+time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note
+that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to
+prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that
+to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore.
+
+References:
+    - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
+    - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch
+
+"""
+from typing import Optional, Tuple
+
+import tensorflow as tf
+
+from ...modeling_tf_utils import shape_list
+from .configuration_idefics import IdeficsConfig
+
+
+class TFIdeficsPerceiverResampler(tf.keras.layers.Layer):
+    def __init__(
+        self, config: IdeficsConfig, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int, **kwargs
+    ) -> None:
+        """
+        Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
+        MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
+        returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed
+        to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler.
+        Could be e.g., VIT embed_dim, ResNet pool dim, and so on.
+
+        Args:
+            config (`IdeficsConfig`): config object
+            embed_dim (`int`): The size of each embedding vector
+            depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
+            n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention).
+            head_dim (`int`): Dimensionality of each head projection in the Transformer block.
+            n_latents (`int`):
+                Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
+
+        """
+        super().__init__(**kwargs)
+        self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents
+        self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
+
+        self.intermediate_dim = (
+            self.embed_dim * 4
+            if not hasattr(config.vision_config, "embed_dim")
+            else config.vision_config.embed_dim * 4
+        )
+        # Create Transformer Blocks
+        self.blocks = []
+        for i in range(depth):
+            self.blocks.append(
+                [
+                    TFIdeficsPerceiverAttention(
+                        self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms, name=f"blocks.{i}.0"
+                    ),
+                    TFIdeficsMLP(self.intermediate_dim, config, name=f"blocks.{i}.1"),
+                ]
+            )
+
+        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
+
+    def build(self, input_shape):
+        # Create Latents for Perceiver
+        self.latents = self.add_weight(
+            shape=(self.n_latents, self.embed_dim), initializer="random_normal", trainable=True, name="latents"
+        )
+        super().build(input_shape)
+
+    def call(self, context: tf.Tensor) -> tf.Tensor:
+        """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
+        # tf.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
+        latents = tf.expand_dims(self.latents, axis=0)
+        latents = tf.tile(latents, [tf.shape(context)[0], 1, 1])
+        # Feed through Perceiver Attention blocks...
+        for attn, ff in self.blocks:
+            latents = attn(context, latents) + latents
+            latents = ff(latents) + latents
+        return self.layer_norm(latents)
+
+
+class TFIdeficsPerceiverAttention(tf.keras.layers.Layer):
+    def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool, **kwargs) -> None:
+        """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
+        super().__init__(**kwargs)
+        self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
+        self.qk_layer_norms = qk_layer_norms
+        # Normalization & Scaling
+        self.context_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="context_layer_norm")
+        self.latents_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="latents_layer_norm")
+        if self.qk_layer_norms:
+            self.q_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="q_layer_norm")
+            self.k_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="k_layer_norm")
+
+        self.qk_scale = self.head_dim**-0.5
+
+        # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
+        self.q_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="q_proj")
+        self.k_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="k_proj")
+        self.v_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="v_proj")
+
+        self.output_proj = tf.keras.layers.Dense(embed_dim, use_bias=False, name="output_proj")
+
+    def call(self, context: tf.Tensor, latents: tf.Tensor) -> tf.Tensor:
+        """
+        Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
+
+        Args:
+            context (`tf.Tensor`):
+                Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample.
+            latents (`tf.Tensor`):
+                Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to.
+
+        Returns:
+            `tf.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross
+            from context.
+        """
+        context = self.context_layer_norm(context)
+        latents = self.latents_layer_norm(latents)
+        batch_size, seq_length, embed_dim = shape_list(context)
+
+        # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
+        #   Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
+        q = self.q_proj(latents)
+        k = self.k_proj(tf.concat([context, latents], axis=-2))
+        v = self.v_proj(tf.concat([context, latents], axis=-2))
+
+        # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
+        #   =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
+        q, k, v = [
+            tf.transpose(tf.reshape(x, (batch_size, x.shape[1], self.n_heads, self.head_dim)), perm=[0, 2, 1, 3])
+            for x in (q, k, v)
+        ]
+
+        if self.qk_layer_norms:
+            q = self.q_layer_norm(q)
+            k = self.k_layer_norm(k)
+
+        scores = tf.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
+        stabilized_scores = scores - tf.reduce_max(scores, axis=-1, keepdims=True)
+        attn = tf.nn.softmax(stabilized_scores, axis=-1)
+
+        # Attend & project back to output...
+        resampled = tf.einsum("... i j, ... j d -> ... i d", attn, v)
+        return self.output_proj(
+            tf.reshape(tf.transpose(resampled, perm=[0, 2, 1, 3]), (batch_size, -1, self.n_heads * self.head_dim))
+        )
+
+
+class TFIdeficsMLP(tf.keras.layers.Layer):
+    def __init__(self, intermediate_size, config: IdeficsConfig, **kwargs):
+        """Simple MLP block with intermediate_size and embedding size"""
+        super().__init__(**kwargs)
+        self.embed_dim = config.vision_config.embed_dim
+        self.ln = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="ln")
+        self.fc = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="fc")
+        self.act = tf.keras.layers.ReLU(name="act")
+        self.c_proj = tf.keras.layers.Dense(self.embed_dim, use_bias=False, name="c_proj")
+
+    def call(self, hidden_states: Optional[Tuple[tf.Tensor]]) -> tf.Tensor:
+        hidden_states = self.ln(hidden_states)
+        hidden_states = self.fc(hidden_states)
+        hidden_states = self.act(hidden_states)
+        hidden_states = self.c_proj(hidden_states)
+
+        return hidden_states
diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py
index d7fd8c8de6555e..2afe2a49781245 100644
--- a/src/transformers/models/idefics/processing_idefics.py
+++ b/src/transformers/models/idefics/processing_idefics.py
@@ -22,34 +22,53 @@
 from ...feature_extraction_utils import BatchFeature
 from ...processing_utils import ProcessorMixin
 from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
-from ...utils import TensorType, is_torch_available
+from ...utils import is_tf_available, is_torch_available
 
 
 if is_torch_available():
     import torch
 
+if is_tf_available():
+    import tensorflow as tf
 
 IMAGE_TOKEN = ""
 
 
 # copied from m4.training.packing
-def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1):
-    # This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]]
-
-    # If any of images index are more than num_classes, set them to -1.
-    # Words after the max number of images allowed have been seen don't attend on anything
+def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_classes=-1):
+    # Set elements >= num_classes to -1
     if num_classes != -1:
-        incremental_mask[incremental_mask >= num_classes] = -1
+        if return_tensors == "pt":
+            incremental_mask[incremental_mask >= num_classes] = -1
+        elif return_tensors == "tf":
+            incremental_mask = tf.where(incremental_mask >= num_classes, -1, incremental_mask)
+
+    # Create mask for negative values
+    if return_tensors == "pt":
+        negatives = incremental_mask == -1
+        incremental_mask[negatives] = 0
+        attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
+        attn_mask[negatives, :] = 0
+    elif return_tensors == "tf":
+        negatives = tf.equal(incremental_mask, -1)
+        incremental_mask = tf.where(negatives, 0, incremental_mask)
+        attn_mask = tf.one_hot(incremental_mask, depth=num_classes)
+        # Reshape 'negatives' to add an extra dimension, making it [batch_size, seq_length, 1]
+        negatives_expanded = tf.expand_dims(negatives, -1)
+        attn_mask = tf.where(negatives_expanded, tf.zeros_like(attn_mask), attn_mask)
 
-    negatives = incremental_mask == -1
-    incremental_mask[negatives] = 0
-    attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
-    attn_mask[negatives, :] = 0
     return attn_mask
 
 
 # copied from m4.training.packing
-def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
+def image_attention_mask_for_packed_input_ids(input_ids, tokenizer, return_tensors):
+    if return_tensors == "pt":
+        return image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer)
+    elif return_tensors == "tf":
+        return image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer)
+
+
+def image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer):
     image_attention_mask = torch.full_like(input_ids, fill_value=-1)
     next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
     image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
@@ -96,6 +115,39 @@ def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
     return image_attention_mask, next_image_attention_mask
 
 
+def image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer):
+    image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
+    eod_token_id = tokenizer.eos_token_id
+    batch_size = tf.shape(input_ids)[0]
+    image_attention_mask = tf.fill(tf.shape(input_ids), -1)
+    next_image_attention_mask = tf.fill(tf.shape(input_ids), -1)
+
+    for batch_idx in range(batch_size):
+        count = -1
+        seen_eod = False
+        seq_length = tf.shape(input_ids)[1]
+
+        for idx in range(seq_length - 1, -1, -1):
+            token_id = input_ids[batch_idx, idx].numpy()
+            if token_id == image_token_id:
+                count += 1
+                indices = [[batch_idx, idx]]
+                updates = [count]
+                image_attention_mask = tf.tensor_scatter_nd_update(image_attention_mask, indices, updates)
+                next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates)
+            elif token_id == eod_token_id and not seen_eod:
+                seen_eod = True
+                count = 0
+                indices = [[batch_idx, idx]]
+                updates = [count]
+                next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates)
+            if seen_eod and token_id != eod_token_id:
+                indices = [[batch_idx, idx]]
+                updates = [-1]
+                next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates)
+    return image_attention_mask, next_image_attention_mask
+
+
 def is_url(string):
     """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
     invalidated the url"""
@@ -156,7 +208,7 @@ def __call__(
         add_eos_token=False,
         add_end_of_utterance_token=None,
         debug=False,
-        return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
+        return_tensors="pt",
     ) -> BatchEncoding:
         """This method takes batched or non-batched prompts made of text and images and converts them into prompts that
         the model was trained on and prepares the image pixel values for the model to process.
@@ -268,7 +320,6 @@ def __call__(
         # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
         if add_end_of_utterance_token is None:
             add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token
-
         # turn non-batched prompts into batched
         if not any(isinstance(i, list) for i in prompts):
             prompts = [prompts]
@@ -322,7 +373,7 @@ def image_tokens(last_was_image):
             if debug is True:
                 print(f"{full_text=}")
 
-            image_objects = self.image_processor(image_objects, transform=transform)
+            image_objects = self.image_processor(image_objects, transform=transform, return_tensors=return_tensors)
 
             all_prompts.append(full_text)
             all_images.append(image_objects)
@@ -345,39 +396,72 @@ def image_tokens(last_was_image):
         output_input_ids = []
         output_images = []
         output_attention_masks = []
+
         for text, attention_mask, images in zip(all_texts, all_attention_masks, all_images):
             padded_input_ids = text
-
             image_count = padded_input_ids.count(self.image_token_id)
             local_max_num_images = min(image_count, max_num_images)
 
             current_images = images[:local_max_num_images]
 
             if len(current_images) > 0:
-                padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
-                padded_image_tensor[: current_images.size(0)] = current_images
+                if return_tensors == "pt":
+                    padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
+                    padded_image_tensor[: current_images.size(0)] = current_images
+                elif return_tensors == "tf":
+                    # Assuming current_images is a TensorFlow tensor
+                    # Get the shape of current_images, excluding the first dimension
+                    image_shape = tf.shape(current_images)[1:]
+                    # Create a shape for the padded_image_tensor
+                    padded_shape = tf.concat([[max_num_images], image_shape], axis=0)
+                    # Create the padded_image_tensor of zeros
+                    padded_image_tensor = tf.zeros(padded_shape, dtype=current_images.dtype)
+                    # Get the number of images (assuming current_images has shape [num_images, height, width, channels])
+                    num_images = tf.shape(current_images)[0]
+                    # Update the padded_image_tensor with the values from current_images
+                    indices = tf.reshape(tf.range(num_images), (-1, 1))
+                    updates = current_images
+                    padded_image_tensor = tf.tensor_scatter_nd_update(padded_image_tensor, indices, updates)
             else:
-                padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims)
+                if return_tensors == "pt":
+                    padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims)
+                elif return_tensors == "tf":
+                    padded_image_tensor = tf.zeros((max_num_images, *self.default_image_dims))
 
             output_images.append(padded_image_tensor)
-            output_input_ids.append(torch.tensor(padded_input_ids))
-            output_attention_masks.append(torch.tensor(attention_mask))
-
-        output_input_ids = torch.stack(output_input_ids)
-        output_images = torch.stack(output_images)
-        output_attention_masks = torch.stack(output_attention_masks)
+            if return_tensors == "pt":
+                output_input_ids.append(torch.tensor(padded_input_ids))
+                output_attention_masks.append(torch.tensor(attention_mask))
+            elif return_tensors == "tf":
+                output_input_ids.append(tf.convert_to_tensor(padded_input_ids, dtype=tf.int32))
+                output_attention_masks.append(attention_mask)
+
+        if return_tensors == "pt":
+            output_input_ids = torch.stack(output_input_ids)
+            output_images = torch.stack(output_images)
+            output_attention_masks = torch.stack(output_attention_masks)
+        elif return_tensors == "tf":
+            output_input_ids = tf.stack(output_input_ids)
+            output_images = tf.stack(output_images)
+            output_attention_masks = tf.stack(output_attention_masks)
 
         if at_least_one_image:
-            image_attention_mask, _ = image_attention_mask_for_packed_input_ids(output_input_ids, self.tokenizer)
+            image_attention_mask, _ = image_attention_mask_for_packed_input_ids(
+                output_input_ids, self.tokenizer, return_tensors
+            )
             image_attention_mask = incremental_to_binary_attention_mask(
-                image_attention_mask, num_classes=max_num_images
+                image_attention_mask, return_tensors, num_classes=max_num_images
             )
         else:
             # in full language mode we set the image mask to all-0s
-            image_attention_mask = torch.zeros(
-                output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool
-            )
-
+            if return_tensors == "pt":
+                image_attention_mask = torch.zeros(
+                    output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool
+                )
+            elif return_tensors == "tf":
+                image_attention_mask = tf.zeros(
+                    (output_input_ids.shape[0], output_input_ids.shape[1], 1), dtype=tf.bool
+                )
         return BatchFeature(
             data={
                 "input_ids": output_input_ids,
diff --git a/src/transformers/models/idefics/vision_tf.py b/src/transformers/models/idefics/vision_tf.py
new file mode 100644
index 00000000000000..0060bb7ac9a7fb
--- /dev/null
+++ b/src/transformers/models/idefics/vision_tf.py
@@ -0,0 +1,573 @@
+# coding=utf-8
+# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TF IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
+
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling
+from ...modeling_tf_utils import TFPreTrainedModel, shape_list
+from ...tf_utils import flatten
+from ...utils import ModelOutput, logging
+from .configuration_idefics import IdeficsVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class TFIdeficsVisionModelOutput(ModelOutput):
+    """
+    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
+
+    Args:
+        image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+            The image embeddings obtained by applying the projection layer to the pooler_output.
+        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    image_embeds: Optional[tf.Tensor] = None
+    last_hidden_state: tf.Tensor = None
+    hidden_states: Optional[Tuple[tf.Tensor]] = None
+    attentions: Optional[Tuple[tf.Tensor]] = None
+
+
+class TFIdeficsVisionEmbeddings(tf.keras.layers.Layer):
+    def __init__(self, config: IdeficsVisionConfig, **kwargs):
+        super().__init__(**kwargs)
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.image_size = config.image_size
+        self.patch_size = config.patch_size
+
+        self.patch_embedding = tf.keras.layers.Conv2D(
+            filters=self.embed_dim,
+            kernel_size=self.patch_size,
+            strides=self.patch_size,
+            use_bias=False,
+            padding="valid",
+            data_format="channels_last",
+            name="patch_embedding",
+        )
+
+        self.num_patches = (self.image_size // self.patch_size) ** 2
+        self.num_positions = self.num_patches + 1
+        self.position_embedding = tf.keras.layers.Embedding(
+            self.num_positions, self.embed_dim, name="position_embedding"
+        )
+        # self.position_ids = tf.range(self.num_positions)[tf.newaxis, :]
+
+    def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
+        num_patches = shape_list(embeddings)[1] - 1
+        pos_embed = self.position_embedding(self.position_ids)
+        num_positions = shape_list(pos_embed)[1] - 1
+        if num_patches == num_positions and height == width:
+            return pos_embed
+        class_pos_embed = pos_embed[:, 0]
+        patch_pos_embed = pos_embed[:, 1:]
+
+        embed_dim = shape_list(embeddings)[-1]
+        num_h_patches = height // self.config.patch_size
+        num_w_patches = width // self.config.patch_size
+        num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1
+        sqrt_num_positions = math.sqrt(float(num_positions))
+        patch_pos_embed = tf.reshape(patch_pos_embed, (1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim))
+
+        scale_height = num_h_patches / sqrt_num_positions
+        scale_width = num_w_patches / sqrt_num_positions
+        original_height = tf.cast(tf.shape(patch_pos_embed)[1], tf.float32)
+        original_width = tf.cast(tf.shape(patch_pos_embed)[2], tf.float32)
+        # Apply scaling
+        new_height = tf.cast(original_height * scale_height, tf.int32)
+        new_width = tf.cast(original_width * scale_width, tf.int32)
+
+        patch_pos_embed = tf.image.resize(
+            patch_pos_embed, size=[new_height, new_width], method=tf.image.ResizeMethod.BICUBIC
+        )
+
+        if (
+            int(num_h_patches) != shape_list(patch_pos_embed)[-3]
+            or int(num_w_patches) != shape_list(patch_pos_embed)[-2]
+        ):
+            raise ValueError(
+                f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the "
+                f"shape of position embedding ({shape_list(patch_pos_embed)[-2], shape_list(patch_pos_embed)[-1]})"
+            )
+        patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, embed_dim))
+        return tf.concat((class_pos_embed[tf.newaxis, :], patch_pos_embed), axis=1)
+
+    def call(self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False) -> tf.Tensor:
+        # Input `pixel_values` is NCHW format which doesn't run on CPU so first thing we do is
+        # transpose it to change it to NHWC. We don't care to transpose it back because
+        # the Conv2D layer is only hit once for each query
+
+        if isinstance(pixel_values, dict):
+            pixel_values = pixel_values["pixel_values"]
+
+        pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+        batch_size, height, width, num_channels = shape_list(pixel_values)
+        if not interpolate_pos_encoding:
+            if height != self.image_size or width != self.image_size:
+                raise ValueError(
+                    f"Input image size ({height}*{width}) doesn't match model"
+                    f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`"
+                )
+
+        patch_embeds = self.patch_embedding(pixel_values)  # shape = [*, width, grid, grid]
+        # Change the 2D spatial dimensions to a single temporal dimension.
+        # shape = (batch_size, num_patches, out_channels=embed_dim)
+        patch_embeds = flatten(patch_embeds, 1, 2)
+
+        class_embeds = tf.broadcast_to(
+            self.class_embedding[tf.newaxis, tf.newaxis, :], [batch_size, 1, self.embed_dim]
+        )
+        embeddings = tf.concat([class_embeds, patch_embeds], axis=1)
+
+        # add positional encoding to each token
+        if interpolate_pos_encoding:
+            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+        else:
+            embeddings = embeddings + self.position_embedding(self.position_ids)
+
+        return embeddings
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        self.position_ids = tf.range(self.num_positions, name="self.position_ids")[tf.newaxis, :]
+        self.class_embedding = self.add_weight(shape=(self.embed_dim,), name="class_embedding")
+        if getattr(self, "patch_embedding", None) is not None:
+            with tf.name_scope(self.patch_embedding.name):
+                self.patch_embedding.build([None, None, None, self.config.num_channels])
+        if getattr(self, "position_embedding", None) is not None:
+            with tf.name_scope(self.position_embedding.name):
+                self.position_embedding.build(None)
+
+
+class TFIdeficsVisionAttention(tf.keras.layers.Layer):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config, **kwargs):
+        super().__init__(**kwargs)
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.embed_dim // self.num_heads
+        if self.head_dim * self.num_heads != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+                f" {self.num_heads})."
+            )
+        self.scale = self.head_dim**-0.5
+        self.dropout = config.attention_dropout
+
+        self.k_proj = tf.keras.layers.Dense(self.embed_dim, name="k_proj")
+        self.v_proj = tf.keras.layers.Dense(self.embed_dim, name="v_proj")
+        self.q_proj = tf.keras.layers.Dense(self.embed_dim, name="q_proj")
+        self.out_proj = tf.keras.layers.Dense(self.embed_dim, name="out_proj")
+
+    def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
+        return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), perm=[0, 2, 1, 3])
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        attention_mask: Optional[tf.Tensor] = None,
+        causal_attention_mask: Optional[tf.Tensor] = None,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[Tuple[tf.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+
+        bsz, tgt_len, embed_dim = shape_list(hidden_states)
+
+        # get query proj
+        query_states = self.q_proj(hidden_states) * self.scale
+        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
+        key_states = tf.reshape(key_states, proj_shape)
+        value_states = tf.reshape(value_states, proj_shape)
+
+        src_len = shape_list(key_states)[1]
+        attn_weights = tf.linalg.matmul(query_states, key_states, transpose_b=True)
+
+        tf.debugging.assert_equal(
+            tf.shape(attn_weights),
+            [bsz * self.num_heads, tgt_len, src_len],
+            message=f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, src_len]}, but is {tf.shape(attn_weights)}",
+        )
+
+        # apply the causal_attention_mask first
+        if causal_attention_mask is not None:
+            if shape_list(causal_attention_mask) != [bsz, 1, tgt_len, src_len]:
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+                    f" {shape_list(causal_attention_mask)}"
+                )
+            attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + causal_attention_mask
+            attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+        if attention_mask is not None:
+            if shape_list(attention_mask) != [bsz, 1, tgt_len, src_len]:
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}"
+                )
+            attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
+            attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+        attn_weights = tf.nn.softmax(attn_weights, axis=-1)
+
+        if output_attentions:
+            # this operation is a bit akward, but it's required to
+            # make sure that attn_weights keeps its gradient.
+            # In order to do so, attn_weights have to reshaped
+            # twice and have to be reused in the following
+            attn_weights_reshaped = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
+            attn_weights = tf.reshape(attn_weights_reshaped, (bsz * self.num_heads, tgt_len, src_len))
+        else:
+            attn_weights_reshaped = None
+
+        attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
+
+        attn_output = tf.linalg.matmul(attn_probs, value_states)
+
+        tf.debugging.assert_equal(
+            tf.shape(attn_output),
+            [bsz * self.num_heads, tgt_len, self.head_dim],
+            message=f"Attention weights should be of size {[bsz * self.num_heads, tgt_len, self.head_dim]}, but is {tf.shape(attn_output)}",
+        )
+
+        attn_output = tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim))
+        attn_output = tf.transpose(attn_output, perm=[0, 2, 1, 3])
+        attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights_reshaped
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "k_proj", None) is not None:
+            with tf.name_scope(self.k_proj.name):
+                self.k_proj.build((self.embed_dim, self.embed_dim))
+        if getattr(self, "v_proj", None) is not None:
+            with tf.name_scope(self.v_proj.name):
+                self.v_proj.build((self.embed_dim, self.embed_dim))
+        if getattr(self, "q_proj", None) is not None:
+            with tf.name_scope(self.q_proj.name):
+                self.q_proj.build((self.embed_dim, self.embed_dim))
+        if getattr(self, "out_proj", None) is not None:
+            with tf.name_scope(self.out_proj.name):
+                self.out_proj.build((self.embed_dim, self.embed_dim))
+
+
+class TFIdeficsVisionMLP(tf.keras.layers.Layer):
+    def __init__(self, config, **kwargs):
+        super().__init__(**kwargs)
+        self.config = config
+        self.activation_fn = get_tf_activation(config.hidden_act)
+        self.fc1 = tf.keras.layers.Dense(config.intermediate_size, name="fc1")
+        self.fc2 = tf.keras.layers.Dense(config.hidden_size, name="fc2")
+
+    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+        hidden_states = self.fc1(hidden_states)
+        hidden_states = self.activation_fn(hidden_states)
+        hidden_states = self.fc2(hidden_states)
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "fc1", None) is not None:
+            with tf.name_scope(self.fc1.name):
+                self.fc1.build(self.config.hidden_size)
+        if getattr(self, "fc2", None) is not None:
+            with tf.name_scope(self.fc2.name):
+                self.fc2.build(self.config.intermediate_size)
+
+
+class TFIdeficsVisionEncoderLayer(tf.keras.layers.Layer):
+    def __init__(self, config: IdeficsVisionConfig, **kwargs):
+        super().__init__(**kwargs)
+        self.embed_dim = config.hidden_size
+        self.self_attn = TFIdeficsVisionAttention(config, name="self_attn")
+        self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
+        self.mlp = TFIdeficsVisionMLP(config, name="mlp")
+        self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        attention_mask: tf.Tensor,
+        causal_attention_mask: tf.Tensor,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[tf.Tensor]:
+        """
+        Args:
+            hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`tf.Tensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+                `(config.encoder_attention_heads,)`.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        residual = hidden_states
+
+        hidden_states = self.layer_norm1(hidden_states)
+        hidden_states, attn_weights = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            causal_attention_mask=causal_attention_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.layer_norm2(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "layer_norm1", None) is not None:
+            with tf.name_scope(self.layer_norm1.name):
+                self.layer_norm1.build([None, None, self.embed_dim])
+        if getattr(self, "layer_norm2", None) is not None:
+            with tf.name_scope(self.layer_norm2.name):
+                self.layer_norm2.build([None, None, self.embed_dim])
+
+
+class TFIdeficsVisionEncoder(tf.keras.layers.Layer):
+    """
+    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+    [`TFIdeficsVisionEncoderLayer`].
+
+    Args:
+        config: IdeficsVisionConfig
+    """
+
+    def __init__(self, config: IdeficsVisionConfig, **kwargs):
+        super().__init__(**kwargs)
+        self.config = config
+        self.layers = [
+            TFIdeficsVisionEncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)
+        ]
+        self.gradient_checkpointing = False
+
+    def call(
+        self,
+        inputs_embeds,
+        attention_mask: Optional[tf.Tensor] = None,
+        causal_attention_mask: Optional[tf.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        training: Optional[bool] = None,
+    ) -> Union[Tuple, TFBaseModelOutput]:
+        r"""
+        Args:
+            inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+                than the model's internal embedding lookup matrix.
+            attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            causal_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Causal mask for the text model. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        encoder_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        hidden_states = inputs_embeds
+        for idx, encoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+            if self.gradient_checkpointing and training:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs, output_attentions)
+
+                    return custom_forward
+
+                layer_outputs = tf.recompute_grad(
+                    create_custom_forward(encoder_layer),
+                    hidden_states,
+                    attention_mask,
+                    causal_attention_mask,
+                )
+            else:
+                layer_outputs = encoder_layer(
+                    hidden_states,
+                    attention_mask,
+                    causal_attention_mask,
+                    output_attentions=output_attentions,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            encoder_states = encoder_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+        return TFBaseModelOutput(
+            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "layers", None) is not None:
+            for layer in self.layers:
+                with tf.name_scope(layer.name):
+                    layer.build(None)
+
+
+class TFIdeficsVisionTransformer(TFPreTrainedModel):
+    def __init__(self, config: IdeficsVisionConfig, **kwargs):
+        super().__init__(config, **kwargs)
+        self.config = config
+        self.embed_dim = config.hidden_size
+
+        self.embeddings = TFIdeficsVisionEmbeddings(config, name="embeddings")
+        self.pre_layrnorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="pre_layrnorm")
+        self.encoder = TFIdeficsVisionEncoder(config, name="encoder")
+        self.post_layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="post_layernorm")
+
+    # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
+    def call(
+        self,
+        pixel_values: Optional[tf.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: Optional[bool] = False,
+        return_dict: Optional[bool] = None,
+        training: Optional[bool] = False,
+    ) -> Union[Tuple, TFBaseModelOutputWithPooling]:
+        r"""
+        Returns:
+
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+        hidden_states = self.pre_layrnorm(hidden_states)
+        encoder_outputs = self.encoder(
+            inputs_embeds=hidden_states,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        last_hidden_state = encoder_outputs[0]
+        pooled_output = last_hidden_state[:, 0, :]
+        pooled_output = self.post_layernorm(pooled_output)
+
+        if not return_dict:
+            return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+        return TFBaseModelOutputWithPooling(
+            last_hidden_state=last_hidden_state,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "embeddings", None) is not None:
+            with tf.name_scope(self.embeddings.name):
+                self.embeddings.build(None)
+        if getattr(self, "pre_layrnorm", None) is not None:
+            with tf.name_scope(self.pre_layrnorm.name):
+                self.pre_layrnorm.build([None, None, self.embed_dim])
+        if getattr(self, "encoder", None) is not None:
+            with tf.name_scope(self.encoder.name):
+                self.encoder.build(None)
+        if getattr(self, "post_layernorm", None) is not None:
+            with tf.name_scope(self.post_layernorm.name):
+                self.post_layernorm.build([None, self.embed_dim])
diff --git a/src/transformers/tf_utils.py b/src/transformers/tf_utils.py
index 75e302947e8066..b91a2ea520f0d0 100644
--- a/src/transformers/tf_utils.py
+++ b/src/transformers/tf_utils.py
@@ -104,6 +104,33 @@ def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1):
     return outputs
 
 
+def scaled_dot_product_attention(
+    query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale: float = None
+):
+    """TF equivalent for torch's nn.functional.scaled_dot_product_attention"""
+    if dropout_p != 0.0:
+        raise ValueError(
+            "Dropout is not supported in this implementation - file an issue "
+            "with Transformers and ping @Rocketknight1 if you need it for a port!"
+        )
+    if is_causal and attn_mask is not None:
+        raise ValueError("You cannot specify an attn_mask and is_causal at the same time!")
+    if is_causal:
+        attn_mask = tf.ones((tf.shape(query)[-2], tf.shape(key)[-2]), dtype=tf.int32)
+        attn_mask = tf.experimental.numpy.tril(attn_mask, k=0)
+    if attn_mask is not None and (attn_mask.dtype.is_integer or attn_mask.dtype.is_bool):
+        # Convert boolean mask to a negative logit bias
+        attn_mask = tf.where(attn_mask > 0, tf.cast(0.0, query.dtype), tf.cast(-1000.0, query.dtype))
+    logits = tf.einsum("...qd, ...kd -> ...qk", query, key)
+    if scale is None:
+        scale = tf.cast(tf.shape(key)[-1], logits.dtype) ** -0.5
+    logits *= scale  # scale by 1/sqrt(key_dim)
+    if attn_mask is not None:
+        logits += attn_mask
+    probs = tf.nn.softmax(logits)
+    return probs @ value
+
+
 def flatten(input, start_dim=0, end_dim=-1):
     # Replicates the behavior of torch.flatten in TF
 
diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py
index 5d4c28cbcc4595..e0b396c7164a75 100644
--- a/src/transformers/utils/dummy_tf_objects.py
+++ b/src/transformers/utils/dummy_tf_objects.py
@@ -1542,6 +1542,27 @@ def __init__(self, *args, **kwargs):
         requires_backends(self, ["tf"])
 
 
+class TFIdeficsForVisionText2Text(metaclass=DummyObject):
+    _backends = ["tf"]
+
+    def __init__(self, *args, **kwargs):
+        requires_backends(self, ["tf"])
+
+
+class TFIdeficsModel(metaclass=DummyObject):
+    _backends = ["tf"]
+
+    def __init__(self, *args, **kwargs):
+        requires_backends(self, ["tf"])
+
+
+class TFIdeficsPreTrainedModel(metaclass=DummyObject):
+    _backends = ["tf"]
+
+    def __init__(self, *args, **kwargs):
+        requires_backends(self, ["tf"])
+
+
 class TFLayoutLMForMaskedLM(metaclass=DummyObject):
     _backends = ["tf"]
 
diff --git a/tests/models/idefics/test_image_processing_idefics.py b/tests/models/idefics/test_image_processing_idefics.py
index 6c682ce4a8f8c6..de42a421cd877e 100644
--- a/tests/models/idefics/test_image_processing_idefics.py
+++ b/tests/models/idefics/test_image_processing_idefics.py
@@ -152,7 +152,7 @@ def test_torchvision_numpy_transforms_equivalency(self):
         # they both do the same
 
         image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
-        image_processor = self.image_processing_class(**self.image_processor_dict)
+        image_processor = self.image_processing_class(**self.image_processor_dict, return_tensors="pt")
 
         print(image_inputs)
 
@@ -181,8 +181,8 @@ def convert_to_rgb(image):
             ]
         )
 
-        pixel_values_transform_implied = image_processor(image_inputs, transform=None)
-        pixel_values_transform_supplied = image_processor(image_inputs, transform=transform)
+        pixel_values_transform_implied = image_processor(image_inputs, transform=None, return_tensors="pt")
+        pixel_values_transform_supplied = image_processor(image_inputs, transform=transform, return_tensors="pt")
 
         torch.testing.assert_close(pixel_values_transform_implied, pixel_values_transform_supplied, rtol=0.0, atol=0.0)
 
diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py
index 9f8f177617d200..5c3d45d2e81bcb 100644
--- a/tests/models/idefics/test_modeling_idefics.py
+++ b/tests/models/idefics/test_modeling_idefics.py
@@ -21,6 +21,7 @@
 from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
 from transformers.testing_utils import (
     TestCasePlus,
+    is_pt_tf_cross_test,
     require_bitsandbytes,
     require_torch,
     require_torch_sdpa,
@@ -559,6 +560,11 @@ def check_hidden_states_output(inputs_dict, config, model_class):
 
             check_hidden_states_output(inputs_dict, config, model_class)
 
+    @is_pt_tf_cross_test
+    def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
+        self.has_attentions = False
+        super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
+
     @slow
     def test_model_from_pretrained(self):
         model_name = "HuggingFaceM4/idefics-9b"
diff --git a/tests/models/idefics/test_modeling_tf_idefics.py b/tests/models/idefics/test_modeling_tf_idefics.py
new file mode 100644
index 00000000000000..eeb3faafa223d9
--- /dev/null
+++ b/tests/models/idefics/test_modeling_tf_idefics.py
@@ -0,0 +1,565 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the TF Idefics model. """
+
+import os
+import tempfile
+import unittest
+from importlib import import_module
+
+from transformers import IdeficsConfig, is_tf_available, is_vision_available
+from transformers.testing_utils import TestCasePlus, is_pt_tf_cross_test, require_tf, require_vision, slow
+from transformers.utils import cached_property
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_tf_available():
+    import tensorflow as tf
+
+    from transformers import IdeficsProcessor, TFIdeficsForVisionText2Text, TFIdeficsModel
+    from transformers.modeling_tf_utils import keras
+    from transformers.models.idefics.configuration_idefics import IdeficsPerceiverConfig, IdeficsVisionConfig
+
+if is_vision_available():
+    from PIL import Image
+
+
+IDEFICS_TINY_RANDOM_MODEL = "HuggingFaceM4/tiny-random-idefics"
+
+
+class IdeficsModelTester:
+    def __init__(
+        self,
+        parent,
+        batch_size=1,
+        seq_length=7,
+        image_size=30,
+        patch_size=2,
+        num_channels=3,
+        is_training=True,
+        use_input_mask=True,
+        use_token_type_ids=True,
+        use_labels=True,
+        vocab_size=99,
+        hidden_size=32,
+        num_hidden_layers=5,
+        num_attention_heads=4,
+        intermediate_size=37,
+        hidden_act="gelu",
+        hidden_dropout_prob=0.1,
+        attention_probs_dropout_prob=0.1,
+        max_position_embeddings=512,
+        type_vocab_size=16,
+        type_sequence_label_size=2,
+        initializer_range=0.02,
+        num_labels=3,
+        scope=None,
+        modality_type_vocab_size=2,
+        vision_embed_dim=32,
+        vision_patch_size=2,
+        vision_image_size=30,
+        vision_num_attention_heads=4,
+        vision_num_hidden_layers=5,
+        vision_intermediate_size=37,
+        perceiver_qk_layer_norms_perceiver=False,
+        perceiver_resampler_depth=2,
+        perceiver_resampler_head_dim=8,
+        perceiver_resampler_n_heads=2,
+        perceiver_resampler_n_latents=16,
+    ):
+        self.parent = parent
+        self.batch_size = batch_size
+        self.seq_length = seq_length
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.is_training = is_training
+        self.use_input_mask = use_input_mask
+        self.use_token_type_ids = use_token_type_ids
+        self.use_labels = use_labels
+        self.vocab_size = vocab_size
+        self.hidden_size = hidden_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.intermediate_size = intermediate_size
+        self.hidden_act = hidden_act
+        self.hidden_dropout_prob = hidden_dropout_prob
+        self.attention_probs_dropout_prob = attention_probs_dropout_prob
+        self.max_position_embeddings = max_position_embeddings
+        self.type_vocab_size = type_vocab_size
+        self.type_sequence_label_size = type_sequence_label_size
+        self.initializer_range = initializer_range
+        self.num_labels = num_labels
+        self.scope = scope
+        self.modality_type_vocab_size = modality_type_vocab_size
+
+        self.vision_embed_dim = vision_embed_dim
+        self.vision_patch_size = vision_patch_size
+        self.vision_image_size = vision_image_size
+        self.vision_num_attention_heads = vision_num_attention_heads
+        self.vision_num_hidden_layers = vision_num_hidden_layers
+        self.vision_intermediate_size = vision_intermediate_size
+
+        self.vision_config = IdeficsVisionConfig(
+            embed_dim=self.vision_embed_dim,
+            patch_size=self.vision_patch_size,
+            image_size=self.vision_image_size,
+            num_attention_heads=self.vision_num_attention_heads,
+            num_hidden_layers=self.vision_num_hidden_layers,
+            intermediate_size=self.vision_intermediate_size,
+        )
+
+        self.perceiver_qk_layer_norms_perceiver = perceiver_qk_layer_norms_perceiver
+        self.perceiver_resampler_depth = perceiver_resampler_depth
+        self.perceiver_resampler_head_dim = perceiver_resampler_head_dim
+        self.perceiver_resampler_n_heads = perceiver_resampler_n_heads
+        self.perceiver_resampler_n_latents = perceiver_resampler_n_latents
+
+        self.perceiver_config = IdeficsPerceiverConfig(
+            qk_layer_norms_perceiver=self.perceiver_qk_layer_norms_perceiver,
+            resampler_depth=self.perceiver_resampler_depth,
+            resampler_head_dim=self.perceiver_resampler_head_dim,
+            resampler_n_heads=self.perceiver_resampler_n_heads,
+            resampler_n_latents=self.perceiver_resampler_n_latents,
+        )
+
+        # we set the expected sequence length (which is used in several tests)
+        # this is equal to the seq length of the text tokens + number of image patches + 1 for the CLS token
+        self.expected_seq_len = self.seq_length + (self.image_size // self.patch_size) ** 2 + 1
+
+    def prepare_config_and_inputs(self, num_images=1, interpolate_pos_encoding=False, image_expansion=0):
+        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+        pixel_values = floats_tensor(
+            [
+                self.batch_size,
+                num_images,
+                self.num_channels,
+                self.image_size + image_expansion,
+                self.image_size + image_expansion,
+            ]
+        )
+        input_mask = None
+        if self.use_input_mask:
+            input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+        image_attention_mask = random_attention_mask([self.batch_size, self.seq_length, num_images])
+
+        config = self.get_config()
+        return (config, input_ids, input_mask, pixel_values, image_attention_mask, interpolate_pos_encoding)
+
+    def get_config(self):
+        return IdeficsConfig(
+            image_size=self.image_size,
+            patch_size=self.patch_size,
+            num_channels=self.num_channels,
+            vocab_size=self.vocab_size,
+            hidden_size=self.hidden_size,
+            num_hidden_layers=self.num_hidden_layers,
+            num_attention_heads=self.num_attention_heads,
+            intermediate_size=self.intermediate_size,
+            hidden_act=self.hidden_act,
+            hidden_dropout_prob=self.hidden_dropout_prob,
+            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+            max_position_embeddings=self.max_position_embeddings,
+            type_vocab_size=self.type_vocab_size,
+            is_decoder=False,
+            initializer_range=self.initializer_range,
+            num_labels=self.num_labels,
+            modality_type_vocab_size=self.modality_type_vocab_size,
+            vision_config=self.vision_config,
+        )
+
+    def create_and_check_model(
+        self,
+        config,
+        input_ids,
+        input_mask,
+        pixel_values,
+        image_attention_mask,
+        interpolate_pos_encoding,
+    ):
+        model = TFIdeficsModel(config=config)
+        result = model(
+            input_ids,
+            attention_mask=input_mask,
+            pixel_values=pixel_values,
+            image_attention_mask=image_attention_mask,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+        )
+        self.parent.assertEqual(
+            result.last_hidden_state.shape, (self.batch_size, input_ids.shape[1], self.hidden_size)
+        )
+
+    def create_and_check_model_gen(
+        self,
+        config,
+        input_ids,
+        input_mask,
+        pixel_values,
+        image_attention_mask,
+        interpolate_pos_encoding,
+    ):
+        model = TFIdeficsForVisionText2Text(config)
+        model.generate(
+            input_ids,
+            attention_mask=input_mask,
+            pixel_values=pixel_values,
+            image_attention_mask=image_attention_mask,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            max_length=self.seq_length + 2,
+        )
+
+    def prepare_config_and_inputs_for_common(self):
+        config_and_inputs = self.prepare_config_and_inputs()
+        (
+            config,
+            input_ids,
+            input_mask,
+            pixel_values,
+            image_attention_mask,
+            interpolate_pos_encoding,
+        ) = config_and_inputs
+        inputs_dict = {
+            "input_ids": input_ids,
+            "attention_mask": input_mask,
+            "pixel_values": pixel_values,
+            "image_attention_mask": image_attention_mask,
+            "interpolate_pos_encoding": interpolate_pos_encoding,
+        }
+        return config, inputs_dict
+
+    def prepare_pixel_values(self):
+        return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+
+@require_tf
+class TFIdeficsModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+    all_model_classes = (TFIdeficsModel, TFIdeficsForVisionText2Text) if is_tf_available() else ()
+    pipeline_model_mapping = {"feature-extraction": TFIdeficsModel} if is_tf_available() else {}
+    test_pruning = False
+    test_headmasking = False
+    test_onnx = False
+    test_resize_embeddings = False
+
+    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+        inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
+        # XXX: IdeficsForVisionText2TextTest has no MODEL_FOR group yet, but it should be the same
+        # as MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, so for now manually changing to do the right thing
+        # as super won't do it
+        if return_labels:
+            inputs_dict["labels"] = tf.zeros(
+                (self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int64
+            )
+        return inputs_dict
+
+    def test_model_outputs_equivalence(self):
+        try:
+            orig = self.all_model_classes
+            # IdeficsModel.forward doesn't have labels input arg - only IdeficsForVisionText2Text does
+            self.all_model_classes = (TFIdeficsForVisionText2Text,) if is_tf_available() else ()
+            super().test_model_outputs_equivalence()
+        finally:
+            self.all_model_classes = orig
+
+    def setUp(self):
+        self.model_tester = IdeficsModelTester(self)
+        self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
+
+    def test_config(self):
+        self.config_tester.run_common_tests()
+
+    def test_model_single_image(self):
+        config_and_inputs = self.model_tester.prepare_config_and_inputs(
+            num_images=1, interpolate_pos_encoding=False, image_expansion=0
+        )
+        self.model_tester.create_and_check_model(*config_and_inputs)
+
+    def test_model_multiple_images(self):
+        config_and_inputs = self.model_tester.prepare_config_and_inputs(
+            num_images=2, interpolate_pos_encoding=False, image_expansion=0
+        )
+        self.model_tester.create_and_check_model(*config_and_inputs)
+
+    def test_model_with_image_pos_embeddings_interpolation_single_image(self):
+        config_and_inputs = self.model_tester.prepare_config_and_inputs(
+            num_images=1, interpolate_pos_encoding=True, image_expansion=2
+        )
+        self.model_tester.create_and_check_model(*config_and_inputs)
+        config_and_inputs = self.model_tester.prepare_config_and_inputs(
+            num_images=1, interpolate_pos_encoding=True, image_expansion=0
+        )
+        self.model_tester.create_and_check_model(*config_and_inputs)
+
+    def test_model_with_image_pos_embeddings_interpolation_multiple_images(self):
+        config_and_inputs = self.model_tester.prepare_config_and_inputs(
+            num_images=2, interpolate_pos_encoding=True, image_expansion=2
+        )
+        self.model_tester.create_and_check_model(*config_and_inputs)
+        config_and_inputs = self.model_tester.prepare_config_and_inputs(
+            num_images=2, interpolate_pos_encoding=True, image_expansion=0
+        )
+        self.model_tester.create_and_check_model(*config_and_inputs)
+
+    def test_generate_with_image_pos_embeddings_interpolation_single_image(self):
+        config_and_inputs = self.model_tester.prepare_config_and_inputs(
+            num_images=1, interpolate_pos_encoding=True, image_expansion=2
+        )
+        self.model_tester.create_and_check_model_gen(*config_and_inputs)
+
+    def test_generate_with_image_pos_embeddings_interpolation_multiple_images(self):
+        config_and_inputs = self.model_tester.prepare_config_and_inputs(
+            num_images=2, interpolate_pos_encoding=True, image_expansion=2
+        )
+        self.model_tester.create_and_check_model_gen(*config_and_inputs)
+
+    def test_training_gradient_checkpointing(self):
+        pass
+
+    @unittest.skip(reason="""IDEFICS does not support retaining the gradients of the hidden states and attention""")
+    def test_retain_grad_hidden_states_attentions(self):
+        return
+
+    @unittest.skip(reason="IDEFICS uses out-of-bounds embeddings deliberately.")
+    def test_embeddings_out_of_bounds_raise_exception(self):
+        pass
+
+    @unittest.skip(reason="IDEFICS attention weights are not extracted in scaled_dot_product_attention")
+    def test_prepare_serving_output(self):
+        pass
+
+    def test_model_common_attributes(self):
+        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+        for model_class in self.all_model_classes:
+            model = model_class(config)
+            self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
+            x = model.get_output_embeddings()
+            self.assertTrue(x is None or isinstance(x, tf.keras.layers.Layer))
+
+    def test_attention_outputs(self):
+        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+        config.return_dict = True
+
+        for model_class in self.all_model_classes:
+            inputs_dict["output_attentions"] = True
+            inputs_dict["output_hidden_states"] = False
+            config.return_dict = True
+            model = model_class(config)
+            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+            attentions = outputs.attentions
+
+            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+            # check that output_attentions also work using config
+            del inputs_dict["output_attentions"]
+            config.output_attentions = True
+            model = model_class(config)
+            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+            attentions = outputs.attentions
+            # IDEFICS does not support outputting attention score becuase it uses SDPA under the hood
+            self.assertTrue(attentions[0] is None)
+            out_len = len(outputs)
+
+            # Check attention is always last and order is fine
+            inputs_dict["output_attentions"] = True
+            inputs_dict["output_hidden_states"] = True
+            model = model_class(config)
+            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+            self.assertEqual(out_len + 1, len(outputs))
+
+            self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+
+            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+            # IDEFICS does not support outputting attention score becuase it uses SDPA under the hood
+            self.assertTrue(self_attentions[0] is None)
+
+    def test_hidden_states_output(self):
+        def check_hidden_states_output(inputs_dict, config, model_class):
+            model = model_class(config)
+            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
+
+            expected_num_layers = getattr(
+                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
+            )
+            self.assertEqual(len(hidden_states), expected_num_layers)
+
+            seq_length = self.model_tester.seq_length
+
+            self.assertListEqual(
+                list(hidden_states[0].shape[-2:]),
+                [seq_length, self.model_tester.hidden_size],
+            )
+
+        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+        for model_class in self.all_model_classes:
+            inputs_dict["output_hidden_states"] = True
+            check_hidden_states_output(inputs_dict, config, model_class)
+
+            # check that output_hidden_states also work using config
+            del inputs_dict["output_hidden_states"]
+            config.output_hidden_states = True
+
+            check_hidden_states_output(inputs_dict, config, model_class)
+
+    @is_pt_tf_cross_test
+    def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
+        self.has_attentions = False
+        super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
+
+    def test_keras_save_load(self):
+        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+        tf_main_layer_classes = {
+            module_member
+            for model_class in self.all_model_classes
+            for module in (import_module(model_class.__module__),)
+            for module_member_name in dir(module)
+            if module_member_name.endswith("MainLayer")
+            for module_member in (getattr(module, module_member_name),)
+            if isinstance(module_member, type)
+            and keras.layers.Layer in module_member.__bases__
+            and getattr(module_member, "_keras_serializable", False)
+        }
+
+        for main_layer_class in tf_main_layer_classes:
+            main_layer = main_layer_class(config)
+
+            symbolic_inputs = {
+                name: keras.Input(tensor.shape[1:], dtype=tensor.dtype, batch_size=2)
+                for name, tensor in inputs_dict.items()
+                if tf.is_tensor(tensor)
+            }
+            model = keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
+            outputs = model(inputs_dict)
+
+            with tempfile.TemporaryDirectory() as tmpdirname:
+                filepath = os.path.join(tmpdirname, "keras_model.h5")
+                model.save(filepath)
+                model = keras.models.load_model(filepath, custom_objects={main_layer_class.__name__: main_layer_class})
+                assert isinstance(model, keras.Model)
+                after_outputs = model(inputs_dict)
+                self.assert_outputs_same(after_outputs, outputs)
+
+    @unittest.skip(reason="IDEFICS test_keras_fit testing done in TFIdeficsForVisionText2TextTest")
+    def test_keras_fit(self):
+        pass
+
+    @slow
+    def test_model_from_pretrained(self):
+        model = TFIdeficsModel.from_pretrained(IDEFICS_TINY_RANDOM_MODEL, from_pt=True)
+        self.assertIsNotNone(model)
+
+    @unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.")
+    def test_saved_model_creation(self):
+        pass
+
+    @unittest.skip(reason="""IDEFICS loss computation not implemented yet""")
+    def test_loss_computation(self):
+        pass
+
+
+@require_tf
+class TFIdeficsForVisionText2TextTest(TFIdeficsModelTest, unittest.TestCase):
+    all_model_classes = (TFIdeficsForVisionText2Text,) if is_tf_available() else ()
+    test_resize_embeddings = False
+
+    def setUp(self):
+        self.model_tester = IdeficsModelTester(
+            self,
+            modality_type_vocab_size=3,
+        )
+        self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
+
+    @unittest.skip("We only test the model that takes in multiple images")
+    def test_model(self):
+        pass
+
+    @unittest.skip("We only test the model that takes in multiple images")
+    def test_for_token_classification(self):
+        pass
+
+    @unittest.skip(reason="""IDEFICS does not support retaining the gradients of the hidden states and attention""")
+    def test_retain_grad_hidden_states_attentions(self):
+        pass
+
+    @unittest.skip(reason="""IDEFICS loss computation not implemented yet""")
+    def test_loss_computation(self):
+        pass
+
+    @slow
+    def test_keras_fit(self):
+        super().test_keras_fit()
+
+
+# Below is the expected output for the integration test TFIdeficsModelIntegrationTest.
+# Since we are using tiny-random to be able to fit it on the CI GPU,it is better to assert on the
+# ids because the generated text is gibberish
+
+# fmt: off
+EXPECTED_GENERATED_IDS = [[0, 0, 1, 4911, 29901, 32000, 32001, 32000, 20355, 915, 445, 1967, 29889, 13, 7900, 22137, 29901, 530, 1967, 310, 1023, 26361, 29889, 13, 2659, 29901, 32000, 32001, 32000, 20355, 915, 445, 1967, 29889, 13, 7900, 22137, 29901, 25519, 22326, 8071, 26357, 28004, 4428, 5916, 14383, 1033, 12358, 10536, 21834, 10447, 21201, 18102, 16886, 8875, 25388, 25914, 28304, 8558, 31048, 1322, 25952, 189, 31600, 3600, 12824, 7045, 28090, 20228, 32001, 5385, 29186, 2165, 11822, 13825, 23077, 7883, 22504, 2078, 18893, 2179, 10556, 9515, 7672, 3491, 12403, 5398, 27299, 6463, 16349, 23037, 28956, 16960, 22664, 7724, 17587, 17424, 10175, 17417, 5930, 30855, 17695, 16170, 14474, 29996, 313, 14502, 3241, 13618, 32001, 5385, 29186, 2165, 11822, 13825, 19934, 4875, 27142, 3230, 2709, 28054, 3270, 19148, 10917, 1060, 26443, 12259, 1347, 28482, 3830, 25519, 199, 12782, 9144, 12289, 1142, 18400, 21390, 19129, 7292, 28430, 24711, 5551, 30349, 30533, 13271, 17697, 4982, 8713, 5380, 17869, 12490, 5398, 27299, 11593, 19918, 15924, 29430, 10175, 17417, 5930, 30855, 17695, 16170, 14474, 19234],
+                          [1, 4911, 29901, 32000, 32001, 32000, 20355, 915, 445, 1967, 29889, 13, 7900, 22137, 29901, 530, 1967, 310, 1023, 413, 986, 575, 29889, 13, 2659, 29901, 32000, 32001, 32000, 20355, 915, 445, 1967, 29889, 13, 7900, 22137, 29901, 25519, 22326, 8071, 26357, 28004, 4428, 17554, 20500, 21714, 27834, 4798, 12195, 30379, 5427, 20228, 10473, 14351, 8049, 15605, 14491, 212, 2711, 32000, 21714, 31259, 24368, 19036, 22970, 26083, 19394, 20372, 7672, 9939, 25388, 30533, 8200, 30271, 2114, 24749, 13224, 10603, 21118, 2179, 3759, 16515, 6587, 1287, 23998, 17793, 32001, 5385, 29186, 2165, 11822, 13825, 29732, 17503, 2729, 6722, 2943, 1221, 16043, 18244, 24965, 14383, 19840, 5980, 13488, 28531, 735, 26146, 22504, 2078, 18893, 20372, 7672, 32001, 5385, 29186, 2165, 11822, 13825, 29732, 17503, 2729, 6722, 19551, 220, 10528, 28940, 4453, 28266, 15416, 18693, 8199, 1153, 27706, 29231, 29186, 2165, 11822, 13825, 29732, 17503, 2729, 6722, 19551, 8231, 10739, 31992, 25906, 22254, 23127, 7689, 19614, 1149, 18844, 23037, 28956, 16960, 22664, 6975, 28938, 24002, 11026, 15020, 21964, 16307], ]
+
+@require_tf
+@require_vision
+class TFIdeficsModelIntegrationTest(TestCasePlus):
+    @cached_property
+    def default_processor(self):
+        return IdeficsProcessor.from_pretrained(IDEFICS_TINY_RANDOM_MODEL) if is_vision_available() else None
+
+    @slow
+    def test_inference_natural_language_visual_reasoning(self):
+        cat_image_path = self.tests_dir / "fixtures/tests_samples/COCO/000000039769.png"
+        cats_image_obj = Image.open(cat_image_path)  # 2 cats
+        dogs_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_nlvr2/raw/main/image1.jpeg"
+
+        prompts = [
+            [
+                "User:",
+                dogs_image_url,
+                "Describe this image.\nAssistant: An image of two dogs.\n",
+                "User:",
+                cats_image_obj,
+                "Describe this image.\nAssistant:",
+            ],
+            [
+                "User:",
+                cats_image_obj,
+                "Describe this image.\nAssistant: An image of two kittens.\n",
+                "User:",
+                dogs_image_url,
+                "Describe this image.\nAssistant:",
+            ],
+        ]
+
+        model = TFIdeficsForVisionText2Text.from_pretrained(IDEFICS_TINY_RANDOM_MODEL, from_pt=True)
+        processor = self.default_processor
+        inputs = processor(prompts, return_tensors="tf")
+        generated_ids = model.generate(**inputs, max_length=100)
+        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
+
+        # keep for debugging
+        for i, t in enumerate(generated_text):
+            t = bytes(t, "utf-8").decode("unicode_escape")
+            print(f"{i}:\n{t}\n")
+
+        self.assertListEqual(EXPECTED_GENERATED_IDS[0], generated_ids[0].numpy().tolist())
+        self.assertListEqual(EXPECTED_GENERATED_IDS[1], generated_ids[1].numpy().tolist())
diff --git a/tests/models/idefics/test_processor_idefics.py b/tests/models/idefics/test_processor_idefics.py
index 2e319413d4c5e2..26dcbb1c0f1566 100644
--- a/tests/models/idefics/test_processor_idefics.py
+++ b/tests/models/idefics/test_processor_idefics.py
@@ -41,7 +41,7 @@ def setUp(self):
 
         self.checkpoint_path = self.get_auto_remove_tmp_dir()
 
-        image_processor = IdeficsImageProcessor()
+        image_processor = IdeficsImageProcessor(return_tensors="pt")
         tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceM4/tiny-random-idefics")
 
         processor = IdeficsProcessor(image_processor, tokenizer)
@@ -132,7 +132,7 @@ def test_tokenizer_decode(self):
         image_processor = self.get_image_processor()
         tokenizer = self.get_tokenizer()
 
-        processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor)
+        processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor, return_tensors="pt")
 
         predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
 
@@ -145,7 +145,7 @@ def test_tokenizer_padding(self):
         image_processor = self.get_image_processor()
         tokenizer = self.get_tokenizer(padding_side="right")
 
-        processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor)
+        processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor, return_tensors="pt")
 
         predicted_tokens = [
             " Describe this image.\nAssistant:",
@@ -156,8 +156,9 @@ def test_tokenizer_padding(self):
             ([1] * 10) + ([0] * 10),
         ]
         prompts = [[prompt] for prompt in self.prepare_prompts()[2]]
-        max_length = processor(prompts, padding="max_length", truncation=True, max_length=20)
-        longest = processor(prompts, padding="longest", truncation=True, max_length=30)
+
+        max_length = processor(prompts, padding="max_length", truncation=True, max_length=20, return_tensors="pt")
+        longest = processor(prompts, padding="longest", truncation=True, max_length=30, return_tensors="pt")
 
         decoded_max_length = processor.tokenizer.decode(max_length["input_ids"][-1])
         decoded_longest = processor.tokenizer.decode(longest["input_ids"][-1])
@@ -203,7 +204,7 @@ def test_model_input_names(self):
         processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor)
         prompts = self.prepare_prompts()
 
-        inputs = processor(prompts, padding="longest")
+        inputs = processor(prompts, padding="longest", return_tensors="pt")
 
         # For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask']
         self.assertSetEqual(set(inputs.keys()), set(self.input_keys))
diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py
index f396875570c98d..2cf272f4aac10d 100644
--- a/tests/test_modeling_tf_common.py
+++ b/tests/test_modeling_tf_common.py
@@ -380,7 +380,9 @@ def test_keras_save_load(self):
                 main_layer = main_layer_class(config)
 
             symbolic_inputs = {
-                name: keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
+                name: keras.Input(tensor.shape[1:], dtype=tensor.dtype)
+                for name, tensor in inputs_dict.items()
+                if tf.is_tensor(tensor)
             }
 
             model = keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
@@ -1689,7 +1691,11 @@ def test_dataset_conversion(self):
                 tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
                 if "labels" not in tf_inputs_dict:
                     return  # This model isn't giving us labels after all, don't try training with it
-                tf_inputs_dict = {key: val for key, val in tf_inputs_dict.items() if "head_mask" not in key}
+                tf_inputs_dict = {
+                    key: val
+                    for key, val in tf_inputs_dict.items()
+                    if "head_mask" not in key and isinstance(val, tf.Tensor)
+                }
                 tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0]  # Use a random other tensor
                 input_dataset = Dataset.from_dict(tf_inputs_dict)
                 tf_dataset = model.prepare_tf_dataset(
@@ -1853,8 +1859,8 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
 
 def random_attention_mask(shape, rng=None, name=None, dtype=None):
     attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype)
-    # make sure that at least one token is attended to for each batch
-    attn_mask = tf.concat([attn_mask[:, :-1], tf.ones_like(attn_mask[:, -1:], dtype=dtype)], axis=-1)
+    # Mark the first token as 1 (matches behaviour of PyTorch/Flax function)
+    attn_mask = tf.concat([tf.ones_like(attn_mask[:, :1]), attn_mask[:, 1:]], axis=1)
     return attn_mask
 
 

From 2e27291ce4adbea9d2cb2f9bd6c43ec492e2bb5c Mon Sep 17 00:00:00 2001
From: Joao Gante 
Date: Mon, 13 May 2024 16:08:45 +0100
Subject: [PATCH 15/19] Generate: assistant should be greedy in assisted
 decoding (#30778)

* assistant should be greedy

* better comment

* Update src/transformers/generation/candidate_generator.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
---
 src/transformers/generation/candidate_generator.py | 6 ++++++
 src/transformers/generation/configuration_utils.py | 5 +++++
 2 files changed, 11 insertions(+)

diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py
index a958228d9be27f..52371d94dc56d1 100644
--- a/src/transformers/generation/candidate_generator.py
+++ b/src/transformers/generation/candidate_generator.py
@@ -150,6 +150,12 @@ def __init__(
         self.generation_config.return_dict_in_generate = True
         self.generation_config.output_scores = True
 
+        # Disable sampling -- this implementation of assisted generation/speculative decoding uses the assistant
+        # greedily to maximize matches. Disables sampling-related flags to prevent warnings
+        self.generation_config.do_sample = False
+        for attr in ("temperature", "top_p", "min_p", "typical_p", "top_k", "epsilon_cutoff", "eta_cutoff"):
+            setattr(self.generation_config, attr, None)
+
         # avoid unnecessary warnings that min_length is larger than max_new_tokens
         # remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
         self.main_model_min_length = self.generation_config.min_length
diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py
index 85fcc055948c41..2bdf20c68613e9 100644
--- a/src/transformers/generation/configuration_utils.py
+++ b/src/transformers/generation/configuration_utils.py
@@ -496,6 +496,11 @@ def validate(self, is_init=False):
                     greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
                     UserWarning,
                 )
+            if self.min_p is not None:
+                warnings.warn(
+                    greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p),
+                    UserWarning,
+                )
             if self.typical_p is not None and self.typical_p != 1.0:
                 warnings.warn(
                     greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),

From 82c1625ec3cb2c54f8a40cb00ca03eade3206e1b Mon Sep 17 00:00:00 2001
From: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Date: Mon, 13 May 2024 17:27:44 +0200
Subject: [PATCH 16/19] Save other CI jobs' result (torch/tf pipeline, example,
 deepspeed etc) (#30699)

* update

* update

* update

* update

* update

* update

* update

* update

* Update utils/notification_service.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: ydshieh 
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
---
 .github/workflows/slack-report.yml         | 15 +++++---
 utils/notification_service.py              | 41 +++++++++++++++-------
 utils/notification_service_quantization.py |  7 ++++
 3 files changed, 46 insertions(+), 17 deletions(-)

diff --git a/.github/workflows/slack-report.yml b/.github/workflows/slack-report.yml
index 77cfdc8c140241..0d1197a05d122a 100644
--- a/.github/workflows/slack-report.yml
+++ b/.github/workflows/slack-report.yml
@@ -60,12 +60,10 @@ jobs:
 
       # Upload complete failure tables, as they might be big and only truncated versions could be sent to Slack.
       - name: Failure table artifacts
-        # Only the model testing job is concerned for this step
-        if: ${{ inputs.job == 'run_models_gpu' }}
         uses: actions/upload-artifact@v4
         with:
-          name: ci_results
-          path: ci_results
+          name: ci_results_${{ inputs.job }}
+          path: ci_results_${{ inputs.job }}
       
       - uses: actions/checkout@v4
       - uses: actions/download-artifact@v4
@@ -77,6 +75,7 @@ jobs:
           SLACK_REPORT_CHANNEL: ${{ inputs.slack_report_channel }}
           CI_EVENT: scheduled
           CI_SHA: ${{ github.sha }}
+          CI_TEST_JOB: ${{ inputs.job }}
           SETUP_STATUS: ${{ inputs.setup_status }}
         # We pass `needs.setup.outputs.quantization_matrix` as the argument. A processing in `notification_service_quantization.py` to change
         # `quantization/bnb` to `quantization_bnb` is required, as the artifact names use `_` instead of `/`.
@@ -85,3 +84,11 @@ jobs:
           pip install slack_sdk
           pip show slack_sdk
           python utils/notification_service_quantization.py "${{ inputs.quantization_matrix }}" 
+
+      # Upload complete failure tables, as they might be big and only truncated versions could be sent to Slack.
+      - name: Failure table artifacts
+        if: ${{ inputs.job == 'run_quantization_torch_gpu' }}
+        uses: actions/upload-artifact@v4
+        with:
+          name: ci_results_${{ inputs.job }}
+          path: ci_results_${{ inputs.job }}
\ No newline at end of file
diff --git a/utils/notification_service.py b/utils/notification_service.py
index 0598278368cb4b..cf126cd68a3385 100644
--- a/utils/notification_service.py
+++ b/utils/notification_service.py
@@ -416,7 +416,7 @@ def per_model_sum(model_category_dict):
             reports=sorted_model_reports,
             to_truncate=False,
         )
-        file_path = os.path.join(os.getcwd(), "ci_results/model_failures_report.txt")
+        file_path = os.path.join(os.getcwd(), f"ci_results_{job_name}/model_failures_report.txt")
         with open(file_path, "w", encoding="UTF-8") as fp:
             fp.write(model_failures_report)
 
@@ -426,18 +426,18 @@ def per_model_sum(model_category_dict):
             reports=sorted_module_reports,
             to_truncate=False,
         )
-        file_path = os.path.join(os.getcwd(), "ci_results/module_failures_report.txt")
+        file_path = os.path.join(os.getcwd(), f"ci_results_{job_name}/module_failures_report.txt")
         with open(file_path, "w", encoding="UTF-8") as fp:
             fp.write(module_failures_report)
 
         if self.prev_ci_artifacts is not None:
-            # if the last run produces artifact named `ci_results`
+            # if the last run produces artifact named `ci_results_{job_name}`
             if (
-                "ci_results" in self.prev_ci_artifacts
-                and "model_failures_report.txt" in self.prev_ci_artifacts["ci_results"]
+                f"ci_results_{job_name}" in self.prev_ci_artifacts
+                and "model_failures_report.txt" in self.prev_ci_artifacts[f"ci_results_{job_name}"]
             ):
                 # Compute the difference of the previous/current (model failure) table
-                prev_model_failures = self.prev_ci_artifacts["ci_results"]["model_failures_report.txt"]
+                prev_model_failures = self.prev_ci_artifacts[f"ci_results_{job_name}"]["model_failures_report.txt"]
                 entries_changed = self.compute_diff_for_failure_reports(model_failures_report, prev_model_failures)
                 if len(entries_changed) > 0:
                     # Save the complete difference
@@ -447,7 +447,7 @@ def per_model_sum(model_category_dict):
                         reports=entries_changed,
                         to_truncate=False,
                     )
-                    file_path = os.path.join(os.getcwd(), "ci_results/changed_model_failures_report.txt")
+                    file_path = os.path.join(os.getcwd(), f"ci_results_{job_name}/changed_model_failures_report.txt")
                     with open(file_path, "w", encoding="UTF-8") as fp:
                         fp.write(diff_report)
 
@@ -643,8 +643,11 @@ def get_new_model_failure_blocks(self, with_header=True):
         sorted_dict = sorted(self.model_results.items(), key=lambda t: t[0])
 
         prev_model_results = {}
-        if "ci_results" in self.prev_ci_artifacts and "model_results.json" in self.prev_ci_artifacts["ci_results"]:
-            prev_model_results = json.loads(self.prev_ci_artifacts["ci_results"]["model_results.json"])
+        if (
+            f"ci_results_{job_name}" in self.prev_ci_artifacts
+            and "model_results.json" in self.prev_ci_artifacts[f"ci_results_{job_name}"]
+        ):
+            prev_model_results = json.loads(self.prev_ci_artifacts[f"ci_results_{job_name}"]["model_results.json"])
 
         all_failure_lines = {}
         for job, job_result in sorted_dict:
@@ -1139,20 +1142,32 @@ def prepare_reports(title, header, reports, to_truncate=True):
             with open(os.path.join(directory, "selected_warnings.json")) as fp:
                 selected_warnings = json.load(fp)
 
-    if not os.path.isdir(os.path.join(os.getcwd(), "ci_results")):
-        os.makedirs(os.path.join(os.getcwd(), "ci_results"))
+    if not os.path.isdir(os.path.join(os.getcwd(), f"ci_results_{job_name}")):
+        os.makedirs(os.path.join(os.getcwd(), f"ci_results_{job_name}"))
 
     # Only the model testing job is concerned: this condition is to avoid other jobs to upload the empty list as
     # results.
     if job_name == "run_models_gpu":
-        with open("ci_results/model_results.json", "w", encoding="UTF-8") as fp:
+        with open(f"ci_results_{job_name}/model_results.json", "w", encoding="UTF-8") as fp:
             json.dump(model_results, fp, indent=4, ensure_ascii=False)
 
+    # Must have the same keys as in `additional_results`.
+    # The values are used as the file names where to save the corresponding CI job results.
+    test_to_result_name = {
+        "PyTorch pipelines": "torch_pipeline",
+        "TensorFlow pipelines": "tf_pipeline",
+        "Examples directory": "example",
+        "Torch CUDA extension tests": "deepspeed",
+    }
+    for job, job_result in additional_results.items():
+        with open(f"ci_results_{job_name}/{test_to_result_name[job]}_results.json", "w", encoding="UTF-8") as fp:
+            json.dump(job_result, fp, indent=4, ensure_ascii=False)
+
     prev_ci_artifacts = None
     target_workflow = "huggingface/transformers/.github/workflows/self-scheduled.yml@refs/heads/main"
     if os.environ.get("CI_WORKFLOW_REF") == target_workflow:
         # Get the last previously completed CI's failure tables
-        artifact_names = ["ci_results"]
+        artifact_names = [f"ci_results_{job_name}"]
         output_dir = os.path.join(os.getcwd(), "previous_reports")
         os.makedirs(output_dir, exist_ok=True)
         prev_ci_artifacts = get_last_daily_ci_reports(
diff --git a/utils/notification_service_quantization.py b/utils/notification_service_quantization.py
index 1687eeaa25f32f..6d026bc0d053dc 100644
--- a/utils/notification_service_quantization.py
+++ b/utils/notification_service_quantization.py
@@ -242,6 +242,13 @@ def post_reply(self):
                             {"line": line, "trace": stacktraces.pop(0)}
                         )
 
+    job_name = os.getenv("CI_TEST_JOB")
+    if not os.path.isdir(os.path.join(os.getcwd(), f"ci_results_{job_name}")):
+        os.makedirs(os.path.join(os.getcwd(), f"ci_results_{job_name}"))
+
+    with open(f"ci_results_{job_name}/quantization_results.json", "w", encoding="UTF-8") as fp:
+        json.dump(quantization_results, fp, indent=4, ensure_ascii=False)
+
     message = QuantizationMessage(
         title,
         results=quantization_results,

From 0f8fefd4818a575dd749ea919abf4543afffbabf Mon Sep 17 00:00:00 2001
From: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Date: Mon, 13 May 2024 16:30:55 +0100
Subject: [PATCH 17/19] Deprecate models script (#30184)

* Add utility for finding candidate models for deprecation

* Update model init

* Make into configurable script

* Fix path

* Add sorting of base object alphabetically

* Tidy

* Refactor __init__ alpha ordering

* Update script with logging

* fix import

* Fix logger

* Fix logger

* Get config file before moving files

* Take models from CLI

* Split models into lines to make easier to feed to deprecate_models script

* Update

* Use posix path

* Print instead

* Add example in module docstring

* Fix up

* Add clarifying comments; add models to DEPRECATE_MODELS

* Address PR comments

* Don't update relative paths on the same level
---
 utils/deprecate_models.py    | 357 +++++++++++++++++++++++++++++++++++
 utils/models_to_deprecate.py |   4 +-
 2 files changed, 359 insertions(+), 2 deletions(-)
 create mode 100644 utils/deprecate_models.py

diff --git a/utils/deprecate_models.py b/utils/deprecate_models.py
new file mode 100644
index 00000000000000..d5160e93842095
--- /dev/null
+++ b/utils/deprecate_models.py
@@ -0,0 +1,357 @@
+"""
+Script which deprecates a list of given models
+
+Example usage:
+python utils/deprecate_models.py --models bert distilbert
+"""
+
+import argparse
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Optional, Tuple
+
+import requests
+from custom_init_isort import sort_imports_in_all_inits
+from git import Repo
+from packaging import version
+
+from transformers import CONFIG_MAPPING, logging
+from transformers import __version__ as current_version
+
+
+REPO_PATH = Path(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
+repo = Repo(REPO_PATH)
+
+logger = logging.get_logger(__name__)
+
+
+def get_last_stable_minor_release():
+    # Get the last stable release of transformers
+    url = "https://pypi.org/pypi/transformers/json"
+    release_data = requests.get(url).json()
+
+    # Find the last stable release of of transformers (version below current version)
+    major_version, minor_version, patch_version, _ = current_version.split(".")
+    last_major_minor = f"{major_version}.{int(minor_version) - 1}"
+    last_stable_minor_releases = [
+        release for release in release_data["releases"] if release.startswith(last_major_minor)
+    ]
+    last_stable_release = sorted(last_stable_minor_releases, key=version.parse)[-1]
+
+    return last_stable_release
+
+
+def build_tip_message(last_stable_release):
+    return (
+        """
+    
+
+    This model is in maintenance mode only, we don't accept any new PRs changing its code.
+    """
+        + f"""If you run into any issues running this model, please reinstall the last version that supported this model: v{last_stable_release}.
+    You can do so by running the following command: `pip install -U transformers=={last_stable_release}`.
+
+    """
+    )
+
+
+def insert_tip_to_model_doc(model_doc_path, tip_message):
+    tip_message_lines = tip_message.split("\n")
+
+    with open(model_doc_path, "r") as f:
+        model_doc = f.read()
+
+    # Add the tip message to the model doc page directly underneath the title
+    lines = model_doc.split("\n")
+
+    new_model_lines = []
+    for line in lines:
+        if line.startswith("# "):
+            new_model_lines.append(line)
+            new_model_lines.extend(tip_message_lines)
+        else:
+            new_model_lines.append(line)
+
+    with open(model_doc_path, "w") as f:
+        f.write("\n".join(new_model_lines))
+
+
+def get_model_doc_path(model: str) -> Tuple[Optional[str], Optional[str]]:
+    # Possible variants of the model name in the model doc path
+    model_doc_paths = [
+        REPO_PATH / f"docs/source/en/model_doc/{model}.md",
+        # Try replacing _ with - in the model name
+        REPO_PATH / f"docs/source/en/model_doc/{model.replace('_', '-')}.md",
+        # Try replacing _ with "" in the model name
+        REPO_PATH / f"docs/source/en/model_doc/{model.replace('_', '')}.md",
+    ]
+
+    for model_doc_path in model_doc_paths:
+        if os.path.exists(model_doc_path):
+            return model_doc_path, model
+
+    return None, None
+
+
+def extract_model_info(model):
+    model_info = {}
+    model_doc_path, model_doc_name = get_model_doc_path(model)
+    model_path = REPO_PATH / f"src/transformers/models/{model}"
+
+    if model_doc_path is None:
+        print(f"Model doc path does not exist for {model}")
+        return None
+    model_info["model_doc_path"] = model_doc_path
+    model_info["model_doc_name"] = model_doc_name
+
+    if not os.path.exists(model_path):
+        print(f"Model path does not exist for {model}")
+        return None
+    model_info["model_path"] = model_path
+
+    return model_info
+
+
+def update_relative_imports(filename, model):
+    with open(filename, "r") as f:
+        filelines = f.read()
+
+    new_file_lines = []
+    for line in filelines.split("\n"):
+        if line.startswith("from .."):
+            new_file_lines.append(line.replace("from ..", "from ..."))
+        else:
+            new_file_lines.append(line)
+
+    with open(filename, "w") as f:
+        f.write("\n".join(new_file_lines))
+
+
+def move_model_files_to_deprecated(model):
+    model_path = REPO_PATH / f"src/transformers/models/{model}"
+    deprecated_model_path = REPO_PATH / f"src/transformers/models/deprecated/{model}"
+
+    if not os.path.exists(deprecated_model_path):
+        os.makedirs(deprecated_model_path)
+
+    for file in os.listdir(model_path):
+        if file == "__pycache__":
+            continue
+        repo.git.mv(f"{model_path}/{file}", f"{deprecated_model_path}/{file}")
+
+        # For deprecated files, we then need to update the relative imports
+        update_relative_imports(f"{deprecated_model_path}/{file}", model)
+
+
+def delete_model_tests(model):
+    tests_path = REPO_PATH / f"tests/models/{model}"
+
+    if os.path.exists(tests_path):
+        repo.git.rm("-r", tests_path)
+
+
+def get_line_indent(s):
+    return len(s) - len(s.lstrip())
+
+
+def update_main_init_file(models):
+    """
+    Replace all instances of model.model_name with model.deprecated.model_name in the __init__.py file
+
+    Args:
+        models (List[str]): The models to mark as deprecated
+    """
+    filename = REPO_PATH / "src/transformers/__init__.py"
+    with open(filename, "r") as f:
+        init_file = f.read()
+
+    # 1. For each model, find all the instances of model.model_name and replace with model.deprecated.model_name
+    for model in models:
+        init_file = init_file.replace(f"models.{model}", f"models.deprecated.{model}")
+
+    with open(filename, "w") as f:
+        f.write(init_file)
+
+    # 2. Resort the imports
+    sort_imports_in_all_inits(check_only=False)
+
+
+def remove_model_references_from_file(filename, models, condition):
+    """
+    Remove all references to the given models from the given file
+
+    Args:
+        filename (str): The file to remove the references from
+        models (List[str]): The models to remove
+        condition (Callable): A function that takes the line and model and returns True if the line should be removed
+    """
+    with open(filename, "r") as f:
+        init_file = f.read()
+
+    new_file_lines = []
+    for i, line in enumerate(init_file.split("\n")):
+        if any(condition(line, model) for model in models):
+            continue
+        new_file_lines.append(line)
+
+    with open(filename, "w") as f:
+        f.write("\n".join(new_file_lines))
+
+
+def remove_model_config_classes_from_config_check(model_config_classes):
+    """
+    Remove the deprecated model config classes from the check_config_attributes.py file
+
+    Args:
+        model_config_classes (List[str]): The model config classes to remove e.g. ["BertConfig", "DistilBertConfig"]
+    """
+    filename = REPO_PATH / "utils/check_config_attributes.py"
+    with open(filename, "r") as f:
+        check_config_attributes = f.read()
+
+    # Keep track as we have to delete comment above too
+    in_special_cases_to_allow = False
+    in_indent = False
+    new_file_lines = []
+
+    for line in check_config_attributes.split("\n"):
+        indent = get_line_indent(line)
+        if (line.strip() == "SPECIAL_CASES_TO_ALLOW = {") or (line.strip() == "SPECIAL_CASES_TO_ALLOW.update("):
+            in_special_cases_to_allow = True
+
+        elif in_special_cases_to_allow and indent == 0 and line.strip() in ("}", ")"):
+            in_special_cases_to_allow = False
+
+        if in_indent:
+            if line.strip().endswith(("]", "],")):
+                in_indent = False
+            continue
+
+        if in_special_cases_to_allow and any(
+            model_config_class in line for model_config_class in model_config_classes
+        ):
+            # Remove comments above the model config class to remove
+            while new_file_lines[-1].strip().startswith("#"):
+                new_file_lines.pop()
+
+            if line.strip().endswith("["):
+                in_indent = True
+
+            continue
+
+        elif any(model_config_class in line for model_config_class in model_config_classes):
+            continue
+
+        new_file_lines.append(line)
+
+    with open(filename, "w") as f:
+        f.write("\n".join(new_file_lines))
+
+
+def add_models_to_deprecated_models_in_config_auto(models):
+    """
+    Add the models to the DEPRECATED_MODELS list in configuration_auto.py and sorts the list
+    to be in alphabetical order.
+    """
+    filepath = REPO_PATH / "src/transformers/models/auto/configuration_auto.py"
+    with open(filepath, "r") as f:
+        config_auto = f.read()
+
+    new_file_lines = []
+    deprecated_models_list = []
+    in_deprecated_models = False
+    for line in config_auto.split("\n"):
+        if line.strip() == "DEPRECATED_MODELS = [":
+            in_deprecated_models = True
+            new_file_lines.append(line)
+        elif in_deprecated_models and line.strip() == "]":
+            in_deprecated_models = False
+            # Add the new models to deprecated models list
+            deprecated_models_list.extend([f'"{model},"' for model in models])
+            # Sort so they're in alphabetical order in the file
+            deprecated_models_list = sorted(deprecated_models_list)
+            new_file_lines.extend(deprecated_models_list)
+            # Make sure we still have the closing bracket
+            new_file_lines.append(line)
+        elif in_deprecated_models:
+            deprecated_models_list.append(line.strip())
+        else:
+            new_file_lines.append(line)
+
+    with open(filepath, "w") as f:
+        f.write("\n".join(new_file_lines))
+
+
+def deprecate_models(models):
+    # Get model info
+    skipped_models = []
+    models_info = defaultdict(dict)
+    for model in models:
+        single_model_info = extract_model_info(model)
+        if single_model_info is None:
+            skipped_models.append(model)
+        else:
+            models_info[model] = single_model_info
+
+    model_config_classes = []
+    for model, model_info in models_info.items():
+        if model in CONFIG_MAPPING:
+            model_config_classes.append(CONFIG_MAPPING[model].__name__)
+        elif model_info["model_doc_name"] in CONFIG_MAPPING:
+            model_config_classes.append(CONFIG_MAPPING[model_info["model_doc_name"]].__name__)
+        else:
+            skipped_models.append(model)
+            print(f"Model config class not found for model: {model}")
+
+    # Filter out skipped models
+    models = [model for model in models if model not in skipped_models]
+
+    if skipped_models:
+        print(f"Skipped models: {skipped_models} as the model doc or model path could not be found.")
+    print(f"Models to deprecate: {models}")
+
+    # Remove model config classes from config check
+    print("Removing model config classes from config checks")
+    remove_model_config_classes_from_config_check(model_config_classes)
+
+    tip_message = build_tip_message(get_last_stable_minor_release())
+
+    for model, model_info in models_info.items():
+        print(f"Processing model: {model}")
+        # Add the tip message to the model doc page directly underneath the title
+        print("Adding tip message to model doc page")
+        insert_tip_to_model_doc(model_info["model_doc_path"], tip_message)
+
+        # Move the model file to deprecated: src/transfomers/models/model -> src/transformers/models/deprecated/model
+        print("Moving model files to deprecated for model")
+        move_model_files_to_deprecated(model)
+
+        # Delete the model tests: tests/models/model
+        print("Deleting model tests")
+        delete_model_tests(model)
+
+    # # We do the following with all models passed at once to avoid having to re-write the file multiple times
+    print("Updating __init__.py file to point to the deprecated models")
+    update_main_init_file(models)
+
+    # Remove model references from other files
+    print("Removing model references from other files")
+    remove_model_references_from_file(
+        "src/transformers/models/__init__.py", models, lambda line, model: model == line.strip().strip(",")
+    )
+    remove_model_references_from_file(
+        "utils/slow_documentation_tests.txt", models, lambda line, model: "/" + model + "/" in line
+    )
+    remove_model_references_from_file("utils/not_doctested.txt", models, lambda line, model: "/" + model + "/" in line)
+
+    # Add models to DEPRECATED_MODELS in the configuration_auto.py
+    print("Adding models to DEPRECATED_MODELS in configuration_auto.py")
+    add_models_to_deprecated_models_in_config_auto(models)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--models", nargs="+", help="List of models to deprecate")
+    args = parser.parse_args()
+    deprecate_models(args.models)
diff --git a/utils/models_to_deprecate.py b/utils/models_to_deprecate.py
index dcf191aa060f31..ebdecf22eb8a68 100644
--- a/utils/models_to_deprecate.py
+++ b/utils/models_to_deprecate.py
@@ -11,7 +11,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 """
 Script to find a candidate list of models to deprecate based on the number of downloads and the date of the last commit.
 """
@@ -149,7 +148,7 @@ def get_list_of_models_to_deprecate(
         with open("models_info.json", "w") as f:
             json.dump(models_info, f, indent=4)
 
-    print("\nModels to deprecate:")
+    print("\nFinding models to deprecate:")
     n_models_to_deprecate = 0
     models_to_deprecate = {}
     for model, info in models_info.items():
@@ -160,6 +159,7 @@ def get_list_of_models_to_deprecate(
             print(f"\nModel: {model}")
             print(f"Downloads: {n_downloads}")
             print(f"Date: {info['first_commit_datetime']}")
+    print("\nModels to deprecate: ", "\n" + "\n".join(models_to_deprecate.keys()))
     print(f"\nNumber of models to deprecate: {n_models_to_deprecate}")
     print("Before deprecating make sure to verify the models, including if they're used as a module in other models.")
 

From 539ed75d501f9af87b28dc0325b8d2f62b791cd9 Mon Sep 17 00:00:00 2001
From: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Date: Mon, 13 May 2024 18:00:43 +0200
Subject: [PATCH 18/19] skip low_cpu_mem_usage tests (#30782)

---
 tests/test_modeling_common.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index daa438e9f1bdea..cd46934b5fcfe4 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -21,6 +21,7 @@
 import random
 import re
 import tempfile
+import unittest
 import warnings
 from collections import defaultdict
 from typing import Dict, List, Tuple
@@ -440,6 +441,7 @@ class CopyClass(model_class):
     @slow
     @require_accelerate
     @mark.accelerate_tests
+    @unittest.skip("Need to fix since we have a device mismatch")
     def test_save_load_low_cpu_mem_usage(self):
         config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
         with tempfile.TemporaryDirectory() as saved_model_path:
@@ -452,6 +454,7 @@ def test_save_load_low_cpu_mem_usage(self):
     @slow
     @require_accelerate
     @mark.accelerate_tests
+    @unittest.skip("Need to fix since we have a device mismatch")
     def test_save_load_low_cpu_mem_usage_checkpoints(self):
         config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
         with tempfile.TemporaryDirectory() as saved_model_path:
@@ -465,6 +468,7 @@ def test_save_load_low_cpu_mem_usage_checkpoints(self):
     @slow
     @require_accelerate
     @mark.accelerate_tests
+    @unittest.skip("Need to fix since we have a device mismatch")
     def test_save_load_low_cpu_mem_usage_no_safetensors(self):
         with tempfile.TemporaryDirectory() as saved_model_path:
             for model_class in self.all_model_classes:

From 37bba2a32d2742a10216ffd925bb8f145a732ce1 Mon Sep 17 00:00:00 2001
From: fxmarty <9808326+fxmarty@users.noreply.github.com>
Date: Mon, 13 May 2024 18:14:36 +0200
Subject: [PATCH 19/19] CI: update to ROCm 6.0.2 and test MI300 (#30266)

* update to ROCm 6.0.2 and test MI300

* add callers for mi300

* update dockerfile

* fix trainer tests

* remove apex

* style

* Update tests/trainer/test_trainer_seq2seq.py

* Update tests/trainer/test_trainer_seq2seq.py

* Update tests/trainer/test_trainer_seq2seq.py

* Update tests/trainer/test_trainer_seq2seq.py

* update to torch 2.3

* add workflow dispatch target

* we may need branches: mi300-ci after all

* nit

* fix docker build

* nit

* add check runner

* remove docker-gpu

* fix issues

* fix

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: ydshieh 
---
 .../workflows/self-push-amd-mi300-caller.yml  |  25 ++++
 .github/workflows/self-push-amd.yml           |   8 +-
 .../self-scheduled-amd-mi210-caller.yml       |   1 +
 .../self-scheduled-amd-mi250-caller.yml       |   1 +
 .../self-scheduled-amd-mi300-caller.yml       |  21 +++
 .github/workflows/self-scheduled-amd.yml      |  26 ++--
 .../transformers-pytorch-amd-gpu/Dockerfile   |  19 +--
 docs/source/en/perf_infer_gpu_one.md          |   2 +-
 .../integrations/integration_utils.py         |   5 +
 src/transformers/training_args.py             |   7 +
 tests/extended/test_trainer_ext.py            |   1 +
 tests/trainer/test_trainer.py                 | 125 +++++++++++-------
 tests/trainer/test_trainer_distributed.py     |   2 +-
 tests/trainer/test_trainer_seq2seq.py         |   8 +-
 14 files changed, 170 insertions(+), 81 deletions(-)
 create mode 100644 .github/workflows/self-push-amd-mi300-caller.yml
 create mode 100644 .github/workflows/self-scheduled-amd-mi300-caller.yml

diff --git a/.github/workflows/self-push-amd-mi300-caller.yml b/.github/workflows/self-push-amd-mi300-caller.yml
new file mode 100644
index 00000000000000..a8ee4e540ecf3f
--- /dev/null
+++ b/.github/workflows/self-push-amd-mi300-caller.yml
@@ -0,0 +1,25 @@
+name: Self-hosted runner (AMD mi300 CI caller)
+
+on:
+  workflow_run:
+    workflows: ["Self-hosted runner (push-caller)"]
+    branches: ["main"]
+    types: [completed]
+  push:
+    branches:
+      - run_amd_push_ci_caller*
+    paths:
+      - "src/**"
+      - "tests/**"
+      - ".github/**"
+      - "templates/**"
+      - "utils/**"
+
+jobs:
+  run_amd_ci:
+    name: AMD mi300
+    if: (cancelled() != true) && ((github.event_name == 'workflow_run') || ((github.event_name == 'push') && (startsWith(github.ref_name, 'run_amd_push_ci_caller') || startsWith(github.ref_name, 'mi300-ci'))))
+    uses: ./.github/workflows/self-push-amd.yml
+    with:
+      gpu_flavor: mi300
+    secrets: inherit
diff --git a/.github/workflows/self-push-amd.yml b/.github/workflows/self-push-amd.yml
index 8705f398b2b510..8d68002e329418 100644
--- a/.github/workflows/self-push-amd.yml
+++ b/.github/workflows/self-push-amd.yml
@@ -36,7 +36,7 @@ jobs:
     strategy:
       matrix:
         machine_type: [single-gpu, multi-gpu]
-    runs-on: [self-hosted, docker-gpu, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
+    runs-on: [self-hosted, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
     container:
       image: huggingface/transformers-pytorch-amd-gpu-push-ci  # <--- We test only for PyTorch for now
       options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
@@ -57,7 +57,7 @@ jobs:
     strategy:
       matrix:
         machine_type: [single-gpu, multi-gpu]
-    runs-on: [self-hosted, docker-gpu, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
+    runs-on: [self-hosted, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
     container:
       image: huggingface/transformers-pytorch-amd-gpu-push-ci  # <--- We test only for PyTorch for now
       options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
@@ -155,7 +155,7 @@ jobs:
       matrix:
         folders: ${{ fromJson(needs.setup_gpu.outputs.matrix) }}
         machine_type: [single-gpu, multi-gpu]
-    runs-on: [self-hosted, docker-gpu, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
+    runs-on: [self-hosted, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
     container:
       image: huggingface/transformers-pytorch-amd-gpu-push-ci  # <--- We test only for PyTorch for now
       options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
@@ -230,7 +230,7 @@ jobs:
       - name: Run all non-slow selected tests on GPU
         working-directory: /transformers
         run: |
-          python3 -m pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports ${{ fromJson(needs.setup_gpu.outputs.test_map)[matrix.folders] }}
+          python3 -m pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports ${{ fromJson(needs.setup_gpu.outputs.test_map)[matrix.folders] }} -m "not not_device_test"
 
       - name: Failure short reports
         if: ${{ failure() }}
diff --git a/.github/workflows/self-scheduled-amd-mi210-caller.yml b/.github/workflows/self-scheduled-amd-mi210-caller.yml
index cdb968901058b6..6abba6894aaffa 100644
--- a/.github/workflows/self-scheduled-amd-mi210-caller.yml
+++ b/.github/workflows/self-scheduled-amd-mi210-caller.yml
@@ -16,4 +16,5 @@ jobs:
     uses: ./.github/workflows/self-scheduled-amd.yml
     with:
       gpu_flavor: mi210
+      slack_report_channel: "#transformers-ci-daily-amd"
     secrets: inherit
diff --git a/.github/workflows/self-scheduled-amd-mi250-caller.yml b/.github/workflows/self-scheduled-amd-mi250-caller.yml
index dc7d12f173935e..36365d4a67f1e2 100644
--- a/.github/workflows/self-scheduled-amd-mi250-caller.yml
+++ b/.github/workflows/self-scheduled-amd-mi250-caller.yml
@@ -16,4 +16,5 @@ jobs:
     uses: ./.github/workflows/self-scheduled-amd.yml
     with:
       gpu_flavor: mi250
+      slack_report_channel: "#transformers-ci-daily-amd"
     secrets: inherit
diff --git a/.github/workflows/self-scheduled-amd-mi300-caller.yml b/.github/workflows/self-scheduled-amd-mi300-caller.yml
new file mode 100644
index 00000000000000..a9e7b934c34b77
--- /dev/null
+++ b/.github/workflows/self-scheduled-amd-mi300-caller.yml
@@ -0,0 +1,21 @@
+name: Self-hosted runner (AMD mi300 scheduled CI caller)
+
+on:
+  workflow_run:
+    workflows: ["Self-hosted runner (AMD scheduled CI caller)"]
+    branches: ["main"]
+    types: [completed]
+  push:
+    branches:
+      - run_amd_scheduled_ci_caller*
+
+jobs:
+  run_amd_ci:
+    name: AMD mi300
+    needs: build-docker-containers
+    if: (cancelled() != true) && ((github.event_name == 'workflow_run') || ((github.event_name == 'push') && (startsWith(github.ref_name, 'run_amd_push_ci_caller') || startsWith(github.ref_name, 'mi300-ci'))))
+    uses: ./.github/workflows/self-scheduled-amd.yml
+    with:
+      gpu_flavor: mi300
+      slack_report_channel: "#transformers-ci-daily-amd"
+    secrets: inherit
diff --git a/.github/workflows/self-scheduled-amd.yml b/.github/workflows/self-scheduled-amd.yml
index d2ab90d1331848..e9f280f51ab43d 100644
--- a/.github/workflows/self-scheduled-amd.yml
+++ b/.github/workflows/self-scheduled-amd.yml
@@ -34,7 +34,7 @@ jobs:
           fetch-depth: 2
 
       - name: Check Runner Status
-        run: python utils/check_self_hosted_runner.py --target_runners hf-amd-mi210-ci-1gpu-1,hf-amd-mi250-ci-1gpu-1 --token ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
+        run: python utils/check_self_hosted_runner.py --target_runners hf-amd-mi210-ci-1gpu-1,hf-amd-mi250-ci-1gpu-1,hf-amd-mi300-ci-1gpu-1 --token ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
 
   check_runners:
     name: Check Runners
@@ -42,7 +42,7 @@ jobs:
     strategy:
       matrix:
         machine_type: [single-gpu, multi-gpu]
-    runs-on: [self-hosted, docker-gpu, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
+    runs-on: [self-hosted, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
     container:
       image: huggingface/transformers-pytorch-amd-gpu
       options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
@@ -63,7 +63,7 @@ jobs:
     strategy:
       matrix:
         machine_type: [single-gpu, multi-gpu]
-    runs-on: [self-hosted, docker-gpu, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
+    runs-on: [self-hosted, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
     container:
       image: huggingface/transformers-pytorch-amd-gpu
       options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
@@ -116,7 +116,7 @@ jobs:
       matrix:
         folders: ${{ fromJson(needs.setup.outputs.matrix) }}
         machine_type: [single-gpu]
-    runs-on: [self-hosted, docker-gpu, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
+    runs-on: [self-hosted, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
     container:
       image: huggingface/transformers-pytorch-amd-gpu
       options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
@@ -162,7 +162,7 @@ jobs:
 
       - name: Run all tests on GPU
         working-directory: /transformers
-        run: python3 -m pytest -v --make-reports=${{ matrix.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports tests/${{ matrix.folders }}
+        run: python3 -m pytest -v --make-reports=${{ matrix.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports tests/${{ matrix.folders }} -m "not not_device_test"
 
       - name: Failure short reports
         if: ${{ failure() }}
@@ -184,7 +184,7 @@ jobs:
       matrix:
         folders: ${{ fromJson(needs.setup.outputs.matrix) }}
         machine_type: [multi-gpu]
-    runs-on: [self-hosted, docker-gpu, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
+    runs-on: [self-hosted, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
     container:
       image: huggingface/transformers-pytorch-amd-gpu
       options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
@@ -230,7 +230,7 @@ jobs:
 
       - name: Run all tests on GPU
         working-directory: /transformers
-        run: python3 -m pytest -v --make-reports=${{ matrix.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports tests/${{ matrix.folders }}
+        run: python3 -m pytest -v --make-reports=${{ matrix.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports tests/${{ matrix.folders }} -m "not not_device_test"
 
       - name: Failure short reports
         if: ${{ failure() }}
@@ -250,7 +250,7 @@ jobs:
       fail-fast: false
       matrix:
         machine_type: [single-gpu]
-    runs-on: [self-hosted, docker-gpu, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
+    runs-on: [self-hosted, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
     container:
       image: huggingface/transformers-pytorch-amd-gpu
       options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
@@ -287,7 +287,7 @@ jobs:
         working-directory: /transformers
         run: |
           pip install -r examples/pytorch/_tests_requirements.txt
-          python3 -m pytest -v --make-reports=${{ matrix.machine_type }}_run_examples_gpu_test_reports examples/pytorch
+          python3 -m pytest -v --make-reports=${{ matrix.machine_type }}_run_examples_gpu_test_reports examples/pytorch -m "not not_device_test"
 
       - name: Failure short reports
         if: ${{ failure() }}
@@ -307,7 +307,7 @@ jobs:
       fail-fast: false
       matrix:
         machine_type: [single-gpu, multi-gpu]
-    runs-on: [self-hosted, docker-gpu, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
+    runs-on: [self-hosted, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
     container:
       image: huggingface/transformers-pytorch-amd-gpu
       options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
@@ -343,7 +343,7 @@ jobs:
       - name: Run all pipeline tests on GPU
         working-directory: /transformers
         run: |
-          python3 -m pytest -n 1 -v --dist=loadfile --make-reports=${{ matrix.machine_type }}_run_pipelines_torch_gpu_test_reports tests/pipelines
+          python3 -m pytest -n 1 -v --dist=loadfile --make-reports=${{ matrix.machine_type }}_run_pipelines_torch_gpu_test_reports tests/pipelines -m "not not_device_test"
 
       - name: Failure short reports
         if: ${{ failure() }}
@@ -364,7 +364,7 @@ jobs:
       matrix:
         machine_type: [single-gpu, multi-gpu]
 
-    runs-on: [self-hosted, docker-gpu, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
+    runs-on: [self-hosted, amd-gpu, '${{ matrix.machine_type }}', '${{ inputs.gpu_flavor }}']
     needs: setup
     container:
       image: huggingface/transformers-pytorch-deepspeed-amd-gpu
@@ -400,7 +400,7 @@ jobs:
 
       - name: Run all tests on GPU
         working-directory: /transformers
-        run: python3 -m pytest -v --make-reports=${{ matrix.machine_type }}_run_torch_cuda_extensions_gpu_test_reports tests/deepspeed tests/extended
+        run: python3 -m pytest -v --make-reports=${{ matrix.machine_type }}_run_torch_cuda_extensions_gpu_test_reports tests/deepspeed tests/extended -m "not not_device_test"
 
       - name: Failure short reports
         if: ${{ failure() }}
diff --git a/docker/transformers-pytorch-amd-gpu/Dockerfile b/docker/transformers-pytorch-amd-gpu/Dockerfile
index 0b070c93a64f3d..da91906d621429 100644
--- a/docker/transformers-pytorch-amd-gpu/Dockerfile
+++ b/docker/transformers-pytorch-amd-gpu/Dockerfile
@@ -1,24 +1,19 @@
-FROM rocm/dev-ubuntu-20.04:5.6
+FROM rocm/dev-ubuntu-22.04:6.0.2
 # rocm/pytorch has no version with 2.1.0
 LABEL maintainer="Hugging Face"
 
 ARG DEBIAN_FRONTEND=noninteractive
 
-ARG PYTORCH='2.1.0'
-ARG TORCH_VISION='0.16.0'
-ARG TORCH_AUDIO='2.1.0'
-ARG ROCM='5.6'
-
 RUN apt update && \
-    apt install -y --no-install-recommends git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-dev python3-pip ffmpeg && \
+    apt install -y --no-install-recommends git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-dev python3-pip python3-dev ffmpeg && \
     apt clean && \
     rm -rf /var/lib/apt/lists/*
 
-RUN python3 -m pip install --no-cache-dir --upgrade pip
+RUN python3 -m pip install --no-cache-dir --upgrade pip numpy
 
-RUN python3 -m pip install torch==$PYTORCH torchvision==$TORCH_VISION torchaudio==$TORCH_AUDIO --index-url https://download.pytorch.org/whl/rocm$ROCM
+RUN python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0
 
-RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools ninja git+https://github.com/facebookresearch/detectron2.git pytesseract "itsdangerous<2.1.0"
+RUN python3 -m pip install --no-cache-dir --upgrade importlib-metadata setuptools ninja git+https://github.com/facebookresearch/detectron2.git pytesseract "itsdangerous<2.1.0"
 
 ARG REF=main
 WORKDIR /
@@ -35,5 +30,5 @@ RUN python3 -m pip uninstall -y tensorflow flax
 # this line must be added in order for python to be aware of transformers.
 RUN cd transformers && python3 setup.py develop
 
-# Remove nvml as it is not compatible with ROCm
-RUN python3 -m pip uninstall py3nvml pynvml -y
+# Remove nvml as it is not compatible with ROCm. apex is not tested on NVIDIA either.
+RUN python3 -m pip uninstall py3nvml pynvml apex -y
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index de49d4427b5687..c8e99c1d43f5bc 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -94,7 +94,7 @@ We strongly suggest referring to the detailed [installation instructions](https:
 
 
 
-FlashAttention-2 is also supported on AMD GPUs and current support is limited to **Instinct MI210** and **Instinct MI250**. We strongly suggest using this [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.
+FlashAttention-2 is also supported on AMD GPUs and current support is limited to **Instinct MI210**, **Instinct MI250** and **Instinct MI300**. We strongly suggest using this [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.
 
 
 
diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py
index 4543cf9f98b5f3..39de7d6d326bb3 100755
--- a/src/transformers/integrations/integration_utils.py
+++ b/src/transformers/integrations/integration_utils.py
@@ -1545,6 +1545,11 @@ def __init__(self):
             raise RuntimeError(
                 "CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
             )
+        elif torch.version.hip:
+            raise RuntimeError(
+                "CodeCarbonCallback requires `codecarbon` package, which is not compatible with AMD ROCm (https://github.com/mlco2/codecarbon/pull/490). When using the Trainer, please specify the `report_to` argument (https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments.report_to) to disable CodeCarbonCallback."
+            )
+
         import codecarbon
 
         self._codecarbon = codecarbon
diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py
index 6ea2a6674b4034..2807c9951aa6d6 100644
--- a/src/transformers/training_args.py
+++ b/src/transformers/training_args.py
@@ -1735,6 +1735,13 @@ def __post_init__(self):
             from .integrations import get_available_reporting_integrations
 
             self.report_to = get_available_reporting_integrations()
+
+            if "codecarbon" in self.report_to and torch.version.hip:
+                logger.warning(
+                    "When using the Trainer, CodeCarbonCallback requires the `codecarbon` package, which is not compatible with AMD ROCm (https://github.com/mlco2/codecarbon/pull/490). Automatically disabling the codecarbon callback. Reference: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments.report_to."
+                )
+                self.report_to.remove("codecarbon")
+
         elif self.report_to == "none" or self.report_to == ["none"]:
             self.report_to = []
         elif not isinstance(self.report_to, list):
diff --git a/tests/extended/test_trainer_ext.py b/tests/extended/test_trainer_ext.py
index 4bda892162fdad..a35ea1a8e7eba7 100644
--- a/tests/extended/test_trainer_ext.py
+++ b/tests/extended/test_trainer_ext.py
@@ -301,6 +301,7 @@ def run_trainer(
             --label_smoothing_factor 0.1
             --target_lang ro_RO
             --source_lang en_XX
+            --report_to none
         """.split()
 
         args_eval = f"""
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index da6dcb2a4b72fb..c420da4052f186 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -607,7 +607,7 @@ def test_trainer_with_datasets(self):
 
         # Base training. Should have the same results as test_reproducible_training
         model = RegressionModel()
-        args = TrainingArguments("./regression", learning_rate=0.1)
+        args = TrainingArguments("./regression", learning_rate=0.1, report_to="none")
         trainer = Trainer(model, args, train_dataset=train_dataset)
         trainer.train()
         self.check_trained_model(trainer.model)
@@ -629,7 +629,7 @@ def test_trainer_with_datasets(self):
 
     def test_model_init(self):
         train_dataset = RegressionDataset()
-        args = TrainingArguments("./regression", learning_rate=0.1)
+        args = TrainingArguments("./regression", learning_rate=0.1, report_to="none")
         trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
         trainer.train()
         self.check_trained_model(trainer.model)
@@ -692,7 +692,7 @@ def test_training_loss(self):
 
     def test_custom_optimizer(self):
         train_dataset = RegressionDataset()
-        args = TrainingArguments("./regression")
+        args = TrainingArguments("./regression", report_to="none")
         model = RegressionModel()
         optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
         lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0)
@@ -716,6 +716,7 @@ def test_lr_scheduler_kwargs(self):
             lr_scheduler_kwargs=extra_kwargs,
             learning_rate=0.2,
             warmup_steps=num_warmup_steps,
+            report_to="none",
         )
         trainer = Trainer(model, args, train_dataset=train_dataset)
         trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
@@ -742,6 +743,7 @@ def test_cosine_with_min_lr_scheduler(self):
             lr_scheduler_kwargs=extra_kwargs,
             learning_rate=0.2,
             warmup_steps=num_warmup_steps,
+            report_to="none",
         )
         trainer = Trainer(model, args, train_dataset=train_dataset)
         trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
@@ -762,6 +764,7 @@ def test_reduce_lr_on_plateau_args(self):
             "./regression",
             eval_strategy="epoch",
             metric_for_best_model="eval_loss",
+            report_to="none",
         )
         model = RegressionModel()
         optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
@@ -796,6 +799,7 @@ def log(self, logs):
             metric_for_best_model="eval_loss",
             num_train_epochs=10,
             learning_rate=0.2,
+            report_to="none",
         )
         model = RegressionModel()
         trainer = TrainerWithLRLogs(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
@@ -828,7 +832,7 @@ def test_adafactor_lr_none(self):
         from transformers.optimization import Adafactor, AdafactorSchedule
 
         train_dataset = RegressionDataset()
-        args = TrainingArguments("./regression")
+        args = TrainingArguments("./regression", report_to="none")
         model = RegressionModel()
         optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
         lr_scheduler = AdafactorSchedule(optimizer)
@@ -879,7 +883,7 @@ def test_trainer_works_with_dict(self):
         train_dataset = RegressionDataset()
         eval_dataset = RegressionDataset()
         model = RegressionDictModel()
-        args = TrainingArguments("./regression")
+        args = TrainingArguments("./regression", report_to="none")
         trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
         trainer.train()
         _ = trainer.evaluate()
@@ -890,7 +894,7 @@ def test_evaluation_with_keys_to_drop(self):
         tiny_gpt2 = GPT2LMHeadModel(config)
         x = torch.randint(0, 100, (128,))
         eval_dataset = RepeatDataset(x)
-        args = TrainingArguments("./test")
+        args = TrainingArguments("./test", report_to="none")
         trainer = Trainer(tiny_gpt2, args, eval_dataset=eval_dataset)
         # By default the past_key_values are removed
         result = trainer.predict(eval_dataset)
@@ -1100,7 +1104,12 @@ def test_neftune(self):
 
         # Trainer without inf/nan filter
         args = TrainingArguments(
-            "./test", learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, neftune_noise_alpha=0.4
+            "./test",
+            learning_rate=1e-9,
+            logging_steps=5,
+            logging_nan_inf_filter=False,
+            neftune_noise_alpha=0.4,
+            report_to="none",
         )
         trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
 
@@ -1117,7 +1126,12 @@ def test_neftune(self):
         tiny_gpt2 = GPT2LMHeadModel(config)
         # Trainer without inf/nan filter
         args = TrainingArguments(
-            "./test", learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, neftune_noise_alpha=0.4
+            "./test",
+            learning_rate=1e-9,
+            logging_steps=5,
+            logging_nan_inf_filter=False,
+            neftune_noise_alpha=0.4,
+            report_to="none",
         )
         trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
 
@@ -1143,13 +1157,17 @@ def test_logging_inf_nan_filter(self):
         train_dataset = RepeatDataset(x)
 
         # Trainer without inf/nan filter
-        args = TrainingArguments("./test", learning_rate=1e9, logging_steps=5, logging_nan_inf_filter=False)
+        args = TrainingArguments(
+            "./test", learning_rate=1e9, logging_steps=5, logging_nan_inf_filter=False, report_to="none"
+        )
         trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
         trainer.train()
         log_history_no_filter = trainer.state.log_history
 
         # Trainer with inf/nan filter
-        args = TrainingArguments("./test", learning_rate=1e9, logging_steps=5, logging_nan_inf_filter=True)
+        args = TrainingArguments(
+            "./test", learning_rate=1e9, logging_steps=5, logging_nan_inf_filter=True, report_to="none"
+        )
         trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
         trainer.train()
         log_history_filter = trainer.state.log_history
@@ -1196,11 +1214,16 @@ def test_train_and_eval_dataloaders(self):
     # tests that we do not require dataloader to have a .dataset attribute
     def test_dataloader_without_dataset(self):
         train_dataset = RegressionDataset(length=128)
-        trainer = CustomDataloaderTrainer(
-            model=RegressionModel(), train_dataset=train_dataset, eval_dataset=train_dataset
-        )
-        trainer.train()
-        trainer.evaluate()
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            trainer = CustomDataloaderTrainer(
+                model=RegressionModel(),
+                train_dataset=train_dataset,
+                eval_dataset=train_dataset,
+                args=TrainingArguments(output_dir=tmp_dir, report_to="none"),
+            )
+
+            trainer.train()
+            trainer.evaluate()
 
     def test_galore_matched_modules(self):
         regex_patterns = [r".*.attn.*", r".*.mlp.*"]
@@ -1495,7 +1518,9 @@ def test_data_is_not_parallelized_when_model_is_parallel(self):
         # Make the Trainer believe it's a parallelized model
         model.is_parallelizable = True
         model.model_parallel = True
-        args = TrainingArguments("./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16)
+        args = TrainingArguments(
+            "./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16, report_to="none"
+        )
         trainer = Trainer(model, args, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
         # Check the Trainer was fooled
         self.assertTrue(trainer.is_model_parallel)
@@ -1849,7 +1874,7 @@ def test_predict_with_ipex(self):
     def test_dynamic_shapes(self):
         eval_dataset = DynamicShapesDataset(batch_size=self.batch_size)
         model = RegressionModel(a=2, b=1)
-        args = TrainingArguments("./regression")
+        args = TrainingArguments("./regression", report_to="none")
         trainer = Trainer(model, args, eval_dataset=eval_dataset)
 
         # Check evaluation can run to completion
@@ -1866,7 +1891,7 @@ def test_dynamic_shapes(self):
             self.assertTrue(np.all(seen[expected.shape[0] :] == -100))
 
         # Same tests with eval accumulation
-        args = TrainingArguments("./regression", eval_accumulation_steps=2)
+        args = TrainingArguments("./regression", eval_accumulation_steps=2, report_to="none")
         trainer = Trainer(model, args, eval_dataset=eval_dataset)
 
         # Check evaluation can run to completion
@@ -2984,13 +3009,14 @@ def test_bf16_full_eval(self):
 
     def test_no_wd_param_group(self):
         model = nn.Sequential(TstLayer(128), nn.ModuleList([TstLayer(128), TstLayer(128)]))
-        trainer = Trainer(model=model)
-        trainer.create_optimizer_and_scheduler(10)
-        wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight']  # fmt: skip
-        wd_params = [p for n, p in model.named_parameters() if n in wd_names]
-        no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names]
-        self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
-        self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            trainer = Trainer(model=model, args=TrainingArguments(output_dir=tmp_dir, report_to="none"))
+            trainer.create_optimizer_and_scheduler(10)
+            wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight']  # fmt: skip
+            wd_params = [p for n, p in model.named_parameters() if n in wd_names]
+            no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names]
+            self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
+            self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
 
     @slow
     @require_torch_multi_accelerator
@@ -4134,32 +4160,35 @@ def test_get_num_trainable_parameters(self):
         # in_features * out_features + bias
         layer_1 = 128 * 64 + 64
         layer_2 = 64 * 32 + 32
-        trainer = Trainer(model=model)
-        self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2)
-        # Freeze the last layer
-        for param in model[-1].parameters():
-            param.requires_grad = False
-        self.assertEqual(trainer.get_num_trainable_parameters(), layer_1)
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            trainer = Trainer(model=model, args=TrainingArguments(output_dir=tmp_dir, report_to="none"))
+            self.assertEqual(trainer.get_num_trainable_parameters(), layer_1 + layer_2)
+            # Freeze the last layer
+            for param in model[-1].parameters():
+                param.requires_grad = False
+            self.assertEqual(trainer.get_num_trainable_parameters(), layer_1)
 
     def test_get_learning_rates(self):
         model = nn.Sequential(nn.Linear(128, 64))
-        trainer = Trainer(model=model)
-        with self.assertRaises(ValueError):
-            trainer.get_learning_rates()
-        trainer.create_optimizer()
-        self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05])
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            trainer = Trainer(model=model, args=TrainingArguments(output_dir=tmp_dir, report_to="none"))
+            with self.assertRaises(ValueError):
+                trainer.get_learning_rates()
+            trainer.create_optimizer()
+            self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05])
 
     def test_get_optimizer_group(self):
         model = nn.Sequential(nn.Linear(128, 64))
-        trainer = Trainer(model=model)
-        # ValueError is raised if optimizer is None
-        with self.assertRaises(ValueError):
-            trainer.get_optimizer_group()
-        trainer.create_optimizer()
-        # Get groups
-        num_groups = len(trainer.get_optimizer_group())
-        self.assertEqual(num_groups, 2)
-        # Get group of parameter
-        param = next(model.parameters())
-        group = trainer.get_optimizer_group(param)
-        self.assertIn(param, group["params"])
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            trainer = Trainer(model=model, args=TrainingArguments(output_dir=tmp_dir, report_to="none"))
+            # ValueError is raised if optimizer is None
+            with self.assertRaises(ValueError):
+                trainer.get_optimizer_group()
+            trainer.create_optimizer()
+            # Get groups
+            num_groups = len(trainer.get_optimizer_group())
+            self.assertEqual(num_groups, 2)
+            # Get group of parameter
+            param = next(model.parameters())
+            group = trainer.get_optimizer_group(param)
+            self.assertIn(param, group["params"])
diff --git a/tests/trainer/test_trainer_distributed.py b/tests/trainer/test_trainer_distributed.py
index 8f867cf0beba37..968f800174a64a 100644
--- a/tests/trainer/test_trainer_distributed.py
+++ b/tests/trainer/test_trainer_distributed.py
@@ -153,7 +153,7 @@ def test_trainer(self):
             {self.test_file_dir}/test_trainer_distributed.py
         """.split()
         output_dir = self.get_auto_remove_tmp_dir()
-        args = f"--output_dir {output_dir}".split()
+        args = f"--output_dir {output_dir} --report_to none".split()
         cmd = ["torchrun"] + distributed_args + args
         execute_subprocess_async(cmd, env=self.get_env())
         # successful return here == success - any errors would have caused an error in the sub-call
diff --git a/tests/trainer/test_trainer_seq2seq.py b/tests/trainer/test_trainer_seq2seq.py
index d8722c67836f26..61d2163b9e815c 100644
--- a/tests/trainer/test_trainer_seq2seq.py
+++ b/tests/trainer/test_trainer_seq2seq.py
@@ -119,6 +119,7 @@ def _compute_metrics(pred):
             warmup_steps=0,
             eval_steps=2,
             logging_steps=2,
+            report_to="none",
         )
 
         # instantiate trainer
@@ -152,7 +153,7 @@ def test_return_sequences(self):
             "google-t5/t5-small", max_length=None, min_length=None, max_new_tokens=256, min_new_tokens=1, num_beams=5
         )
 
-        training_args = Seq2SeqTrainingArguments(".", predict_with_generate=True)
+        training_args = Seq2SeqTrainingArguments(".", predict_with_generate=True, report_to="none")
 
         trainer = Seq2SeqTrainer(
             model=model,
@@ -160,6 +161,7 @@ def test_return_sequences(self):
             tokenizer=tokenizer,
             data_collator=data_collator,
             compute_metrics=lambda x: {"samples": x[0].shape[0]},
+            report_to="none",
         )
 
         def prepare_data(examples):
@@ -191,7 +193,9 @@ def test_bad_generation_config_fail_early(self):
         data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", padding="longest")
         gen_config = GenerationConfig(do_sample=False, top_p=0.9)  # bad: top_p is not compatible with do_sample=False
 
-        training_args = Seq2SeqTrainingArguments(".", predict_with_generate=True, generation_config=gen_config)
+        training_args = Seq2SeqTrainingArguments(
+            ".", predict_with_generate=True, generation_config=gen_config, report_to="none"
+        )
         with self.assertRaises(ValueError) as exc:
             _ = Seq2SeqTrainer(
                 model=model,