Skip to content

Commit

Permalink
Add PaliGemma (#30814)
Browse files Browse the repository at this point in the history
* add new model like

* add state dict slicing + new model config

* update palma config and weights, passes vision activations

* fix

* update

* reorder loading/unpacking

* clean up

* add debug statements

* change device

* fix

* debugging

* fix noncausal mask

* fixup sdpa + causal mask

* fix activation function

* remove debug before changing modeling file

* add variants

* debug attention mask in generate

* revert to non-debug sdpa

* revert gemma modifications

* add custom language modeling

* use Processor

* add language modeling file to init

* try thin wrapper around generate

* Update

* update mask

* breakpoints galore

* remove conflict

* switch to left-padding

* add incomplete model doc

* add paligemma global files

* batch rename paligemma

* make generation match outputs and captioning

* style

* style

* remove copied from + doc

* remove more copied from

* remove copy from projector

* minor fix

* update config and style

* add readme - dummy

* CORRECT image captioning

* moving to args

* add siglip proper + fix merging image + text features

* take update_causal_mask from upstream

* remove breakpoint

* leverage AutoModel

* fix input_ids slicing

* make siglip head conditional

* remove encoder_decoder value

* remove unneeded modeling file

* add commented 4d attention mask

* FIXED generation with 4D mask

* Update src/transformers/models/siglip/modeling_siglip.py

Co-authored-by: Arthur <[email protected]>

* fix left padding detection

* shuffle order of verifications

* fix missing labels for training

* fix

* vectorize merging of features, improve slicing

* improve testing before conversion

* handle merging in processor

* image token index depends on checkpoint

* add variants, save processor too

* save processors, base tokenizer off spm file

* expand model embeddings due to additional image token

* pass image processing args

* add convert rgb to siglip processor

* add \n token separately

* fix tokenizer and prompts

* fix docstrings

* change to camel

* fix casing

* debug pos_ids and sdpa

* pass and use cache_position

* add flag for newline tokenization

* Update src/transformers/models/paligemma/processing_paligemma.py

Co-authored-by: Merve Noyan <[email protected]>

* simplify conversion script

* add copied from

* add precision to conversion script

* Update src/transformers/models/paligemma/modeling_paligemma.py

Co-authored-by: Pedro Cuenca <[email protected]>

* clean up

* Shift attention mask from `1:`

After discussion with @molbap

* add docs, fix quality

* quality, tied weights inheritance, and logits/label alignment

* fix more tests

* pass attn_implementation to language model correctly

* add SiglipVisionTransformer to no split modules

* skip paligemma test for sdpa dispatch to flash

* skip incompatible tests

* quality

* [broken archive maps]

* Apply suggestions

- remove archive lists
- style
- take shape of inputs_embeds for batch

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/utils/dummy_pt_objects.py

Co-authored-by: Arthur <[email protected]>

* simplify conversion script

* add suggestions

* add suggestions

* add copied from

* fix

* move labels out

* revert

* fix

* remove placeholder labels if None

* use cache_position

* fix quality + docstrings

* fix quality

* fix paligemma 4d gemma mask incompatibility

* fix config docstring

* fix query and attn_mask dtype

---------

Co-authored-by: ArthurZucker <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Merve Noyan <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
  • Loading branch information
5 people authored May 14, 2024
1 parent c96aca3 commit 1360801
Show file tree
Hide file tree
Showing 23 changed files with 1,890 additions and 8 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,8 @@
title: OWL-ViT
- local: model_doc/owlv2
title: OWLv2
- local: model_doc/paligemma
title: PaliGemma
- local: model_doc/perceiver
title: Perceiver
- local: model_doc/pix2struct
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ Flax), PyTorch, and/or TensorFlow.
| [OPT](model_doc/opt) ||||
| [OWL-ViT](model_doc/owlvit) ||||
| [OWLv2](model_doc/owlv2) ||||
| [PaliGemma](model_doc/paligemma) ||||
| [PatchTSMixer](model_doc/patchtsmixer) ||||
| [PatchTST](model_doc/patchtst) ||||
| [Pegasus](model_doc/pegasus) ||||
Expand Down
38 changes: 38 additions & 0 deletions docs/source/en/model_doc/paligemma.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# PaliGemma

