From 8bd2b1e8c23234cd607ca8d63f53c1edfea27462 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Sat, 14 Sep 2024 12:28:39 +0200 Subject: [PATCH] Add support for Pixtral (#33449) * initial commit * gloups * updates * work * weights match * nits * nits * updates to support the tokenizer :) * updates * Pixtral processor (#33454) * rough outline * Add in image break and end tokens * Fix * Udo some formatting changes * Set patch_size default * Fix * Fix token expansion * nit in conversion script * Fix image token list creation * done * add expected results * Process list of list of images (#33465) * updates * working image and processor * this is the expected format * some fixes * push current updated * working mult images! * add a small integration test * Uodate configuration docstring * Formatting * Config docstring fix * simplify model test * fixup modeling and etests * Return BatchMixFeature in image processor * fix some copies * update * nits * Update model docstring * Apply suggestions from code review * Fix up * updates * revert modeling changes * update * update * fix load safe * addd liscence * update * use pixel_values as required by the model * skip some tests and refactor * Add pixtral image processing tests (#33476) * Image processing tests * Add processing tests * woops * defaults reflect pixtral image processor * fixup post merge * images -> pixel values * oups sorry Mr docbuilder * isort * fix * fix processor tests * small fixes * nit * update * last nits * oups this was really breaking! * nits * is composition needs to be true --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/pixtral.md | 98 ++++ src/transformers/__init__.py | 13 +- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + .../models/llava/configuration_llava.py | 2 +- src/transformers/models/pixtral/__init__.py | 70 +++ .../models/pixtral/configuration_pixtral.py | 103 ++++ .../pixtral/convert_pixtral_weights_to_hf.py | 285 ++++++++++ .../pixtral/image_processing_pixtral.py | 519 ++++++++++++++++++ .../models/pixtral/modeling_pixtral.py | 517 +++++++++++++++++ .../models/pixtral/processing_pixtral.py | 282 ++++++++++ src/transformers/utils/dummy_pt_objects.py | 14 + .../utils/dummy_vision_objects.py | 7 + tests/models/llava/test_modeling_llava.py | 47 ++ tests/models/pixtral/__init__.py | 0 .../pixtral/test_image_processing_pixtral.py | 217 ++++++++ tests/models/pixtral/test_modeling_pixtral.py | 292 ++++++++++ .../models/pixtral/test_processor_pixtral.py | 233 ++++++++ 24 files changed, 2707 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/model_doc/pixtral.md create mode 100644 src/transformers/models/pixtral/__init__.py create mode 100644 src/transformers/models/pixtral/configuration_pixtral.py create mode 100644 src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py create mode 100644 src/transformers/models/pixtral/image_processing_pixtral.py create mode 100644 src/transformers/models/pixtral/modeling_pixtral.py create mode 100644 src/transformers/models/pixtral/processing_pixtral.py create mode 100644 tests/models/pixtral/__init__.py create mode 100644 tests/models/pixtral/test_image_processing_pixtral.py create mode 100644 tests/models/pixtral/test_modeling_pixtral.py create mode 100644 tests/models/pixtral/test_processor_pixtral.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 1c7f62ec6ea7b8..235ea81a7f1ea6 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -862,6 +862,8 @@ title: Perceiver - local: model_doc/pix2struct title: Pix2Struct + - local: model_doc/pixtral + title: Pixtral - local: model_doc/sam title: Segment Anything - local: model_doc/siglip diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 8e3a4da8b021de..c18426de4c031c 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -253,6 +253,7 @@ Flax), PyTorch, and/or TensorFlow. | [Phi3](model_doc/phi3) | ✅ | ❌ | ❌ | | [PhoBERT](model_doc/phobert) | ✅ | ✅ | ✅ | | [Pix2Struct](model_doc/pix2struct) | ✅ | ❌ | ❌ | +| [Pixtral](model_doc/pixtral) | ❌ | ❌ | ❌ | | [PLBart](model_doc/plbart) | ✅ | ❌ | ❌ | | [PoolFormer](model_doc/poolformer) | ✅ | ❌ | ❌ | | [Pop2Piano](model_doc/pop2piano) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/pixtral.md b/docs/source/en/model_doc/pixtral.md new file mode 100644 index 00000000000000..8df2bf5af5f9ca --- /dev/null +++ b/docs/source/en/model_doc/pixtral.md @@ -0,0 +1,98 @@ + + +# Pixtral + +## Overview + +The Pixtral model was released by the Mistral AI team on [Vllm](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found! + + +Tips: + +- Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized) +- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders. +- The format for one or mulitple prompts is the following: +``` +"[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]" +``` +Then, the processor will replace each `[IMG]` token with a number of `[IMG]` token that depends on the height and the width of the image. Each *row* of the image is separated by a `[IMG_BREAK]` token, and each image is separated by a `[IMG_END]` token. + +This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ) + +Here is an example of how to run it: + +```python +from transformers import LlavaForConditionalGeneration, AutoProcessor +from PIL import Image + +model_id = "hf-internal-testing/pixtral-12b" +model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cuda") +processor = AutoProcessor.from_pretrained(model_id) + +IMG_URLS = [ + "https://picsum.photos/id/237/400/300", + "https://picsum.photos/id/231/200/300", + "https://picsum.photos/id/27/500/500", + "https://picsum.photos/id/17/150/600", +] +PROMPT = "[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]" + +inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda") +generate_ids = model.generate(**inputs, max_new_tokens=500) +ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + +EXPECTED_GENERATION = """ +Describe the images. +Sure, let's break down each image description: + +1. **Image 1:** + - **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera. + - **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur. + +2. **Image 2:** + - **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley. + - **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image. + +3. **Image 3:** + - **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset. + - **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene. + +4. **Image 4:** + - **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers. + - **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden. + +Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it. +""" + +``` +## PixtralVisionConfig + +[[autodoc]] PixtralVisionConfig + +## PixtralModel + +[[autodoc]] PixtralModel + - forward + +## PixtralImageProcessor + +[[autodoc]] PixtralImageProcessor + - preprocess + +## PixtralProcessor + +[[autodoc]] PixtralProcessor diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 00cc67915f3664..36775d8454ab8c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -649,6 +649,7 @@ "Pix2StructTextConfig", "Pix2StructVisionConfig", ], + "models.pixtral": ["PixtralProcessor", "PixtralVisionConfig"], "models.plbart": ["PLBartConfig"], "models.poolformer": ["PoolFormerConfig"], "models.pop2piano": ["Pop2PianoConfig"], @@ -1199,6 +1200,7 @@ _import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"]) _import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"]) _import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"]) + _import_structure["models.pixtral"].append("PixtralImageProcessor") _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"]) _import_structure["models.pvt"].extend(["PvtImageProcessor"]) _import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"]) @@ -1359,7 +1361,6 @@ "AlignVisionModel", ] ) - _import_structure["models.altclip"].extend( [ "AltCLIPModel", @@ -2977,6 +2978,7 @@ "Pix2StructVisionModel", ] ) + _import_structure["models.pixtral"].extend(["PixtralModel", "PixtralPreTrainedModel"]) _import_structure["models.plbart"].extend( [ "PLBartForCausalLM", @@ -5434,6 +5436,10 @@ Pix2StructTextConfig, Pix2StructVisionConfig, ) + from .models.pixtral import ( + PixtralProcessor, + PixtralVisionConfig, + ) from .models.plbart import PLBartConfig from .models.poolformer import ( PoolFormerConfig, @@ -6009,6 +6015,7 @@ from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor from .models.pix2struct import Pix2StructImageProcessor + from .models.pixtral import PixtralImageProcessor from .models.poolformer import ( PoolFormerFeatureExtractor, PoolFormerImageProcessor, @@ -7448,6 +7455,10 @@ Pix2StructTextModel, Pix2StructVisionModel, ) + from .models.pixtral import ( + PixtralModel, + PixtralPreTrainedModel, + ) from .models.plbart import ( PLBartForCausalLM, PLBartForConditionalGeneration, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 26b96def67d992..2022048cd4553f 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -187,6 +187,7 @@ phi3, phobert, pix2struct, + pixtral, plbart, poolformer, pop2piano, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index fa1a7fb88eafa8..2cd7d550d90b7a 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -205,6 +205,7 @@ ("phi", "PhiConfig"), ("phi3", "Phi3Config"), ("pix2struct", "Pix2StructConfig"), + ("pixtral", "PixtralVisionConfig"), ("plbart", "PLBartConfig"), ("poolformer", "PoolFormerConfig"), ("pop2piano", "Pop2PianoConfig"), @@ -509,6 +510,7 @@ ("phi3", "Phi3"), ("phobert", "PhoBERT"), ("pix2struct", "Pix2Struct"), + ("pixtral", "Pixtral"), ("plbart", "PLBart"), ("poolformer", "PoolFormer"), ("pop2piano", "Pop2Piano"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index c83c43518a6a31..95d9ddef8f7979 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -114,6 +114,7 @@ ("owlvit", ("OwlViTImageProcessor",)), ("perceiver", ("PerceiverImageProcessor",)), ("pix2struct", ("Pix2StructImageProcessor",)), + ("pixtral", ("PixtralImageProcessor",)), ("poolformer", ("PoolFormerImageProcessor",)), ("pvt", ("PvtImageProcessor",)), ("pvt_v2", ("PvtImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 45a9c4d0d078b7..e0d15f1e236590 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -193,6 +193,7 @@ ("persimmon", "PersimmonModel"), ("phi", "PhiModel"), ("phi3", "Phi3Model"), + ("pixtral", "PixtralModel"), ("plbart", "PLBartModel"), ("poolformer", "PoolFormerModel"), ("prophetnet", "ProphetNetModel"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 7f49e0e8d99730..82d325248eabfb 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -82,6 +82,7 @@ ("owlvit", "OwlViTProcessor"), ("paligemma", "PaliGemmaProcessor"), ("pix2struct", "Pix2StructProcessor"), + ("pixtral", "PixtralProcessor"), ("pop2piano", "Pop2PianoProcessor"), ("qwen2_audio", "Qwen2AudioProcessor"), ("qwen2_vl", "Qwen2VLProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index c8eb06db04a098..e735579108d857 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -385,6 +385,7 @@ ("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("phobert", ("PhobertTokenizer", None)), ("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), + ("pixtral", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)), ("prophetnet", ("ProphetNetTokenizer", None)), ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/llava/configuration_llava.py b/src/transformers/models/llava/configuration_llava.py index f2338a7c5a5df7..3a4cb09855f0ec 100644 --- a/src/transformers/models/llava/configuration_llava.py +++ b/src/transformers/models/llava/configuration_llava.py @@ -73,7 +73,7 @@ class LlavaConfig(PretrainedConfig): ```""" model_type = "llava" - is_composition = False + is_composition = True def __init__( self, diff --git a/src/transformers/models/pixtral/__init__.py b/src/transformers/models/pixtral/__init__.py new file mode 100644 index 00000000000000..e09ed8e60127dd --- /dev/null +++ b/src/transformers/models/pixtral/__init__.py @@ -0,0 +1,70 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_pixtral": ["PixtralVisionConfig"], + "processing_pixtral": ["PixtralProcessor"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_pixtral"] = [ + "PixtralModel", + "PixtralPreTrainedModel", + ] + +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"] + + +if TYPE_CHECKING: + from .configuration_pixtral import PixtralProcessor, PixtralVisionConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_pixtral import ( + PixtralModel, + PixtralPreTrainedModel, + ) + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_pixtral import PixtralImageProcessor + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/pixtral/configuration_pixtral.py b/src/transformers/models/pixtral/configuration_pixtral.py new file mode 100644 index 00000000000000..dcc1e458ca78a3 --- /dev/null +++ b/src/transformers/models/pixtral/configuration_pixtral.py @@ -0,0 +1,103 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pixtral model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PixtralVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PixtralModel`]. It is used to instantiate an + Pixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Pixtral-9B. + + e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of input channels in the input images. + image_size (`int`, *optional*, defaults to 1024): + Max dimension of the input images. + patch_size (`int`, *optional*, defaults to 16): + Size of the image patches. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + Activation function used in the hidden layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the attention layers. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie the word embeddings with the input embeddings. + + Example: + + ```python + >>> from transformers import PixtralModel, PixtralVisionConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a Pixtral 12B style configuration + >>> config = PixtralVisionConfig() + + >>> # Initializing a model from the pixtral 12B style configuration + >>> model = PixtralModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pixtral" + + def __init__( + self, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=24, + num_attention_heads=16, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + attention_dropout=0.0, + rope_theta=10000.0, + tie_word_embeddings=False, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.rope_theta = rope_theta + self.tie_word_embeddings = tie_word_embeddings + self.head_dim = hidden_size // num_attention_heads diff --git a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py new file mode 100644 index 00000000000000..c4190082d99471 --- /dev/null +++ b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py @@ -0,0 +1,285 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import regex as re +import torch +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from safetensors.torch import load_file as safe_load_file +from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers, processors +from tokenizers.models import BPE + +from transformers import ( + LlavaConfig, + LlavaForConditionalGeneration, + MistralConfig, + PixtralImageProcessor, + PixtralProcessor, + PixtralVisionConfig, + PreTrainedTokenizerFast, +) +from transformers.convert_slow_tokenizer import bytes_to_unicode + + +""" +# Here is how to get the original tokens! +model_name = "mistralai/Pixtral-12B-2409" +tok = MistralTokenizer.from_model(model_name) + +from mistral_common.protocol.instruct.request import ChatCompletionRequest, UserMessage, ImageChunk, TextChunk + +EXPECTED_TOKENS = tok.encode_chat_completion( + ChatCompletionRequest( + messages=[ + UserMessage( + content=[ + TextChunk(text="Describe the images"), + ] + [ImageChunk(image=img) for img in IMG_URLS] + ) + ], + model="pixtral", + ) +) +assert tokenizer.decode(inputs["input_ids"][0]) == EXPECTED_TOKENS +""" + +OLD_KEY_TO_NEW_KEY_MAPPING = { + # Layer Normalization Weights + r"vision_encoder.transformer.layers.(\d+).input_layernorm.weight": r"vision_tower.transformer.layers.\1.attention_norm.weight", + r"vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"vision_tower.transformer.layers.\1.ffn_norm.weight", + # Self Attention Projections + r"vision_encoder.transformer.layers.(\d+).attention.wq.weight": r"vision_tower.transformer.layers.\1.attention.q_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wk.weight": r"vision_tower.transformer.layers.\1.attention.k_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wv.weight": r"vision_tower.transformer.layers.\1.attention.v_proj.weight", + r"vision_encoder.transformer.layers.(\d+).attention.wo.weight": r"vision_tower.transformer.layers.\1.attention.o_proj.weight", + # MLP Projections + r"vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", + r"vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", + r"vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", + # Additional mappings + r"vision_encoder": r"vision_tower", + r"vision_language_adapter.w_in": r"multi_modal_projector.linear_1", + r"vision_language_adapter.w_out": r"multi_modal_projector.linear_2", + r"layers.(\d+).attention.wq.weight": r"language_model.model.layers.\1.self_attn.q_proj.weight", + r"layers.(\d+).attention.wk.weight": r"language_model.model.layers.\1.self_attn.k_proj.weight", + r"layers.(\d+).attention.wv.weight": r"language_model.model.layers.\1.self_attn.v_proj.weight", + r"layers.(\d+).attention.wo.weight": r"language_model.model.layers.\1.self_attn.o_proj.weight", + r"layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.mlp.gate_proj.weight", + r"layers.(\d+).feed_forward.w2.weight": r"language_model.model.layers.\1.mlp.down_proj.weight", + r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.mlp.up_proj.weight", + r"layers.(\d+).ffn_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight", + r"layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight", + r"tok_embeddings.weight": r"language_model.model.embed_tokens.weight", + r"output.weight": r"language_model.lm_head.weight", + r"norm.weight": r"language_model.model.norm.weight", +} + + +class MistralConverter: + """ + A general tiktoken converter. + """ + + def __init__( + self, + vocab=None, + pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""", + add_prefix_space=False, + additional_special_tokens=None, + *args, + **kwargs, + ): + super().__init__(*args) + self.vocab = vocab + self.pattern = pattern + self.add_prefix_space = add_prefix_space + self.additional_special_tokens = additional_special_tokens + + def extract_vocab_merges_from_model(self, vocab: str): + bpe_ranks = vocab + byte_encoder = bytes_to_unicode() + + def token_bytes_to_string(b): + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + + merges = [] + vocab = {} + for idx, (token, rank) in enumerate(bpe_ranks.items()): + if token not in self.additional_special_tokens: + vocab[token_bytes_to_string(token)] = idx + if len(token) == 1: + continue + local = [] + for index in range(1, len(token)): + piece_l, piece_r = token[:index], token[index:] + if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks: + local.append((piece_l, piece_r, rank)) + local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False) + merges.extend(local) + else: + vocab[token] = idx + merges = sorted(merges, key=lambda val: val[2], reverse=False) + merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges] + return vocab, merges + + def tokenizer(self): + vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab) + tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False)) + if hasattr(tokenizer.model, "ignore_merges"): + tokenizer.model.ignore_merges = True + return tokenizer + + def converted(self) -> Tokenizer: + tokenizer = self.tokenizer() + tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [ + pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False), + pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False), + ] + ) + tokenizer.decoder = decoders.ByteLevel() + tokenizer.add_special_tokens(self.additional_special_tokens) + + tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) + + return tokenizer + + +def convert_mistral_tokenizer(): + model_name = "mistralai/Pixtral-12B-2409" + + tokenizer = MistralTokenizer.from_model(model_name) + + vocab = tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial + all_special = [ + token.value if hasattr(token, "value") else token + for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens + ] + specials_tokens = {token: all_special.index(token) for token in all_special} + specials_tokens.update(vocab) + vocab = specials_tokens + + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), + bos_token="", + unk_token="", + eos_token="", + ) + tokenizer.model_input_names = ["input_ids", "attention_mask"] + + return tokenizer + + +def permute_for_rope(value, n_heads, config): + dim1 = value.shape[0] + dim2 = config.hidden_size + return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + +def convert_dictionnary(original_state_dict, vision_config, text_config): + new_dict = {} + + all_keys = "\n" + "\n".join(original_state_dict.keys()) + old_keys = all_keys + for old, new in OLD_KEY_TO_NEW_KEY_MAPPING.items(): + all_keys = re.sub(r"\n" + old, r"\n" + new, all_keys) + + OLD_TO_NEW = dict(zip(old_keys.split("\n"), all_keys.split("\n"))) + + for key, value in original_state_dict.items(): + new_key = OLD_TO_NEW[key] + if "vision_encoder" in key: + _config = vision_config + num_attention_heads = _config.num_attention_heads + else: + _config = text_config + if "q_proj" in new_key: + num_attention_heads = _config.num_attention_heads + if "k_proj" in new_key: + num_attention_heads = _config.num_key_value_heads + # convert the text model (basically mistral model) + + if "q_proj" in new_key or "k_proj" in new_key: + value = permute_for_rope(value, num_attention_heads, _config) + + new_dict[new_key] = value + return new_dict + + +def convert_mistral_model(input_dir, output_dir): + text_config = MistralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + head_dim=128, + hidden_act="silu", + hidden_size=5120, + initializer_range=0.02, + intermediate_size=14336, + max_position_embeddings=1024000, + model_type="mistral", + num_attention_heads=32, + num_hidden_layers=40, + num_key_value_heads=8, + rms_norm_eps=1e-05, + rope_theta=1000000000.0, + sliding_window=None, + tie_word_embeddings=False, + vocab_size=131072, + ) + + vision_config = PixtralVisionConfig() + config = LlavaConfig( + vision_config, + text_config, + vision_feature_layer=-1, + image_token_index=10, + vision_feature_select_strategy="full", + image_seq_length=1, + ) + config.architectures = ["LlavaForConditionalGeneration"] + config.save_pretrained(output_dir) + + original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors") + new_dict = convert_dictionnary(original_state_dict, vision_config, text_config) + + with torch.device("meta"): + model = LlavaForConditionalGeneration(config) + model.load_state_dict(new_dict, strict=True, assign=True) + + model.save_pretrained(output_dir) + + tokenizer = convert_mistral_tokenizer() + image_processor = PixtralImageProcessor() + processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]") + processor.save_pretrained(output_dir) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + help="Location of LLaMA weights, which contains tokenizer.model and model folders", + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + + args = parser.parse_args() + convert_mistral_model(args.input_dir, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py new file mode 100644 index 00000000000000..c6d18420bec575 --- /dev/null +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -0,0 +1,519 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Pixtral.""" + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + is_valid_image, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging +from ...utils.import_utils import requires_backends + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +class BatchMixFeature(BatchFeature): + def to(self, *args, **kwargs) -> "BatchMixFeature": + """ + Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in + different `dtypes` and sending the `BatchFeature` to a different `device`. + + Args: + args (`Tuple`): + Will be passed to the `to(...)` function of the tensors. + kwargs (`Dict`, *optional*): + Will be passed to the `to(...)` function of the tensors. + + Returns: + [`BatchFeature`]: The same instance after modification. + """ + requires_backends(self, ["torch"]) + import torch # noqa + + new_data = {} + device = kwargs.get("device") + # Check if the args are a device or a dtype + if device is None and len(args) > 0: + # device should be always the first argument + arg = args[0] + if is_torch_dtype(arg): + # The first argument is a dtype + pass + elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): + device = arg + else: + # it's something else + raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + for k, v in self.items(): + # check if v is a floating point + if isinstance(v, list): + new_data[k] = [ + element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element) + ] + elif torch.is_floating_point(v): + # cast and send to device + new_data[k] = v.to(*args, **kwargs) + elif device is not None: + new_data[k] = v.to(device=device) + else: + new_data[k] = v + self.data = new_data + return self + + +# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images +def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]: + """ + Convert a single image or a list of images to a list of numpy arrays. + + Args: + images (`ImageInput`): + A single image or a list of images. + + Returns: + A list of numpy arrays. + """ + # If it's a single image, convert it to a list of lists + if is_valid_image(images): + images = [[images]] + # If it's a list of images, it's a single batch, so convert it to a list of lists + elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]): + images = [images] + # If it's a list of batches, it's already in the right format + elif ( + isinstance(images, (list, tuple)) + and len(images) > 0 + and isinstance(images[0], (list, tuple)) + and is_valid_image(images[0][0]) + ): + pass + else: + raise ValueError( + "Invalid input type. Must be a single image, a list of images, or a list of batches of images." + ) + return images + + +# Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white. +def convert_to_rgb(image: ImageInput) -> ImageInput: + """ + Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image + as is. + Args: + image (Image): + The image to convert. + """ + requires_backends(convert_to_rgb, ["vision"]) + + if not isinstance(image, PIL.Image.Image): + return image + + if image.mode == "RGB": + return image + + # First we convert to RGBA to set background to white. + image = image.convert("RGBA") + + # Create a new image with a white background. + new_image = PIL.Image.new("RGBA", image.size, "WHITE") + new_image.paste(image, (0, 0), image) + new_image = new_image.convert("RGB") + return new_image + + +def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int]) -> int: + """ + Calculate the number of image tokens given the image size and patch size. + + Args: + image_size (`Tuple[int, int]`): + The size of the image as `(height, width)`. + patch_size (`Tuple[int, int]`): + The patch size as `(height, width)`. + + Returns: + `int`: The number of image tokens. + """ + height, width = image_size + patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size) + num_width_tokens = (width - 1) // patch_width + 1 + num_height_tokens = (height - 1) // patch_height + 1 + return num_height_tokens, num_width_tokens + + +def get_resize_output_image_size( + input_image: np.ndarray, + size: Union[int, Tuple[int, int], List[int], Tuple[int]], + patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]], + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple: + """ + Find the target (height, width) dimension of the output image after resizing given the input image and the desired + size. + + Args: + input_image (`np.ndarray`): + The image to resize. + size (`int` or `Tuple[int, int]`): + Max image size an input image can be. Must be a dictionary with the key "longest_edge". + patch_size (`int` or `Tuple[int, int]`): + The patch_size as `(height, width)` to use for resizing the image. If patch_size is an integer, `(patch_size, patch_size)` + will be used + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. + + Returns: + `tuple`: The target (height, width) dimension of the output image after resizing. + """ + max_height, max_width = size if isinstance(size, (tuple, list)) else (size, size) + patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size) + height, width = get_image_size(input_image, input_data_format) + + ratio = max(height / max_height, width / max_width) + + if ratio > 1: + # Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results + height = int(np.ceil(height / ratio)) + width = int(np.ceil(width / ratio)) + + num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width)) + return num_height_tokens * patch_height, num_width_tokens * patch_width + + +# Hack to get tensor conversion used in BatchFeature without batching the images +def _get_is_as_tensor_fns(tensor_type: Union[str, TensorType]) -> Tuple[Callable, Callable]: + return BatchFeature()._get_is_as_tensor_fns(tensor_type) + + +def convert_to_tensor(array, tensor_type: Union[str, TensorType]) -> Any: + is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type) + if is_tensor(array): + return array + return as_tensor(array) + + +class PixtralImageProcessor(BaseImageProcessor): + r""" + Constructs a Pixtral image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`): + Size of the maximum dimension of either the height or width dimension of the image. Used to control how + images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)` + patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`): + Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + patch_size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"longest_edge": 1024} + patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16} + patch_size = get_size_dict(patch_size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.patch_size = patch_size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073] + self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711] + self.do_convert_rgb = do_convert_rgb + self._valid_processor_keys = [ + "images", + "do_resize", + "size", + "patch_size", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + patch_size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dict containing the longest possible edge of the image. + patch_size (`Dict[str, int]`): + Patch size used to calculate the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if "longest_edge" in size: + size = (size["longest_edge"], size["longest_edge"]) + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.") + + if "height" in patch_size and "width" in patch_size: + patch_size = (patch_size["height"], patch_size["width"]) + else: + raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + patch_size=patch_size, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + patch_size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Describes the maximum input dimensions to the model. + patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`): + Patch size in the model. Used to calculate the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + patch_size = patch_size if patch_size is not None else self.patch_size + patch_size = get_size_dict(patch_size, default_to_square=True) + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + images_list = make_list_of_images(images) + + if not valid_images(images_list[0]): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if do_convert_rgb: + images_list = [[convert_to_rgb(image) for image in images] for images in images_list] + + # All transformations expect numpy arrays. + images_list = [[to_numpy_array(image) for image in images] for images in images_list] + + if is_scaled_image(images_list[0][0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images_list[0][0]) + + batch_images = [] + batch_image_sizes = [] + for sample_images in images_list: + images = [] + image_sizes = [] + for image in sample_images: + if do_resize: + image = self.resize( + image=image, + size=size, + patch_size=patch_size, + resample=resample, + input_data_format=input_data_format, + ) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + images.append(image) + image_sizes.append(get_image_size(image, input_data_format)) + batch_images.append(images) + batch_image_sizes.append(image_sizes) + + images_list = [ + [to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images] + for images in batch_images + ] + + # Convert to tensor type outside of BatchFeature to avoid batching the images of different sizes + images_list = [[convert_to_tensor(image, return_tensors) for image in images] for images in images_list] + return BatchMixFeature(data={"pixel_values": images_list, "image_sizes": batch_image_sizes}, tensor_type=None) diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py new file mode 100644 index 00000000000000..0e10c78b7852af --- /dev/null +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -0,0 +1,517 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Pixtral model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ... import PreTrainedModel +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_pixtral import PixtralVisionConfig + + +logger = logging.get_logger(__name__) + + +def position_ids_in_meshgrid(patch_embeds_list, max_width): + positions = [] + for patch in patch_embeds_list: + height, width = patch.shape[-2:] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + positions.append(ids[:, 0]) + return torch.cat(positions) + + +class PixtralRotaryEmbedding(nn.Module): + """ + The key with pixtral embedding is just that you have a frequency for each pixel positions. + If you have height x width pixels (or embedding pixels) + + then the frequency used for ROPE is given by indexing the pre_computed frequency on the + width and height. + + What you output is of dimension batch, height * width, dim with dim the embed dim. + + This simply means that for each image hidden states, you are going to add + a corresponding positional embedding, based on it's index in the grid. + """ + + def __init__(self, config, device): + super().__init__() + self.rope_type = "default" + self.dim = config.head_dim + self.base = config.rope_theta + max_patches_per_side = config.image_size // config.patch_size + freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + + h = torch.arange(max_patches_per_side, device=freqs.device) + w = torch.arange(max_patches_per_side, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes + # Different from paper, but it uses a different permutation in order to obtain the same calculation + + # TODO maybe make it torch compatible later on. We can also just slice + self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + freqs = self.inv_freq[position_ids] + # position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + emb = freqs + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class PixtralAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, patches, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Pixtral +class PixtralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral +class PixtralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + PixtralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class PixtralAttentionLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) + self.feed_forward = PixtralMLP(config) + self.attention = PixtralAttention(config) + self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.attention_norm(hidden_states) + hidden_states, attn_weights = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + return outputs + + +class PixtralTransformer(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layers = torch.nn.ModuleList() + for _ in range(config.num_hidden_layers): + self.layers.append(PixtralAttentionLayer(config)) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + position_embeddings, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=[hidden_states], attentions=all_attentions + ) + + +PIXTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PixtralVisionConfig`] or [`PixtralVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + PIXTRAL_START_DOCSTRING, +) +class PixtralPreTrainedModel(PreTrainedModel): + config_class = PixtralVisionConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PixtralVisionAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + + def _init_weights(self, module): + # important: this ported version of Pixtral isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/haotian-liu/LLaVA/tree/main/pixtral should serve for that purpose + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PIXTRAL_INPUTS_DOCSTRING = r""" + Args: + pixel_values: list of N_img images of variable sizes, + each of shape (C, H, W) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def generate_block_attention_mask(patch_embeds_list, tensor): + dtype = tensor.dtype + device = tensor.device + seq_len = tensor.shape[1] + d_min = torch.finfo(dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + + block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) + block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) + for start, end in zip(block_start_idx, block_end_idx): + causal_mask[start:end, start:end] = 0 + + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + return causal_mask + + +@add_start_docstrings( + """The PIXTRAL model which consists of a vision backbone and a language model.""", + PIXTRAL_START_DOCSTRING, +) +class PixtralModel(PixtralPreTrainedModel): + base_model_prefix = "vision_encoder" + + def __init__(self, config): + super().__init__(config) + self.config = config + self.patch_conv = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5) + self.transformer = PixtralTransformer(config) + self.patch_positional_embedding = PixtralRotaryEmbedding(config, device=self.device) + + @add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: List[torch.Tensor], + output_hidden_states: Optional[bool] = False, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + *args, + **kwargs, + ) -> Union[Tuple, BaseModelOutput]: + """ + Returns: + pixel_values: tensor of token features for + all tokens of all images of shape (N_toks, D) + """ + # pass images through initial convolution independently + patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values] + + # flatten to a single sequence + patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + position_ids = position_ids_in_meshgrid( + patch_embeds_list, max_width=self.config.image_size // self.config.patch_size + ).to(self.device) + + position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) + return self.transformer(patch_embeds, attention_mask, position_embedding) diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py new file mode 100644 index 00000000000000..9362703c8aa6da --- /dev/null +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -0,0 +1,282 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Pixtral. +""" + +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, is_valid_image, load_image +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.idefics2.processing_idefics2.is_url +def is_url(val) -> bool: + return isinstance(val, str) and val.startswith("http") + + +# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url +def is_image_or_image_url(elem): + return is_url(elem) or is_valid_image(elem) + + +# Copied from transformers.models.pixtral.image_processing_pixtral.BatchMixFeature +class BatchMixFeature(BatchFeature): + def to(self, *args, **kwargs) -> "BatchMixFeature": + """ + Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in + different `dtypes` and sending the `BatchFeature` to a different `device`. + + Args: + args (`Tuple`): + Will be passed to the `to(...)` function of the tensors. + kwargs (`Dict`, *optional*): + Will be passed to the `to(...)` function of the tensors. + + Returns: + [`BatchFeature`]: The same instance after modification. + """ + requires_backends(self, ["torch"]) + import torch # noqa + + new_data = {} + device = kwargs.get("device") + # Check if the args are a device or a dtype + if device is None and len(args) > 0: + # device should be always the first argument + arg = args[0] + if is_torch_dtype(arg): + # The first argument is a dtype + pass + elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): + device = arg + else: + # it's something else + raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + for k, v in self.items(): + # check if v is a floating point + if isinstance(v, list): + new_data[k] = [ + element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element) + ] + elif torch.is_floating_point(v): + # cast and send to device + new_data[k] = v.to(*args, **kwargs) + elif device is not None: + new_data[k] = v.to(device=device) + else: + new_data[k] = v + self.data = new_data + return self + + +class PixtralProcessor(ProcessorMixin): + r""" + Constructs a Pixtral processor which wraps a Pixtral image processor and a Pixtral tokenizer into a single processor. + + [`PixtralProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the + [`~PixtralProcessor.__call__`] and [`~PixtralProcessor.decode`] for more information. + + Args: + image_processor ([`PixtralImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + patch_size (`int`, *optional*, defaults to 16): + Patch size from the vision tower. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + image_token (`str`, *optional*, defaults to `"[IMG]"`): + Special token used to denote image location. + image_break_token (`str`, *optional*, defaults to `"[IMG_BREAK]"`): + Special token used to denote the end of a line of pixels in an image. + image_end_token (`str`, *optional*, defaults to `"[IMG_END]"`): + Special token used to denote the end of an image input. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = [ + "chat_template", + "patch_size", + "image_token", + "image_break_token", + "image_end_token", + ] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size: int = 16, + chat_template=None, + image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases + image_break_token="[IMG_BREAK]", + image_end_token="[IMG_END]", + **kwargs, + ): + self.patch_size = patch_size + self.image_token = image_token + self.image_break_token = image_break_token + self.image_end_token = image_end_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length=None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchMixFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if images is not None: + if is_image_or_image_url(images): + images = [[images]] + elif isinstance(images, list) and is_image_or_image_url(images[0]): + images = [images] + elif ( + not isinstance(images, list) + and not isinstance(images[0], list) + and not is_image_or_image_url(images[0][0]) + ): + raise ValueError( + "Invalid input images. Please provide a single image or a list of images or a list of list of images." + ) + images = [[load_image(im) for im in sample] for sample in images] + image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors=return_tensors) + else: + image_inputs = {} + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + # try to expand inputs in processing if we have the necessary parts + prompt_strings = text + if image_inputs.get("pixel_values") is not None: + # Replace the image token with the expanded image token sequence + images = image_inputs["pixel_values"] + image_sizes = image_inputs.pop("image_sizes") + prompt_strings = [] + + for sample_images, sample_image_sizes, sample in zip(images, image_sizes, text): + replace_strings = [] + # First calculate the number of tokens needed for each image and put in a placeholder + for image, image_size in zip(sample_images, sample_image_sizes): + height, width = image_size + num_height_tokens = height // self.patch_size + num_width_tokens = width // self.patch_size + replace_tokens = [ + [self.image_token] * num_width_tokens + [self.image_break_token] + ] * num_height_tokens + # Flatten list + replace_tokens = [item for sublist in replace_tokens for item in sublist] + replace_tokens[-1] = self.image_end_token + replace_str = "".join(replace_tokens) + replace_strings.append(replace_str) + sample = sample.replace(self.image_token, "", 1) + + while "" in sample: + replace_str = replace_strings.pop(0) + sample = sample.replace("", replace_str, 1) + + prompt_strings.append(sample) + + text_inputs = self.tokenizer( + prompt_strings, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + max_length=max_length, + ) + return BatchMixFeature(data={**text_inputs, **image_inputs}) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b9ce0d0f15bbf5..2db7b38b580375 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -7067,6 +7067,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class PixtralModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PixtralPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class PLBartForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 2493954a518b2c..436378582e54ca 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -506,6 +506,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class PixtralImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class PoolFormerFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 2fed802b5a2fb3..5c05480ffa6dbb 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -569,3 +569,50 @@ def test_expansion_in_processing(self): # check that both inputs are handled correctly and generate the same output self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) + + @slow + @require_bitsandbytes + def test_pixtral(self): + model_id = "hf-internal-testing/pixtral-12b" + model = LlavaForConditionalGeneration.from_pretrained(model_id) + processor = AutoProcessor.from_pretrained(model_id) + + IMG_URLS = [ + Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw), + Image.open(requests.get("https://picsum.photos/id/231/200/300", stream=True).raw), + Image.open(requests.get("https://picsum.photos/id/27/500/500", stream=True).raw), + Image.open(requests.get("https://picsum.photos/id/17/150/600", stream=True).raw), + ] + PROMPT = "[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]" + + # image = Image.open(requests.get(url, stream=True).raw) + inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda") + generate_ids = model.generate(**inputs, max_new_tokens=500) + ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + # fmt: off + EXPECTED_GENERATION = """ +Describe the images. +Sure, let's break down each image description: + +1. **Image 1:** + - **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera. + - **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur. + +2. **Image 2:** + - **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley. + - **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image. + +3. **Image 3:** + - **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset. + - **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene. + +4. **Image 4:** + - **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers. + - **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden. + +Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it. +""" + # fmt: on + # check that both inputs are handled correctly and generate the same output + self.assertListEqual(ouptut, EXPECTED_GENERATION) diff --git a/tests/models/pixtral/__init__.py b/tests/models/pixtral/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/pixtral/test_image_processing_pixtral.py b/tests/models/pixtral/test_image_processing_pixtral.py new file mode 100644 index 00000000000000..3994201c065c45 --- /dev/null +++ b/tests/models/pixtral/test_image_processing_pixtral.py @@ -0,0 +1,217 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import PixtralImageProcessor + + +class PixtralImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + max_num_images_per_sample=3, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=None, + patch_size=None, + do_normalize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + do_convert_rgb=True, + ): + size = size if size is not None else {"longest_edge": 24} + patch_size = patch_size if patch_size is not None else {"height": 8, "width": 8} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.max_num_images_per_sample = max_num_images_per_sample + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.patch_size = patch_size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "patch_size": self.patch_size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + } + + def expected_output_image_shape(self, image): + if isinstance(image, Image.Image): + width, height = image.size + elif isinstance(image, np.ndarray): + height, width = image.shape[:2] + elif isinstance(image, torch.Tensor): + height, width = image.shape[-2:] + + max_height = max_width = self.size.get("longest_edge") + + ratio = max(height / max_height, width / max_width) + if ratio > 1: + height = int(np.ceil(height / ratio)) + width = int(np.ceil(width / ratio)) + + patch_height, patch_width = self.patch_size["height"], self.patch_size["width"] + num_height_tokens = (height - 1) // patch_height + 1 + num_width_tokens = (width - 1) // patch_width + 1 + + height = num_height_tokens * patch_height + width = num_width_tokens * patch_width + + return self.num_channels, height, width + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + # Use prepare_image_inputs to make a list of list of single images + + images_list = [] + for _ in range(self.batch_size): + images = [] + for _ in range(random.randint(1, self.max_num_images_per_sample)): + img = prepare_image_inputs( + batch_size=1, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + )[0] + images.append(img) + images_list.append(images) + return images_list + + +@require_torch +@require_vision +class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = PixtralImageProcessor if is_vision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = PixtralImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "patch_size")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + + def test_call_pil(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs_list = self.image_processor_tester.prepare_image_inputs() + for image_inputs in image_inputs_list: + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0]) + self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape) + + # Test batched + batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values + for encoded_images, images in zip(batch_encoded_images, image_inputs_list): + for encoded_image, image in zip(encoded_images, images): + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image) + self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape) + + def test_call_numpy(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True) + for image_inputs in image_inputs_list: + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0]) + self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape) + + # Test batched + batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values + for encoded_images, images in zip(batch_encoded_images, image_inputs_list): + for encoded_image, image in zip(encoded_images, images): + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image) + self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape) + + def test_call_pytorch(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True) + for image_inputs in image_inputs_list: + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0]) + self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape) + + # Test batched + batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values + for encoded_images, images in zip(batch_encoded_images, image_inputs_list): + for encoded_image, image in zip(encoded_images, images): + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image) + self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape) + + @unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy + def test_call_numpy_4_channels(self): + pass diff --git a/tests/models/pixtral/test_modeling_pixtral.py b/tests/models/pixtral/test_modeling_pixtral.py new file mode 100644 index 00000000000000..bd41fa1c9e62fb --- /dev/null +++ b/tests/models/pixtral/test_modeling_pixtral.py @@ -0,0 +1,292 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Pixtral model.""" + +import gc +import unittest + +import requests + +from transformers import ( + AutoProcessor, + PixtralModel, + PixtralVisionConfig, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + require_bitsandbytes, + require_torch, + slow, + torch_device, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor + + +if is_torch_available(): + import torch +else: + is_torch_greater_or_equal_than_2_0 = False + +if is_vision_available(): + from PIL import Image + + +class PixtralModelTester: + def __init__( + self, + parent, + batch_size=12, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + hidden_size=32, + projection_dim=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + dropout=0.1, + attention_dropout=0.1, + initializer_range=0.02, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.hidden_size = hidden_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.scope = scope + + # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return PixtralVisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, pixel_values): + model = PixtralModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = (self.image_size, self.image_size) + patch_size = (self.patch_size, self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def create_and_check_model_with_projection(self, config, pixel_values): + model = PixtralModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = (self.image_size, self.image_size) + patch_size = (self.patch_size, self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + self.parent.assertEqual(result.image_embeds.shape, (self.batch_size, self.projection_dim)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase): + """ + Model tester for `PixtralModel`. + """ + + all_model_classes = (PixtralModel,) if is_torch_available() else () + test_pruning = False + test_head_masking = False + + def setUp(self): + self.model_tester = PixtralModelTester(self) + self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False) + + @unittest.skip("model does not support input embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip("model does not support input embeds") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Compile not yet supported because in Pixtral models") + def test_sdpa_can_compile_dynamic(self): + pass + + @unittest.skip(reason="Compile not yet supported because in Pixtral models") + def test_sdpa_can_dispatch_on_flash(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_attention_outputs(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_cpu_offload(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_batching_equivalence(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_disk_offload_bin(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_model_parallelism(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_model_outputs_equivalence(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_save_load(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_resize_tokens_embeddings(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_model_main_input_name(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_initialization(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_gradient_checkpointing_backward_compatibility(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_disk_offload_safetensors(self): + pass + + @unittest.skip(reason="Not supported yet") + def test_determinism(self): + pass + + +@require_torch +class PixtralModelIntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b") + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + @slow + @require_bitsandbytes + def test_small_model_integration_test(self): + # Let' s make sure we test the preprocessing to replace what is used + model = PixtralModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True) + + prompt = "[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]" + image_file = "https://pixtral-vl.github.io/static/images/view.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = self.processor(prompt, raw_image, return_tensors="pt") + + EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip + self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) + + output = model.generate(**inputs, max_new_tokens=20) + EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip + + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) diff --git a/tests/models/pixtral/test_processor_pixtral.py b/tests/models/pixtral/test_processor_pixtral.py new file mode 100644 index 00000000000000..b70cab1c074480 --- /dev/null +++ b/tests/models/pixtral/test_processor_pixtral.py @@ -0,0 +1,233 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import requests +import torch + +from transformers.testing_utils import require_vision +from transformers.utils import is_vision_available + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoTokenizer, PixtralImageProcessor, PixtralProcessor + + +@require_vision +class PixtralProcessorTest(unittest.TestCase): + processor_class = PixtralProcessor + + @classmethod + def setUpClass(cls): + cls.url_0 = "https://www.ilankelman.org/stopsigns/australia.jpg" + cls.image_0 = Image.open(requests.get(cls.url_0, stream=True).raw) + cls.url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg" + cls.image_1 = Image.open(requests.get(cls.url_1, stream=True).raw) + cls.url_2 = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" + cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw) + + def setUp(self): + super().setUp() + + # FIXME - just load the processor directly from the checkpoint + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/pixtral-12b") + image_processor = PixtralImageProcessor() + self.processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor) + + @unittest.skip("No chat template was set for this model (yet)") + def test_chat_template(self): + expected_prompt = "USER: [IMG]\nWhat is shown in this image? ASSISTANT:" + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + formatted_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) + self.assertEqual(expected_prompt, formatted_prompt) + + @unittest.skip("No chat template was set for this model (yet)") + def test_image_token_filling(self): + # Important to check with non square image + image = torch.randint(0, 2, (3, 500, 316)) + expected_image_tokens = 1526 + image_token_index = 32000 + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + inputs = self.processor( + text=[self.processor.apply_chat_template(messages)], + images=[image], + return_tensors="pt", + ) + image_tokens = (inputs["input_ids"] == image_token_index).sum().item() + self.assertEqual(expected_image_tokens, image_tokens) + + def test_processor_with_single_image(self): + prompt_string = "USER: [IMG]\nWhat's the content of the image? ASSISTANT:" + + # Make small for checking image token expansion + self.processor.image_processor.size = {"longest_edge": 30} + self.processor.image_processor.patch_size = {"height": 2, "width": 2} + + # Test passing in an image + inputs_image = self.processor(text=prompt_string, images=self.image_0, return_tensors="pt") + self.assertIn("input_ids", inputs_image) + self.assertTrue(len(inputs_image["input_ids"]) == 1) + self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], list) + self.assertTrue(len(inputs_image["pixel_values"]) == 1) + self.assertIsInstance(inputs_image["pixel_values"][0], list) + self.assertTrue(len(inputs_image["pixel_values"][0]) == 1) + self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor) + + # fmt: off + input_ids = inputs_image["input_ids"] + self.assertEqual( + input_ids[0].tolist(), + # Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:" + [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + # Test passing in a url + inputs_url = self.processor(text=prompt_string, images=self.url_0, return_tensors="pt") + self.assertIn("input_ids", inputs_url) + self.assertTrue(len(inputs_url["input_ids"]) == 1) + self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) + self.assertIsInstance(inputs_url["pixel_values"], list) + self.assertTrue(len(inputs_url["pixel_values"]) == 1) + self.assertIsInstance(inputs_url["pixel_values"][0], list) + self.assertTrue(len(inputs_url["pixel_values"][0]) == 1) + self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor) + + # fmt: off + input_ids = inputs_url["input_ids"] + self.assertEqual( + input_ids[0].tolist(), + # Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:" + [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + def test_processor_with_multiple_images_single_list(self): + prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:" + + # Make small for checking image token expansion + self.processor.image_processor.size = {"longest_edge": 30} + self.processor.image_processor.patch_size = {"height": 2, "width": 2} + + # Test passing in an image + inputs_image = self.processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt") + self.assertIn("input_ids", inputs_image) + self.assertTrue(len(inputs_image["input_ids"]) == 1) + self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], list) + self.assertTrue(len(inputs_image["pixel_values"]) == 1) + self.assertIsInstance(inputs_image["pixel_values"][0], list) + self.assertTrue(len(inputs_image["pixel_values"][0]) == 2) + self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor) + + # fmt: off + input_ids = inputs_image["input_ids"] + self.assertEqual( + input_ids[0].tolist(), + # Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"] + [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + # Test passing in a url + inputs_url = self.processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt") + self.assertIn("input_ids", inputs_url) + self.assertTrue(len(inputs_url["input_ids"]) == 1) + self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) + self.assertIsInstance(inputs_url["pixel_values"], list) + self.assertTrue(len(inputs_url["pixel_values"]) == 1) + self.assertIsInstance(inputs_url["pixel_values"][0], list) + self.assertTrue(len(inputs_url["pixel_values"][0]) == 2) + self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor) + # fmt: off + input_ids = inputs_url["input_ids"] + self.assertEqual( + input_ids[0].tolist(), + # Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"] + [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + def test_processor_with_multiple_images_multiple_lists(self): + prompt_string = [ + "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:", + "USER: [IMG]\nWhat's the content of the image? ASSISTANT:", + ] + self.processor.tokenizer.pad_token = "" + image_inputs = [[self.image_0, self.image_1], [self.image_2]] + + # Make small for checking image token expansion + self.processor.image_processor.size = {"longest_edge": 30} + self.processor.image_processor.patch_size = {"height": 2, "width": 2} + + # Test passing in an image + inputs_image = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True) + self.assertIn("input_ids", inputs_image) + self.assertTrue(len(inputs_image["input_ids"]) == 2) + self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) + self.assertIsInstance(inputs_image["pixel_values"], list) + self.assertTrue(len(inputs_image["pixel_values"]) == 2) + self.assertIsInstance(inputs_image["pixel_values"][0], list) + self.assertTrue(len(inputs_image["pixel_values"][0]) == 2) + self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor) + + # fmt: off + input_ids = inputs_image["input_ids"] + self.assertEqual( + input_ids[0].tolist(), + # Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"] + [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on + + # Test passing in a url + inputs_url = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True) + self.assertIn("input_ids", inputs_url) + self.assertTrue(len(inputs_url["input_ids"]) == 2) + self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) + self.assertIsInstance(inputs_url["pixel_values"], list) + self.assertTrue(len(inputs_url["pixel_values"]) == 2) + self.assertIsInstance(inputs_url["pixel_values"][0], list) + self.assertTrue(len(inputs_url["pixel_values"][0]) == 2) + self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor) + + # fmt: off + input_ids = inputs_url["input_ids"] + self.assertEqual( + input_ids[0].tolist(), + # Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"] + [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] + ) + # fmt: on