From 1360801a69c0b169e3efdbb0cd05d9a0e72bfb70 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Tue, 14 May 2024 22:07:15 +0200 Subject: [PATCH] Add PaliGemma (#30814) * add new model like * add state dict slicing + new model config * update palma config and weights, passes vision activations * fix * update * reorder loading/unpacking * clean up * add debug statements * change device * fix * debugging * fix noncausal mask * fixup sdpa + causal mask * fix activation function * remove debug before changing modeling file * add variants * debug attention mask in generate * revert to non-debug sdpa * revert gemma modifications * add custom language modeling * use Processor * add language modeling file to init * try thin wrapper around generate * Update * update mask * breakpoints galore * remove conflict * switch to left-padding * add incomplete model doc * add paligemma global files * batch rename paligemma * make generation match outputs and captioning * style * style * remove copied from + doc * remove more copied from * remove copy from projector * minor fix * update config and style * add readme - dummy * CORRECT image captioning * moving to args * add siglip proper + fix merging image + text features * take update_causal_mask from upstream * remove breakpoint * leverage AutoModel * fix input_ids slicing * make siglip head conditional * remove encoder_decoder value * remove unneeded modeling file * add commented 4d attention mask * FIXED generation with 4D mask * Update src/transformers/models/siglip/modeling_siglip.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix left padding detection * shuffle order of verifications * fix missing labels for training * fix * vectorize merging of features, improve slicing * improve testing before conversion * handle merging in processor * image token index depends on checkpoint * add variants, save processor too * save processors, base tokenizer off spm file * expand model embeddings due to additional image token * pass image processing args * add convert rgb to siglip processor * add \n token separately * fix tokenizer and prompts * fix docstrings * change to camel * fix casing * debug pos_ids and sdpa * pass and use cache_position * add flag for newline tokenization * Update src/transformers/models/paligemma/processing_paligemma.py Co-authored-by: Merve Noyan * simplify conversion script * add copied from * add precision to conversion script * Update src/transformers/models/paligemma/modeling_paligemma.py Co-authored-by: Pedro Cuenca * clean up * Shift attention mask from `1:` After discussion with @molbap * add docs, fix quality * quality, tied weights inheritance, and logits/label alignment * fix more tests * pass attn_implementation to language model correctly * add SiglipVisionTransformer to no split modules * skip paligemma test for sdpa dispatch to flash * skip incompatible tests * quality * [broken archive maps] * Apply suggestions - remove archive lists - style - take shape of inputs_embeds for batch Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/utils/dummy_pt_objects.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * simplify conversion script * add suggestions * add suggestions * add copied from * fix * move labels out * revert * fix * remove placeholder labels if None * use cache_position * fix quality + docstrings * fix quality * fix paligemma 4d gemma mask incompatibility * fix config docstring * fix query and attn_mask dtype --------- Co-authored-by: ArthurZucker Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Merve Noyan Co-authored-by: Pedro Cuenca --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/paligemma.md | 38 ++ docs/source/en/perf_infer_gpu_one.md | 1 + src/transformers/__init__.py | 16 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + .../models/gemma/modeling_gemma.py | 3 - src/transformers/models/paligemma/__init__.py | 54 ++ .../paligemma/configuration_paligemma.py | 130 +++++ .../convert_paligemma_weights_to_hf.py | 349 +++++++++++ .../models/paligemma/modeling_paligemma.py | 552 ++++++++++++++++++ .../models/paligemma/processing_paligemma.py | 268 +++++++++ .../models/siglip/image_processing_siglip.py | 13 + .../models/siglip/modeling_siglip.py | 12 +- src/transformers/utils/dummy_pt_objects.py | 21 + tests/models/paligemma/__init__.py | 0 .../paligemma/test_modeling_paligemma.py | 426 ++++++++++++++ tests/test_modeling_common.py | 4 + 23 files changed, 1890 insertions(+), 8 deletions(-) create mode 100644 docs/source/en/model_doc/paligemma.md create mode 100644 src/transformers/models/paligemma/__init__.py create mode 100644 src/transformers/models/paligemma/configuration_paligemma.py create mode 100644 src/transformers/models/paligemma/convert_paligemma_weights_to_hf.py create mode 100644 src/transformers/models/paligemma/modeling_paligemma.py create mode 100644 src/transformers/models/paligemma/processing_paligemma.py create mode 100644 tests/models/paligemma/__init__.py create mode 100644 tests/models/paligemma/test_modeling_paligemma.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index cd60486e161cb9..325ecb0c4d2c80 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -784,6 +784,8 @@ title: OWL-ViT - local: model_doc/owlv2 title: OWLv2 + - local: model_doc/paligemma + title: PaliGemma - local: model_doc/perceiver title: Perceiver - local: model_doc/pix2struct diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 943f98c2183cc2..16018c32e57ec0 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -230,6 +230,7 @@ Flax), PyTorch, and/or TensorFlow. | [OPT](model_doc/opt) | ✅ | ✅ | ✅ | | [OWL-ViT](model_doc/owlvit) | ✅ | ❌ | ❌ | | [OWLv2](model_doc/owlv2) | ✅ | ❌ | ❌ | +| [PaliGemma](model_doc/paligemma) | ✅ | ❌ | ❌ | | [PatchTSMixer](model_doc/patchtsmixer) | ✅ | ❌ | ❌ | | [PatchTST](model_doc/patchtst) | ✅ | ❌ | ❌ | | [Pegasus](model_doc/pegasus) | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/paligemma.md b/docs/source/en/model_doc/paligemma.md new file mode 100644 index 00000000000000..10946caa83d769 --- /dev/null +++ b/docs/source/en/model_doc/paligemma.md @@ -0,0 +1,38 @@ + + +# PaliGemma + +## Overview + +The PaliGemma model was proposed by Google. It is a 3B VLM composed by a Siglip-400m vision encoder and a Gemma-2B decoder linked by a multimodal linear projection. It is not a chat model with images. It cuts an image into a fixed number of VIT tokens and prepends it to an optional prompt. One particularity is that the model uses full block attention on all the image tokens plus the input text tokens. It comes in 3 resolutions, 224x224, 448x448 and 896x896 with 3 base models, with 55 fine-tuned versions for different tasks, and 2 mix models. + + +This model was contributed by [Molbap](https://huggingface.co/Molbap). + + +## PaliGemmaConfig + +[[autodoc]] PaliGemmaConfig + +## PaliGemmaProcessor + +[[autodoc]] PaliGemmaProcessor + +## PaliGemmaForConditionalGeneration + +[[autodoc]] PaliGemmaForConditionalGeneration + - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 151ca765a216f8..2dfabcb6a4a135 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -203,6 +203,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) +* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 16970ac22af502..ea1b0de9dd5d74 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -582,6 +582,7 @@ "OwlViTTextConfig", "OwlViTVisionConfig", ], + "models.paligemma": ["PaliGemmaConfig"], "models.patchtsmixer": ["PatchTSMixerConfig"], "models.patchtst": ["PatchTSTConfig"], "models.pegasus": [ @@ -2651,6 +2652,13 @@ "OwlViTVisionModel", ] ) + _import_structure["models.paligemma"].extend( + [ + "PaliGemmaForConditionalGeneration", + "PaliGemmaPreTrainedModel", + "PaliGemmaProcessor", + ] + ) _import_structure["models.patchtsmixer"].extend( [ "PatchTSMixerForPrediction", @@ -5126,6 +5134,9 @@ OwlViTTextConfig, OwlViTVisionConfig, ) + from .models.paligemma import ( + PaliGemmaConfig, + ) from .models.patchtsmixer import ( PatchTSMixerConfig, ) @@ -6956,6 +6967,11 @@ OwlViTTextModel, OwlViTVisionModel, ) + from .models.paligemma import ( + PaliGemmaForConditionalGeneration, + PaliGemmaPreTrainedModel, + PaliGemmaProcessor, + ) from .models.patchtsmixer import ( PatchTSMixerForPrediction, PatchTSMixerForPretraining, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index c0fd6cc6d2cae4..d3b931c92694a5 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -173,6 +173,7 @@ opt, owlv2, owlvit, + paligemma, patchtsmixer, patchtst, pegasus, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 634399ad64d4b6..7ff94b6d8c4bab 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -182,6 +182,7 @@ ("opt", "OPTConfig"), ("owlv2", "Owlv2Config"), ("owlvit", "OwlViTConfig"), + ("paligemma", "PaliGemmaConfig"), ("patchtsmixer", "PatchTSMixerConfig"), ("patchtst", "PatchTSTConfig"), ("pegasus", "PegasusConfig"), @@ -464,6 +465,7 @@ ("opt", "OPT"), ("owlv2", "OWLv2"), ("owlvit", "OWL-ViT"), + ("paligemma", "PaliGemma"), ("patchtsmixer", "PatchTSMixer"), ("patchtst", "PatchTST"), ("pegasus", "Pegasus"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 9e8daefb397a58..bed78fb5630cca 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -93,6 +93,7 @@ ("oneformer", "OneFormerImageProcessor"), ("owlv2", "Owlv2ImageProcessor"), ("owlvit", "OwlViTImageProcessor"), + ("paligemma", "CLIPImageProcessor"), ("perceiver", "PerceiverImageProcessor"), ("pix2struct", "Pix2StructImageProcessor"), ("poolformer", "PoolFormerImageProcessor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index cb3e84efc289b2..3ae91ac176f216 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -313,6 +313,7 @@ ("nezha", "NezhaForPreTraining"), ("nllb-moe", "NllbMoeForConditionalGeneration"), ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("paligemma", "PaliGemmaForConditionalGeneration"), ("retribert", "RetriBertModel"), ("roberta", "RobertaForMaskedLM"), ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), @@ -697,6 +698,7 @@ ("kosmos-2", "Kosmos2ForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"), + ("paligemma", "PaliGemmaForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"), ("vision-encoder-decoder", "VisionEncoderDecoderModel"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 1bbce2d85aa1bd..d4babb88acd4e4 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -74,6 +74,7 @@ ("oneformer", "OneFormerProcessor"), ("owlv2", "Owlv2Processor"), ("owlvit", "OwlViTProcessor"), + ("paligemma", "PaliGemmaProcessor"), ("pix2struct", "Pix2StructProcessor"), ("pop2piano", "Pop2PianoProcessor"), ("sam", "SamProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 8bb8757cb659ca..7b213e49912bd5 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -331,6 +331,7 @@ ("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), + ("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ( "pegasus", ( diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 12d01a6ea04d3e..9af816839445e6 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -794,7 +794,6 @@ def _init_weights(self, module): "The bare Gemma Model outputting raw hidden-states without any specific head on top.", GEMMA_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMA,Llama->Gemma class GemmaModel(GemmaPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`] @@ -988,8 +987,6 @@ def _update_causal_mask( if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing - if attention_mask.max() != 0: - raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: causal_mask = torch.full( diff --git a/src/transformers/models/paligemma/__init__.py b/src/transformers/models/paligemma/__init__.py new file mode 100644 index 00000000000000..11ba4f3edd09e8 --- /dev/null +++ b/src/transformers/models/paligemma/__init__.py @@ -0,0 +1,54 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = {"configuration_paligemma": ["PaliGemmaConfig"]} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_paligemma"] = [ + "PaliGemmaForConditionalGeneration", + "PaliGemmaPreTrainedModel", + ] + _import_structure["processing_paligemma"] = ["PaliGemmaProcessor"] + + +if TYPE_CHECKING: + from .configuration_paligemma import PaliGemmaConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_paligemma import ( + PaliGemmaForConditionalGeneration, + PaliGemmaPreTrainedModel, + ) + from .processing_paligemma import PaliGemmaProcessor + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/paligemma/configuration_paligemma.py b/src/transformers/models/paligemma/configuration_paligemma.py new file mode 100644 index 00000000000000..7252425b855ed9 --- /dev/null +++ b/src/transformers/models/paligemma/configuration_paligemma.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PaliGemmamodel configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class PaliGemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PaliGemmaForConditionalGeneration`]. It is used to instantiate an + PaliGemmamodel according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PaliGemma-2B. + + e.g. [paligemma-hf/paligemma-2b](https://huggingface.co/paligemma-hf/paligemma-2b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`PaliGemmaVisionConfig`, *optional*): + Custom vision config or dict + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + image_token_index (`int`, *optional*, defaults to 256000): + The image token index to encode the image prompt. + vocab_size (`int`, *optional*, defaults to 257152): + Vocabulary size of the PaliGemmamodel. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~PaliGemmaForConditionalGeneration`] + projection_dim (`int`, *optional*, defaults to 2048): + Dimension of the multimodal projection space. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden layer of the Language model. + + Example: + + ```python + >>> from transformers import PaliGemmaForConditionalGeneration, PaliGemmaConfig, SiglipVisionConfig, GemmaConfig + + >>> # Initializing a Siglip-like vision config + >>> vision_config = SiglipVisionConfig() + + >>> # Initializing a PaliGemma config + >>> text_config = GemmaConfig() + + >>> # Initializing a PaliGemma paligemma-3b-224 style configuration + >>> configuration = PaliGemmaConfig(vision_config, text_config) + + >>> # Initializing a model from the paligemma-3b-224 style configuration + >>> model = PaliGemmaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "paligemma" + is_composition = False + + def __init__( + self, + vision_config=None, + text_config=None, + ignore_index=-100, + image_token_index=256000, + vocab_size=257152, + projection_dim=2048, + hidden_size=2048, + **kwargs, + ): + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.vocab_size = vocab_size + self.projection_dim = projection_dim + self.hidden_size = hidden_size + self.vision_config = vision_config + self.is_encoder_decoder = False + + if isinstance(self.vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model" + ) + self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + self.vision_config = CONFIG_MAPPING["siglip_vision_model"]( + intermediate_size=4096, + hidden_size=1152, + patch_size=14, + image_size=224, + num_hidden_layers=27, + num_attention_heads=16, + vocab_size=257152, + vision_use_head=False, + ) + self.vocab_size = self.vocab_size + + self.text_config = text_config + + if isinstance(self.text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma" + self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + self.vocab_size = self.text_config.vocab_size + elif text_config is None: + self.text_config = CONFIG_MAPPING["gemma"]( + hidden_size=2048, + num_hidden_layers=18, + intermediate_size=16384, + num_attention_heads=8, + num_key_value_heads=1, + is_encoder_decoder=False, + ) + self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2 + self.vision_config.projection_dim = projection_dim + super().__init__(**kwargs) diff --git a/src/transformers/models/paligemma/convert_paligemma_weights_to_hf.py b/src/transformers/models/paligemma/convert_paligemma_weights_to_hf.py new file mode 100644 index 00000000000000..0d43cac0167c6b --- /dev/null +++ b/src/transformers/models/paligemma/convert_paligemma_weights_to_hf.py @@ -0,0 +1,349 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert PaliGemma checkpoints from the original repository. +""" + + +import argparse +import collections + +import torch +from numpy import load + +from transformers import ( + AutoTokenizer, + GemmaTokenizer, + GemmaTokenizerFast, + PaliGemmaConfig, + PaliGemmaForConditionalGeneration, + PaliGemmaProcessor, + SiglipImageProcessor, +) +from transformers.tokenization_utils_base import AddedToken +from transformers.utils import logging + + +device = "cuda" # "cpu" + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +# TODO add sequence length variations here + +PALIGEMMA_VARIANTS = ["2b-test", "3b-224px", "3b-448px", "3b-896px"] + + +def get_paligemma_config(variant: str, precision: str): + config = { + "image_token_index": None, + "pad_token_id": 0, + "bos_token_id": 2, + "eos_token_id": 1, + } + + image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896} + + if variant in PALIGEMMA_VARIANTS: + image_size = image_sizes[variant] + patch_size = 14 + num_image_tokens = (image_size**2) // (patch_size**2) + + config["image_token_index"] = 257152 if variant != "2b-test" else 256000 + text_config = { + "vocab_size": 257152, + "num_hidden_layers": 18, + "num_key_value_heads": 1, + "head_dim": 256, + "torch_dtype": precision, + "hidden_size": 2048, + "hidden_activation": "gelu_pytorch_tanh", + "num_attention_heads": 8, + "intermediate_size": 16384, + "is_encoder_decoder": False, + } + vision_config = { + "torch_dtype": precision, + "image_size": image_size, + "patch_size": patch_size, + "num_image_tokens": num_image_tokens, + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + "projector_hidden_act": "gelu_fast", + "vision_use_head": False, + } + final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) + else: + raise ValueError(f"Identifier {variant} not supported. Available: {PALIGEMMA_VARIANTS}") + return final_config + + +def slice_state_dict(state_dict, config): + # fmt: off + # patch embeddings + state_dict["vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop("img/embedding/kernel").transpose( + 3, 2, 0, 1 + ) + state_dict["vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop("img/embedding/bias") + # positional embeddings + state_dict["vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop("img/pos_embedding").reshape( + -1, config.vision_config.hidden_size + ) + + # extract vision layers to be sliced at index 0. There are 27 layers in the base model. + encoderblock_layernorm0_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/scale") + encoderblock_layernorm0_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_0/bias") + encoderblock_layernorm1_scale = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/scale") + encoderblock_layernorm1_bias = state_dict.pop("img/Transformer/encoderblock/LayerNorm_1/bias") + + encoderblock_mlp_dense0_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel") + encoderblock_mlp_dense0_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias") + encoderblock_mlp_dense1_kernel= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel") + encoderblock_mlp_dense1_bias= state_dict.pop("img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias") + + encoderblock_attention_0_key_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel") + encoderblock_attention_0_key_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias") + encoderblock_attention_0_value_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel") + encoderblock_attention_0_value_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias") + encoderblock_attention_0_query_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel") + encoderblock_attention_0_query_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias") + encoderblock_attention_0_out_kernel = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel") + encoderblock_attention_0_out_bias = state_dict.pop("img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias") + + for i in range(config.vision_config.num_hidden_layers): + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] + + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + + state_dict["vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop("img/Transformer/encoder_norm/scale").transpose() + state_dict["vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop("img/Transformer/encoder_norm/bias") + + # multimodal projector + + state_dict['multi_modal_projector.linear.weight'] = state_dict.pop("img/head/kernel").transpose() + state_dict['multi_modal_projector.linear.bias'] = state_dict.pop("img/head/bias") + + # text decoder (gemma) + + embedding_vector = state_dict.pop("llm/embedder/input_embedding") + state_dict["language_model.model.embed_tokens.weight"] = embedding_vector + + # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. + + llm_attention_attn_vec_einsum = state_dict.pop("llm/layers/attn/attn_vec_einsum/w") + llm_attention_kv_einsum = state_dict.pop("llm/layers/attn/kv_einsum/w") + llm_attention_q_einsum = state_dict.pop("llm/layers/attn/q_einsum/w") + + llm_mlp_gating_einsum = state_dict.pop("llm/layers/mlp/gating_einsum") + llm_mlp_linear = state_dict.pop("llm/layers/mlp/linear") + # TODO verify correctness of layer norm loading + + llm_input_layernorm = state_dict.pop("llm/layers/pre_attention_norm/scale") + llm_post_attention_layernorm = state_dict.pop("llm/layers/pre_ffw_norm/scale") + + for i in range(config.text_config.num_hidden_layers): + # llm_attention_q_einsum[i].shape = (8, 2048, 256) + q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + state_dict[f"language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped + + # llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256) + k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() + state_dict[f"language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped + # llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256) + v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() + state_dict[f"language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped + + # output projection. + + # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048) + o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + state_dict[f"language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped + # mlp layers + gate_proj_weight = llm_mlp_gating_einsum[i, 0] + state_dict[f"language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() + up_proj_weight = llm_mlp_gating_einsum[i, 1] + state_dict[f"language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() + state_dict[f"language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() + state_dict[f"language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] + state_dict[f"language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] + + state_dict["language_model.model.norm.weight"] = state_dict.pop("llm/final_norm/scale") + state_dict["language_model.lm_head.weight"] = embedding_vector # weights are tied. + + # fmt: on + for key, value in state_dict.items(): + state_dict[key] = torch.from_numpy(value) + return state_dict + + +def flatten_nested_dict(params, parent_key="", sep="/"): + items = [] + + for k, v in params.items(): + k = k.removeprefix("params/") + new_key = parent_key + sep + k if parent_key else k + + if isinstance(v, collections.abc.MutableMapping): + items.extend(flatten_nested_dict(v, parent_key=new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +@torch.no_grad() +def convert_paligemma_checkpoint( + checkpoint_path, + tokenizer_model_file, + pytorch_dump_folder_path, + variant: str, + precision: str, + do_convert_weights=False, +): + """ + Read checkpoints from flax npz files, rename/reshape, send result to state dict and verify logits if needed. + """ + config = get_paligemma_config(variant, precision=precision) + if do_convert_weights: + if variant == "2b-test": + # for the test model, the vocabulary was smaller + tokenizer_id = "google/gemma-2b" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + else: + tokenizer_class = GemmaTokenizer if GemmaTokenizerFast is None else GemmaTokenizerFast + tokenizer = tokenizer_class(tokenizer_model_file) + image_token = AddedToken("", normalized=False, special=True) + tokens_to_add = {"additional_special_tokens": [image_token]} + tokenizer.add_special_tokens(tokens_to_add) + + # tokenizer.padding_side = 'right' # uncomment for testing purposes only. + + image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384") + image_processor.size = {"width": config.vision_config.image_size, "height": config.vision_config.image_size} + image_processor.image_seq_length = config.vision_config.num_image_tokens + + processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer) + data = load(checkpoint_path) + state_dict = flatten_nested_dict(data) + del data + state_dict_transformers = slice_state_dict(state_dict, config) + del state_dict + + model = PaliGemmaForConditionalGeneration(config).to(device).eval() + model.load_state_dict(state_dict_transformers) + del state_dict_transformers + + else: + processor = PaliGemmaProcessor.from_pretrained(pytorch_dump_folder_path) + model = ( + PaliGemmaForConditionalGeneration.from_pretrained(pytorch_dump_folder_path, attn_implementation="sdpa") + .to(device) + .eval() + ) + model.config.text_config._attn_implementation = "sdpa" + + # model expansion to get random embeds of image tokens + pad_shape = 64 # for performance reasons + pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data + mu = torch.mean(pre_expansion_embeddings, dim=0).float() + n = pre_expansion_embeddings.size()[0] + sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n + dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) + + # We add an image token so we resize the model + model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape) + model.language_model.model.embed_tokens.weight.data[257152:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[257152:].shape[0]))), + dim=0, + ) + model.language_model.lm_head.weight.data[257152:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[257152:].shape[0]))), + dim=0, + ) + + model.save_pretrained(pytorch_dump_folder_path, max_shard_size="2GB", safe_serialization=True) + processor.save_pretrained(pytorch_dump_folder_path) + + +# + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint_path", + required=True, + type=str, + help="Path to the .npz checkpoint", + ) + + parser.add_argument( + "--tokenizer_model_file", + required=True, + type=str, + help="Path to the sentencepiece tokenizer.model file", + ) + + parser.add_argument( + "--pytorch_dump_folder_path", + required=True, + type=str, + help="Path to the output directory where model and processor will be saved.", + ) + + parser.add_argument( + "--precision", + choices=["float32", "bfloat16", "float16"], + type=str, + help="Precision identifier for model conversion - should match the base checkpoint precision.", + ) + + parser.add_argument( + "--variant", + default="2b-test", + choices=PALIGEMMA_VARIANTS, + type=str, + help="String identifier of the paligemma variant to convert.", + ) + + parser.add_argument( + "--do_convert_weights", action="store_true", help="Whether or not to reload and convert the weights." + ) + + args = parser.parse_args() + convert_paligemma_checkpoint( + checkpoint_path=args.checkpoint_path, + tokenizer_model_file=args.tokenizer_model_file, + pytorch_dump_folder_path=args.pytorch_dump_folder_path, + variant=args.variant, + precision=args.precision, + do_convert_weights=args.do_convert_weights, + ) diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py new file mode 100644 index 00000000000000..b2f2904b2f83ac --- /dev/null +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -0,0 +1,552 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch PaliGemmamodel.""" +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import Cache +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, + replace_return_docstrings, +) +from .configuration_paligemma import PaliGemmaConfig + + +if is_flash_attn_2_available(): + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +from ..auto import AutoModel, AutoModelForCausalLM + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "PaliGemmaConfig" + + +@dataclass +class PaliGemmaCausalLMOutputWithPast(ModelOutput): + """ + Base class for PaliGemmacausal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class PaliGemmaMultiModalProjector(nn.Module): + def __init__(self, config: PaliGemmaConfig): + super().__init__() + self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True) + + def forward(self, image_features): + hidden_states = self.linear(image_features) + + return hidden_states + + +PALIGEMMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PaliGemmaConfig`] or [`PaliGemmaVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + PALIGEMMA_START_DOCSTRING, +) +class PaliGemmaPreTrainedModel(PreTrainedModel): + config_class = PaliGemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PaliGemmaMultiModalProjector"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = False + _supports_sdpa = True + + def _init_weights(self, module): + # important: this ported version of PaliGemmaisn't meant for training from scratch - only + # inference and fine-tuning + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def _supports_sdpa(self): + """ + Retrieve language_model's attribute to check whether the model supports + SDPA or not. + """ + return self.language_model._supports_sdpa + + +PALIGEMMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`PaliGemmaProcessor`] uses + [`SiglipImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The PALIGEMMA model which consists of a vision backbone and a language model.""", + PALIGEMMA_START_DOCSTRING, +) +class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel): + def __init__(self, config: PaliGemmaConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.multi_modal_projector = PaliGemmaMultiModalProjector(config) + self.vocab_size = config.vocab_size + self._attn_implementation = config._attn_implementation + + language_model = AutoModelForCausalLM.from_config( + config=config.text_config, attn_implementation=self._attn_implementation + ) + + if language_model._tied_weights_keys is not None: + self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] + self.language_model = language_model + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma + def get_decoder(self): + return self.language_model.get_decoder() + + # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights with Llava->PaliGemma + def tie_weights(self): + return self.language_model.tie_weights() + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + _, _, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + scaled_image_features = image_features / (self.config.hidden_size**0.5) + final_embedding = torch.zeros( + batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + + text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id) + image_mask = input_ids == self.config.image_token_index + pad_mask = input_ids == self.pad_token_id + + # expand masks to match embedding dimension + text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim) + pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim) + # insert padding and text token embeddings + final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding) + final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding) + # insert image embeddings - the image mask is always less or equal to the sentence in length + final_embedding = final_embedding.masked_scatter( + image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features + ) + final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding) + + final_attention_mask_4d = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1) + final_attention_mask_4d = final_attention_mask_4d.float().expand( + -1, self.config.text_config.num_key_value_heads, -1, -1 + ) + + # position_ids = torch.arange(0, sequence_length, device=input_ids.device).expand(batch_size, -1) + # position_ids = torch.where(input_ids == self.pad_token_id, torch.ones_like(position_ids), position_ids) + position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1) + + if labels is not None: + final_labels = torch.full( + (batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) + final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels) + else: + final_labels = None + return final_embedding, final_attention_mask_4d, final_labels, position_ids + + @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration + + >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf") + >>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf") + + >>> prompt = "answer en Where is the cow standing?" + >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "answer en Where is the cow standing?\nbeach" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # the attention mask is turned 4d after, we keep track of the original one + input_attention_mask = attention_mask + + if inputs_embeds is None: + # 1. Extra the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and images + if pixel_values is not None and input_ids.shape[1] != 1: + image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype)) + selected_image_feature = image_outputs.last_hidden_state + image_features = self.multi_modal_projector(selected_image_feature) + + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + + else: + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + # TODO @molbap this will only work for dynamic cache. + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_seqlen = cache_position[-1] + 1 + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses PaliGemma+ Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + attention_mask = attention_mask.to(inputs_embeds.dtype) + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + logits = outputs.logits + logits = logits.float() + loss = None + if labels is not None: + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if input_attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + shift_attention_mask = input_attention_mask[..., 1:] + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(logits.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return PaliGemmaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + pixel_values=None, + attention_mask=None, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + # here we need to recall past_length is num_image_tokens + previous input_ids. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + elif self.config.image_token_index in input_ids: + input_ids = input_ids[:, input_ids.shape[1] - 1 :] + # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the + # older attention values, as their corresponding values are not part of the input. + if cache_length < past_length and attention_mask is not None: + attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "cache_position": cache_position, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + } + ) + return model_inputs + + def _reorder_cache(self, *args, **kwargs): + return self.language_model._reorder_cache(*args, **kwargs) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py new file mode 100644 index 00000000000000..76a9aaf705e8ff --- /dev/null +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -0,0 +1,268 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for PaliGemma. +""" + + +import logging +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, is_valid_image +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import AddedToken, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType + + +logger = logging.getLogger(__name__) + +IMAGE_TOKEN = "" + + +# Copied from transformers.models.idefics2.processing_idefics2.is_url +def is_url(val) -> bool: + return isinstance(val, str) and val.startswith("http") + + +# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url +def is_image_or_image_url(elem): + return is_url(elem) or is_valid_image(elem) + + +def _is_str_or_image(elem): + return isinstance(elem, (str)) or is_image_or_image_url(elem) + + +def build_string_from_input(prompt, bos_token, image_seq_len, image_token): + """ + Builds a string from the input prompt and image tokens. + For example, for the call: + build_string_from_input( + prompt="Prefix str" + bos_token="", + image_seq_len=3, + image_token="", + ) + The output will be: + "Initial str" + Args: + prompt (`List[Union[str, ImageInput]]`): The input prompt. + bos_token (`str`): The beginning of sentence token. + image_seq_len (`int`): The length of the image sequence. + image_token (`str`): The image token. + """ + return f"{image_token * image_seq_len}{bos_token}{prompt}" + + +class PaliGemmaProcessor(ProcessorMixin): + r""" + Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor. + + [`PaliGemmaProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`LlamaTokenizerFast`]. See the + [`~PaliGemmaProcessor.__call__`] and [`~PaliGemmaProcessor.decode`] for more information. + + Args: + image_processor ([`SiglipImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "SiglipImageProcessor" + tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None): + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + if not hasattr(image_processor, "image_seq_length"): + raise ValueError("Image processor is missing an `image_seq_length` attribute.") + + self.image_seq_length = image_processor.image_seq_length + + image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True) + tokens_to_add = {"additional_special_tokens": [image_token]} + tokenizer.add_special_tokens(tokens_to_add) + self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + + super().__init__(image_processor, tokenizer) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + tokenize_newline_separately: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length=None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + do_resize: bool = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821 + input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821 + resample: "PILImageResampling" = None, # noqa: F821 + do_convert_rgb: bool = None, + do_thumbnail: bool = None, + do_align_long_axis: bool = None, + do_rescale: bool = None, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + tokenize_newline_separately (`bool`, defaults to `True`): + Adds a separately tokenized '\n' at the end of the prompt. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if images is None: + raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.") + if text is None: + logger.warning_once( + "You are using PaliGemma without a text prefix. It will perform as a picture-captioning model." + ) + + if isinstance(text, List) and isinstance(images, List): + if len(images) < len(text): + raise ValueError( + f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image." + ) + if _is_str_or_image(text): + text = [text] + elif isinstance(text, list) and _is_str_or_image(text[0]): + pass + input_strings = [ + build_string_from_input( + prompt=prompt, + bos_token=self.tokenizer.bos_token, + image_seq_len=self.image_seq_length, + image_token=IMAGE_TOKEN, + ) + for prompt in text + ] + + pixel_values = self.image_processor( + images, + do_resize=do_resize, + do_normalize=do_normalize, + return_tensors=return_tensors, + image_mean=image_mean, + image_std=image_std, + input_data_format=input_data_format, + data_format=data_format, + resample=resample, + do_convert_rgb=do_convert_rgb, + )["pixel_values"] + + if max_length is not None: + max_length += self.image_seq_length # max_length has to account for the image tokens + + if tokenize_newline_separately: + inputs = self.tokenizer( + input_strings, + add_special_tokens=False, + return_tensors=None, + padding="do_not_pad", + max_length=max_length, + truncation=truncation, + ) + newline_token = self.tokenizer.convert_tokens_to_ids("\n") + concatenated_ids = [ids + [newline_token] for ids in inputs["input_ids"]] + concatenated_attention_masks = [mask + [1] for mask in inputs["attention_mask"]] + + text_inputs = self.tokenizer.pad( + {"input_ids": concatenated_ids, "attention_mask": concatenated_attention_masks}, + max_length=max_length, + padding=padding, + return_tensors=return_tensors, + ) + else: + text_inputs = self.tokenizer( + input_strings, + add_special_tokens=False, + return_tensors=return_tensors, + padding=padding, + max_length=max_length, + truncation=truncation, + ) + + return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/src/transformers/models/siglip/image_processing_siglip.py b/src/transformers/models/siglip/image_processing_siglip.py index 5f24ffb0a2a8b1..c624df3c751863 100644 --- a/src/transformers/models/siglip/image_processing_siglip.py +++ b/src/transformers/models/siglip/image_processing_siglip.py @@ -18,6 +18,7 @@ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( + convert_to_rgb, resize, to_channel_dimension_format, ) @@ -73,6 +74,8 @@ class SiglipImageProcessor(BaseImageProcessor): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. """ model_input_names = ["pixel_values"] @@ -87,6 +90,7 @@ def __init__( do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -102,6 +106,7 @@ def __init__( self.do_normalize = do_normalize self.image_mean = image_mean self.image_std = image_std + self.do_convert_rgb = do_convert_rgb self._valid_processor_keys = [ "images", "do_resize", @@ -115,6 +120,7 @@ def __init__( "return_tensors", "data_format", "input_data_format", + "do_convert_rgb", ] def preprocess( @@ -131,6 +137,7 @@ def preprocess( return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: bool = None, **kwargs, ) -> PIL.Image.Image: """ @@ -176,6 +183,8 @@ def preprocess( - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size @@ -186,6 +195,7 @@ def preprocess( do_normalize = do_normalize if do_normalize is not None else self.do_normalize image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb images = make_list_of_images(images) @@ -209,6 +219,9 @@ def preprocess( # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + if is_scaled_image(images[0]) and do_rescale: logger.warning_once( "It looks like you are trying to rescale already rescaled images. If the input" diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 5399006227e784..87fae3493bf9dd 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -881,7 +881,9 @@ def __init__(self, config: SiglipVisionConfig): self.embeddings = SiglipVisionEmbeddings(config) self.encoder = SiglipEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - self.head = SiglipMultiheadAttentionPoolingHead(config) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead(config) @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) @@ -915,14 +917,13 @@ def forward( last_hidden_state = encoder_outputs[0] last_hidden_state = self.post_layernorm(last_hidden_state) - pooled_output = self.head(last_hidden_state) - + pooler_output = self.head(last_hidden_state) if self.use_head else None if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] + return (last_hidden_state, pooler_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, - pooler_output=pooled_output, + pooler_output=pooler_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @@ -959,6 +960,7 @@ def forward(self, hidden_state): class SiglipVisionModel(SiglipPreTrainedModel): config_class = SiglipVisionConfig main_input_name = "pixel_values" + _no_split_modules = ["SiglipVisionTransformer"] def __init__(self, config: SiglipVisionConfig): super().__init__(config) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 9570e044d362d7..507cd8f0ae887a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6135,6 +6135,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class PaliGemmaForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PaliGemmaPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PaliGemmaProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class PatchTSMixerForPrediction(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/paligemma/__init__.py b/tests/models/paligemma/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py new file mode 100644 index 00000000000000..0653e1057cb667 --- /dev/null +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -0,0 +1,426 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch PaliGemma model. """ + +import gc +import unittest + +import requests +from parameterized import parameterized + +from transformers import ( + PaliGemmaConfig, + PaliGemmaForConditionalGeneration, + PaliGemmaProcessor, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + require_bitsandbytes, + require_torch, + require_torch_sdpa, + slow, + torch_device, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch +else: + is_torch_greater_or_equal_than_2_0 = False + +if is_vision_available(): + from PIL import Image + + +class PaliGemmaVisionText2TextModelTester: + def __init__( + self, + parent, + ignore_index=-100, + image_token_index=98, + projector_hidden_act="gelu", + seq_length=7, + vision_feature_select_strategy="default", + vision_feature_layer=-1, + projection_dim=32, + text_config={ + "model_type": "gemma", + "seq_length": 128, + "is_training": True, + # "use_input_mask": True, + "use_token_type_ids": False, + "use_labels": True, + "vocab_size": 99, + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 1, + "head_dim": 8, + "intermediate_size": 37, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 16, + "type_sequence_label_size": 2, + "initializer_range": 0.02, + "num_labels": 3, + "num_choices": 4, + "pad_token_id": 0, + }, + is_training=True, + vision_config={ + "use_labels": True, + "image_size": 30, + "patch_size": 2, + "num_image_tokens": 4, + "num_channels": 3, + "is_training": True, + "hidden_size": 32, + "projection_dim": 32, + "num_key_value_heads": 1, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "dropout": 0.1, + "attention_dropout": 0.1, + "initializer_range": 0.02, + }, + use_cache=False, + ): + self.parent = parent + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.text_config = text_config + self.vision_config = vision_config + self.seq_length = seq_length + self.projection_dim = projection_dim + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + self.is_training = is_training + + self.batch_size = 3 + self.num_channels = vision_config["num_channels"] + self.image_size = vision_config["image_size"] + self.encoder_seq_length = seq_length + self.use_cache = use_cache + + def get_config(self): + return PaliGemmaConfig( + text_config=self.text_config, + vision_config=self.vision_config, + ignore_index=self.ignore_index, + image_token_index=self.image_token_index, + projector_hidden_act=self.projector_hidden_act, + projection_dim=self.projection_dim, + vision_feature_select_strategy=self.vision_feature_select_strategy, + vision_feature_layer=self.vision_feature_layer, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config["num_channels"], + self.vision_config["image_size"], + self.vision_config["image_size"], + ] + ) + config = self.get_config() + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + attention_mask = input_ids.ne(1).to(torch_device) + # setting the 4 first tokens to be image + input_ids[:, :4] = config.image_token_index + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase): + """ + Model tester for `PaliGemmaForConditionalGeneration`. + """ + + all_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_torchscript = False + test_head_masking = False + + def setUp(self): + self.model_tester = PaliGemmaVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=PaliGemmaConfig, has_text_modality=False) + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_cpu_offload(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_disk_offload_bin(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_disk_offload_safetensors(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_model_parallelism(self): + pass + + @require_torch_sdpa + @slow + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + self.skipTest( + "Due to custom causal mask, there is a slightly too big difference between eager and sdpa in bfloat16." + ) + + @unittest.skip( + reason="PaliGemmma's SigLip encoder uses the same initialization scheme as the Flax original implementation" + ) + def test_initialization(self): + pass + + # TODO extend valid outputs to include this test @Molbap + @unittest.skip("PaliGemma has currently one output format.") + def test_model_outputs_equivalence(self): + pass + + # TODO fix the loss = nan in the testing configuration chosen @Molbap + @unittest.skip(reason="Edge case giving loss nan values in testing configuration.") + def test_determinism(self): + pass + + @unittest.skip(reason="PaliGemma does not use feedforward chunking.") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="PaliGemma does not support low_cpu_mem_usage.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip(reason="PaliGemma does not support low_cpu_mem_usage.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip(reason="PaliGemma does not support low_cpu_mem_usage.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + + +@slow +@require_torch +class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = PaliGemmaProcessor.from_pretrained("gv-hf/PaliGemma-test-224px-hf") + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + @slow + @require_bitsandbytes + def test_small_model_integration_test(self): + # Let' s make sure we test the preprocessing to replace what is used + model = PaliGemmaForConditionalGeneration.from_pretrained("gv-hf/PaliGemma-test-224px-hf") + prompt = "" + image_file = ( + "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" + ) + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt") + # fmt: off + EXPECTED_INPUT_IDS = torch.tensor([[256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, + 256000, 256000, 256000, 256000, 2, 108]]) + # fmt: on + self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) + + output = model.generate(**inputs, max_new_tokens=20) + EXPECTED_DECODED_TEXT = "\ncow standing on the beach" # fmt: skip + + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_paligemma(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "gv-hf/PaliGemma-test-224px-hf" + + model = PaliGemmaForConditionalGeneration.from_pretrained("gv-hf/PaliGemma-test-224px-hf") + processor = PaliGemmaProcessor.from_pretrained(model_id) + + prompt = "answer en Where is the cow standing?" + image_file = ( + "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" + ) + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor(text=prompt, images=raw_image, return_tensors="pt").to(torch.float16) + + output = model.generate(**inputs, max_new_tokens=900, do_sample=False) + EXPECTED_DECODED_TEXT = "answer en Where is the cow standing?\nbeach" # fmt: skip + + self.assertEqual( + processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_paligemma_batched(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "gv-hf/PaliGemma-test-224px-hf" + + model = PaliGemmaForConditionalGeneration.from_pretrained(model_id) + processor = PaliGemmaProcessor.from_pretrained(model_id) + + prompts = [ + "answer en Where is the cow standing?", + "", + ] + image1 = Image.open( + requests.get( + "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", + stream=True, + ).raw + ) + image2 = image1 + + inputs = processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow standing on the beach"] # fmt: skip + + self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_batch(self): + # Let' s make sure we test the preprocessing to replace what is used + model = PaliGemmaForConditionalGeneration.from_pretrained("gv-hf/PaliGemma-test-224px-hf") + # The first batch is longer in terms of text, the second will be padded. + prompts = [ + "answer en Where is the cow standing?", + "", + ] + image1 = Image.open( + requests.get( + "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", + stream=True, + ).raw + ) + image2 = image1 + + inputs = self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow standing on the beach"] # fmt: skip + self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + + @slow + @require_bitsandbytes + def test_paligemma_index_error_bug(self): + # This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore + # Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for + # more details + model_id = "gv-hf/PaliGemma-test-224px-hf" + model = PaliGemmaForConditionalGeneration.from_pretrained(model_id) + + processor = PaliGemmaProcessor.from_pretrained(model_id) + + # Simulate a super long prompt + prompt = "\n" * 200 + image_file = ( + "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" + ) + + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor( + text=prompt, + images=raw_image, + return_tensors="pt", + ).to(torch.float16) + + # Make sure that `generate` works + _ = model.generate(**inputs, max_new_tokens=20) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cd46934b5fcfe4..3f103eb3eeae07 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3948,6 +3948,10 @@ def test_sdpa_can_dispatch_on_flash(self): inputs_dict = self._prepare_for_class(inputs_dict, model_class) if config.model_type in ["llava", "llava_next", "vipllava"]: self.skipTest("Llava-like models currently (transformers==4.39.1) requires an attention_mask input") + if config.model_type in ["paligemma"]: + self.skipTest( + "PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input" + ) if config.model_type in ["idefics"]: self.skipTest("Idefics currently (transformers==4.39.1) requires an image_attention_mask input") model = model_class(config)