## Overview

The PaliGemma model was proposed by Google. It is a 3B VLM composed by a Siglip-400m vision encoder and a Gemma-2B decoder linked by a multimodal linear projection. It is not a chat model with images. It cuts an image into a fixed number of VIT tokens and prepends it to an optional prompt. One particularity is that the model uses full block attention on all the image tokens plus the input text tokens. It comes in 3 resolutions, 224x224, 448x448 and 896x896 with 3 base models, with 55 fine-tuned versions for different tasks, and 2 mix models.


This model was contributed by [Molbap](https://huggingface.co/Molbap).


## PaliGemmaConfig

[[autodoc]] PaliGemmaConfig

## PaliGemmaProcessor

[[autodoc]] PaliGemmaProcessor

## PaliGemmaForConditionalGeneration

[[autodoc]] PaliGemmaForConditionalGeneration
- forward
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@
"OwlViTTextConfig",
"OwlViTVisionConfig",
],
"models.paligemma": ["PaliGemmaConfig"],
"models.patchtsmixer": ["PatchTSMixerConfig"],
"models.patchtst": ["PatchTSTConfig"],
"models.pegasus": [
Expand Down Expand Up @@ -2651,6 +2652,13 @@
"OwlViTVisionModel",
]
)
_import_structure["models.paligemma"].extend(
[
"PaliGemmaForConditionalGeneration",
"PaliGemmaPreTrainedModel",
"PaliGemmaProcessor",
]
)
_import_structure["models.patchtsmixer"].extend(
[
"PatchTSMixerForPrediction",
Expand Down Expand Up @@ -5126,6 +5134,9 @@
OwlViTTextConfig,
OwlViTVisionConfig,
)
from .models.paligemma import (
PaliGemmaConfig,
)
from .models.patchtsmixer import (
PatchTSMixerConfig,
)
Expand Down Expand Up @@ -6956,6 +6967,11 @@
OwlViTTextModel,
OwlViTVisionModel,
)
from .models.paligemma import (
PaliGemmaForConditionalGeneration,
PaliGemmaPreTrainedModel,
PaliGemmaProcessor,
)
from .models.patchtsmixer import (
PatchTSMixerForPrediction,
PatchTSMixerForPretraining,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
opt,
owlv2,
owlvit,
paligemma,
patchtsmixer,
patchtst,
pegasus,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
("opt", "OPTConfig"),
("owlv2", "Owlv2Config"),
("owlvit", "OwlViTConfig"),
("paligemma", "PaliGemmaConfig"),
("patchtsmixer", "PatchTSMixerConfig"),
("patchtst", "PatchTSTConfig"),
("pegasus", "PegasusConfig"),
Expand Down Expand Up @@ -464,6 +465,7 @@
("opt", "OPT"),
("owlv2", "OWLv2"),
("owlvit", "OWL-ViT"),
("paligemma", "PaliGemma"),
("patchtsmixer", "PatchTSMixer"),
("patchtst", "PatchTST"),
("pegasus", "Pegasus"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
("oneformer", "OneFormerImageProcessor"),
("owlv2", "Owlv2ImageProcessor"),
("owlvit", "OwlViTImageProcessor"),
("paligemma", "CLIPImageProcessor"),
("perceiver", "PerceiverImageProcessor"),
("pix2struct", "Pix2StructImageProcessor"),
("poolformer", "PoolFormerImageProcessor"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@
("nezha", "NezhaForPreTraining"),
("nllb-moe", "NllbMoeForConditionalGeneration"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("retribert", "RetriBertModel"),
("roberta", "RobertaForMaskedLM"),
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
Expand Down Expand Up @@ -697,6 +698,7 @@
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"),
("vipllava", "VipLlavaForConditionalGeneration"),
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
("oneformer", "OneFormerProcessor"),
("owlv2", "Owlv2Processor"),
("owlvit", "OwlViTProcessor"),
("paligemma", "PaliGemmaProcessor"),
("pix2struct", "Pix2StructProcessor"),
("pop2piano", "Pop2PianoProcessor"),
("sam", "SamProcessor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@
("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
(
"pegasus",
(
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,6 @@ def _init_weights(self, module):
"The bare Gemma Model outputting raw hidden-states without any specific head on top.",
GEMMA_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMA,Llama->Gemma
class GemmaModel(GemmaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
Expand Down Expand Up @@ -988,8 +987,6 @@ def _update_causal_mask(

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
Expand Down
54 changes: 54 additions & 0 deletions src/transformers/models/paligemma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available


_import_structure = {"configuration_paligemma": ["PaliGemmaConfig"]}


try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_paligemma"] = [
"PaliGemmaForConditionalGeneration",
"PaliGemmaPreTrainedModel",
]
_import_structure["processing_paligemma"] = ["PaliGemmaProcessor"]


if TYPE_CHECKING:
from .configuration_paligemma import PaliGemmaConfig

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_paligemma import (
PaliGemmaForConditionalGeneration,
PaliGemmaPreTrainedModel,
)
from .processing_paligemma import PaliGemmaProcessor


else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
130 changes: 130 additions & 0 deletions src/transformers/models/paligemma/configuration_paligemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# coding=utf-8
# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PaliGemmamodel configuration"""

from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING


logger = logging.get_logger(__name__)


class PaliGemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PaliGemmaForConditionalGeneration`]. It is used to instantiate an
PaliGemmamodel according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
e.g. [paligemma-hf/paligemma-2b](https://huggingface.co/paligemma-hf/paligemma-2b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`PaliGemmaVisionConfig`, *optional*):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 256000):
The image token index to encode the image prompt.
vocab_size (`int`, *optional*, defaults to 257152):
Vocabulary size of the PaliGemmamodel. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~PaliGemmaForConditionalGeneration`]
projection_dim (`int`, *optional*, defaults to 2048):
Dimension of the multimodal projection space.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden layer of the Language model.
Example:
```python
>>> from transformers import PaliGemmaForConditionalGeneration, PaliGemmaConfig, SiglipVisionConfig, GemmaConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a PaliGemma config
>>> text_config = GemmaConfig()
>>> # Initializing a PaliGemma paligemma-3b-224 style configuration
>>> configuration = PaliGemmaConfig(vision_config, text_config)
>>> # Initializing a model from the paligemma-3b-224 style configuration
>>> model = PaliGemmaForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "paligemma"
is_composition = False

def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=256000,
vocab_size=257152,
projection_dim=2048,
hidden_size=2048,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.vocab_size = vocab_size
self.projection_dim = projection_dim
self.hidden_size = hidden_size
self.vision_config = vision_config
self.is_encoder_decoder = False

if isinstance(self.vision_config, dict):
vision_config["model_type"] = (
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
)
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
self.vision_config = CONFIG_MAPPING["siglip_vision_model"](
intermediate_size=4096,
hidden_size=1152,
patch_size=14,
image_size=224,
num_hidden_layers=27,
num_attention_heads=16,
vocab_size=257152,
vision_use_head=False,
)
self.vocab_size = self.vocab_size

self.text_config = text_config

if isinstance(self.text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
self.vocab_size = self.text_config.vocab_size
elif text_config is None:
self.text_config = CONFIG_MAPPING["gemma"](
hidden_size=2048,
num_hidden_layers=18,
intermediate_size=16384,
num_attention_heads=8,
num_key_value_heads=1,
is_encoder_decoder=False,
)
self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
self.vision_config.projection_dim = projection_dim
super().__init__(**kwargs)
Loading

0 comments on commit 1360801

Please sign in to comment.