Skip to content

Commit

Permalink
Add support for Pixtral (#33449)
Browse files Browse the repository at this point in the history
* initial commit

* gloups

* updates

* work

* weights match

* nits

* nits

* updates to support the tokenizer :)

* updates

* Pixtral processor (#33454)

* rough outline

* Add in image break and end tokens

* Fix

* Udo some formatting changes

* Set patch_size default

* Fix

* Fix token expansion

* nit in conversion script

* Fix image token list creation

* done

* add expected results

* Process list of list of images (#33465)

* updates

* working image and processor

* this is the expected format

* some fixes

* push current updated

* working mult images!

* add a small integration test

* Uodate configuration docstring

* Formatting

* Config docstring fix

* simplify model test

* fixup modeling and etests

* Return BatchMixFeature in image processor

* fix some copies

* update

* nits

* Update model docstring

* Apply suggestions from code review

* Fix up

* updates

* revert modeling changes

* update

* update

* fix load safe

* addd liscence

* update

* use pixel_values as required by the model

* skip some tests and refactor

* Add pixtral image processing tests (#33476)

* Image processing tests

* Add processing tests

* woops

* defaults reflect pixtral image processor

* fixup post merge

* images -> pixel values

* oups sorry Mr docbuilder

* isort

* fix

* fix processor tests

* small fixes

* nit

* update

* last nits

* oups this was really breaking!

* nits

* is composition needs to be true

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
ArthurZucker and amyeroberts authored Sep 14, 2024
1 parent 7bb1c99 commit 8bd2b1e
Show file tree
Hide file tree
Showing 24 changed files with 2,707 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,8 @@
title: Perceiver
- local: model_doc/pix2struct
title: Pix2Struct
- local: model_doc/pixtral
title: Pixtral
- local: model_doc/sam
title: Segment Anything
- local: model_doc/siglip
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Phi3](model_doc/phi3) ||||
| [PhoBERT](model_doc/phobert) ||||
| [Pix2Struct](model_doc/pix2struct) ||||
| [Pixtral](model_doc/pixtral) ||||
| [PLBart](model_doc/plbart) ||||
| [PoolFormer](model_doc/poolformer) ||||
| [Pop2Piano](model_doc/pop2piano) ||||
Expand Down
98 changes: 98 additions & 0 deletions docs/source/en/model_doc/pixtral.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Pixtral

## Overview

The Pixtral model was released by the Mistral AI team on [Vllm](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found!


Tips:

- Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized)
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
- The format for one or mulitple prompts is the following:
```
"<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
```
Then, the processor will replace each `[IMG]` token with a number of `[IMG]` token that depends on the height and the width of the image. Each *row* of the image is separated by a `[IMG_BREAK]` token, and each image is separated by a `[IMG_END]` token.

This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ)

Here is an example of how to run it:

```python
from transformers import LlavaForConditionalGeneration, AutoProcessor
from PIL import Image

model_id = "hf-internal-testing/pixtral-12b"
model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cuda")
processor = AutoProcessor.from_pretrained(model_id)

IMG_URLS = [
"https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/231/200/300",
"https://picsum.photos/id/27/500/500",
"https://picsum.photos/id/17/150/600",
]
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"

inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
generate_ids = model.generate(**inputs, max_new_tokens=500)
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

EXPECTED_GENERATION = """
Describe the images.
Sure, let's break down each image description:
1. **Image 1:**
- **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera.
- **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur.
2. **Image 2:**
- **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley.
- **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image.
3. **Image 3:**
- **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset.
- **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene.
4. **Image 4:**
- **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers.
- **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden.
Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it.
"""

```
## PixtralVisionConfig

[[autodoc]] PixtralVisionConfig

## PixtralModel

[[autodoc]] PixtralModel
- forward

## PixtralImageProcessor

[[autodoc]] PixtralImageProcessor
- preprocess

## PixtralProcessor

[[autodoc]] PixtralProcessor
13 changes: 12 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@
"Pix2StructTextConfig",
"Pix2StructVisionConfig",
],
"models.pixtral": ["PixtralProcessor", "PixtralVisionConfig"],
"models.plbart": ["PLBartConfig"],
"models.poolformer": ["PoolFormerConfig"],
"models.pop2piano": ["Pop2PianoConfig"],
Expand Down Expand Up @@ -1199,6 +1200,7 @@
_import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"])
_import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"])
_import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
_import_structure["models.pixtral"].append("PixtralImageProcessor")
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
_import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"])
Expand Down Expand Up @@ -1359,7 +1361,6 @@
"AlignVisionModel",
]
)

_import_structure["models.altclip"].extend(
[
"AltCLIPModel",
Expand Down Expand Up @@ -2977,6 +2978,7 @@
"Pix2StructVisionModel",
]
)
_import_structure["models.pixtral"].extend(["PixtralModel", "PixtralPreTrainedModel"])
_import_structure["models.plbart"].extend(
[
"PLBartForCausalLM",
Expand Down Expand Up @@ -5434,6 +5436,10 @@
Pix2StructTextConfig,
Pix2StructVisionConfig,
)
from .models.pixtral import (
PixtralProcessor,
PixtralVisionConfig,
)
from .models.plbart import PLBartConfig
from .models.poolformer import (
PoolFormerConfig,
Expand Down Expand Up @@ -6009,6 +6015,7 @@
from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor
from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor
from .models.pix2struct import Pix2StructImageProcessor
from .models.pixtral import PixtralImageProcessor
from .models.poolformer import (
PoolFormerFeatureExtractor,
PoolFormerImageProcessor,
Expand Down Expand Up @@ -7448,6 +7455,10 @@
Pix2StructTextModel,
Pix2StructVisionModel,
)
from .models.pixtral import (
PixtralModel,
PixtralPreTrainedModel,
)
from .models.plbart import (
PLBartForCausalLM,
PLBartForConditionalGeneration,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@
phi3,
phobert,
pix2struct,
pixtral,
plbart,
poolformer,
pop2piano,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
("phi", "PhiConfig"),
("phi3", "Phi3Config"),
("pix2struct", "Pix2StructConfig"),
("pixtral", "PixtralVisionConfig"),
("plbart", "PLBartConfig"),
("poolformer", "PoolFormerConfig"),
("pop2piano", "Pop2PianoConfig"),
Expand Down Expand Up @@ -509,6 +510,7 @@
("phi3", "Phi3"),
("phobert", "PhoBERT"),
("pix2struct", "Pix2Struct"),
("pixtral", "Pixtral"),
("plbart", "PLBart"),
("poolformer", "PoolFormer"),
("pop2piano", "Pop2Piano"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
("owlvit", ("OwlViTImageProcessor",)),
("perceiver", ("PerceiverImageProcessor",)),
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor",)),
("poolformer", ("PoolFormerImageProcessor",)),
("pvt", ("PvtImageProcessor",)),
("pvt_v2", ("PvtImageProcessor",)),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
("persimmon", "PersimmonModel"),
("phi", "PhiModel"),
("phi3", "Phi3Model"),
("pixtral", "PixtralModel"),
("plbart", "PLBartModel"),
("poolformer", "PoolFormerModel"),
("prophetnet", "ProphetNetModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
("owlvit", "OwlViTProcessor"),
("paligemma", "PaliGemmaProcessor"),
("pix2struct", "Pix2StructProcessor"),
("pixtral", "PixtralProcessor"),
("pop2piano", "Pop2PianoProcessor"),
("qwen2_audio", "Qwen2AudioProcessor"),
("qwen2_vl", "Qwen2VLProcessor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
("pixtral", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llava/configuration_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class LlavaConfig(PretrainedConfig):
```"""

model_type = "llava"
is_composition = False
is_composition = True

def __init__(
self,
Expand Down
70 changes: 70 additions & 0 deletions src/transformers/models/pixtral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available


_import_structure = {
"configuration_pixtral": ["PixtralVisionConfig"],
"processing_pixtral": ["PixtralProcessor"],
}


try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_pixtral"] = [
"PixtralModel",
"PixtralPreTrainedModel",
]

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"]


if TYPE_CHECKING:
from .configuration_pixtral import PixtralProcessor, PixtralVisionConfig

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_pixtral import (
PixtralModel,
PixtralPreTrainedModel,
)

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_pixtral import PixtralImageProcessor

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
Loading

0 comments on commit 8bd2b1e

Please sign in to comment.