diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 435b482df599cf..bd7aaf20a58944 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -864,6 +864,8 @@ title: InstructBlipVideo - local: model_doc/kosmos-2 title: KOSMOS-2 + - local: model_doc/kosmos-2.5 + title: KOSMOS-2.5 - local: model_doc/layoutlm title: LayoutLM - local: model_doc/layoutlmv2 diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 3bd1c286d43240..24c4fbdb113c23 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -186,6 +186,7 @@ Flax), PyTorch, and/or TensorFlow. | [JetMoe](model_doc/jetmoe) | ✅ | ❌ | ❌ | | [Jukebox](model_doc/jukebox) | ✅ | ❌ | ❌ | | [KOSMOS-2](model_doc/kosmos-2) | ✅ | ❌ | ❌ | +| [KOSMOS-2.5](model_doc/kosmos-2.5) | ✅ | ❌ | ❌ | | [LayoutLM](model_doc/layoutlm) | ✅ | ✅ | ❌ | | [LayoutLMv2](model_doc/layoutlmv2) | ✅ | ❌ | ❌ | | [LayoutLMv3](model_doc/layoutlmv3) | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/kosmos-2.5.md b/docs/source/en/model_doc/kosmos-2.5.md new file mode 100644 index 00000000000000..d221781a288f60 --- /dev/null +++ b/docs/source/en/model_doc/kosmos-2.5.md @@ -0,0 +1,63 @@ + + +# KOSMOS-2.5 + +## Overview + +Kosmos-2.5 is a multimodal literate model for machine reading of text-intensive images. Pre-trained on large-scale text-intensive images, Kosmos-2.5 excels in two distinct yet cooperative transcription tasks: (1) generating spatially-aware text blocks, where each block of text is assigned its spatial coordinates within the image, and (2) producing structured text output that captures styles and structures into the markdown format. This unified multimodal literate capability is achieved through a shared decoder-only auto-regressive Transformer architecture, task-specific prompts, and flexible text representations. We evaluate Kosmos-2.5 on end-to-end document-level text recognition and image-to-markdown text generation. Furthermore, the model can be readily adapted for any text-intensive image understanding task with different prompts through supervised fine-tuning, making it a general-purpose tool for real-world applications involving text-rich images. This work also paves the way for the future scaling of multimodal large language models. + +The abstract from the paper is the following: + +*We present Kosmos-2.5, a multimodal literate model for machine reading of text-intensive images. Pre-trained on large-scale text-intensive images, Kosmos-2.5 excels in two distinct yet cooperative transcription tasks: (1) generating spatially-aware text blocks, where each block of text is assigned its spatial coordinates within the image, and (2) producing structured text output that captures styles and structures into the markdown format. This unified multimodal literate capability is achieved through a shared Transformer architecture, task-specific prompts, and flexible text representations. We evaluate Kosmos-2.5 on end-to-end document-level text recognition and image-to-markdown text generation. Furthermore, the model can be readily adapted for any text-intensive image understanding task with different prompts through supervised fine-tuning, making it a general-purpose tool for real-world applications involving text-rich images. This work also paves the way for the future scaling of multimodal large language models.* + + + + + + Overview of tasks that KOSMOS-2.5 can handle. Taken from the original paper. + +## Example +**Markdown Task:** For usage instructions, please refer to [md.py](https://huggingface.co/microsoft/kosmos-2.5/blob/main/md.py). + +**OCR Task:** For usage instructions, please refer to [ocr.py](https://huggingface.co/microsoft/kosmos-2.5/blob/main/ocr.py). + + + +## Kosmos2_5Config + +[[autodoc]] Kosmos2_5Config + +## Kosmos2_5ImageProcessor + +[[autodoc]] Kosmos2_5ImageProcessor + +## Kosmos2_5Processor + +[[autodoc]] Kosmos2_5Processor + - __call__ + +## Kosmos2_5Model + +[[autodoc]] Kosmos2_5Model + - forward + +## Kosmos2_5ForConditionalGeneration + +[[autodoc]] Kosmos2_5ForConditionalGeneration + - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index cbb498070d69e5..dc6388b75e720c 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -61,6 +61,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) +* [Kosmos-2.5](https://huggingface.co/docs/transformers/model_doc/kosmos2_5#transformers.Kosmos2_5Model) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [Llava](https://huggingface.co/docs/transformers/model_doc/llava) * [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next) @@ -253,6 +254,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel) * [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) +* [Kosmos-2.5](https://huggingface.co/docs/transformers/model_doc/kosmos2_5#transformers.Kosmos2_5Model) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [Llava](https://huggingface.co/docs/transformers/model_doc/llava) * [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 920dc334dbb2a4..bb32358c21cd73 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -516,6 +516,10 @@ "Kosmos2Config", "Kosmos2Processor", ], + "models.kosmos2_5": [ + "Kosmos2_5Config", + "Kosmos2_5Processor", + ], "models.layoutlm": [ "LayoutLMConfig", "LayoutLMTokenizer", @@ -1220,6 +1224,7 @@ _import_structure["models.idefics3"].extend(["Idefics3ImageProcessor"]) _import_structure["models.imagegpt"].extend(["ImageGPTFeatureExtractor", "ImageGPTImageProcessor"]) _import_structure["models.instructblipvideo"].extend(["InstructBlipVideoImageProcessor"]) + _import_structure["models.kosmos2_5"].extend(["Kosmos2_5ImageProcessor"]) _import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"]) _import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"]) _import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"]) @@ -2568,6 +2573,13 @@ "Kosmos2PreTrainedModel", ] ) + _import_structure["models.kosmos2_5"].extend( + [ + "Kosmos2_5ForConditionalGeneration", + "Kosmos2_5Model", + "Kosmos2_5PreTrainedModel", + ] + ) _import_structure["models.layoutlm"].extend( [ "LayoutLMForMaskedLM", @@ -5453,6 +5465,10 @@ Kosmos2Config, Kosmos2Processor, ) + from .models.kosmos2_5 import ( + Kosmos2_5Config, + Kosmos2_5Processor, + ) from .models.layoutlm import ( LayoutLMConfig, LayoutLMTokenizer, @@ -6192,6 +6208,7 @@ from .models.idefics3 import Idefics3ImageProcessor from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor from .models.instructblipvideo import InstructBlipVideoImageProcessor + from .models.kosmos2_5 import Kosmos2_5ImageProcessor from .models.layoutlmv2 import ( LayoutLMv2FeatureExtractor, LayoutLMv2ImageProcessor, @@ -7321,6 +7338,11 @@ Kosmos2Model, Kosmos2PreTrainedModel, ) + from .models.kosmos2_5 import ( + Kosmos2_5ForConditionalGeneration, + Kosmos2_5Model, + Kosmos2_5PreTrainedModel, + ) from .models.layoutlm import ( LayoutLMForMaskedLM, LayoutLMForQuestionAnswering, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5eb74fab5abe71..8b8eee9ec25989 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -128,6 +128,7 @@ jamba, jetmoe, kosmos2, + kosmos2_5, layoutlm, layoutlmv2, layoutlmv3, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index d7d8281c2e3f03..76df351dd0fdb1 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -149,6 +149,7 @@ ("jetmoe", "JetMoeConfig"), ("jukebox", "JukeboxConfig"), ("kosmos-2", "Kosmos2Config"), + ("kosmos-2.5", "Kosmos2_5Config"), ("layoutlm", "LayoutLMConfig"), ("layoutlmv2", "LayoutLMv2Config"), ("layoutlmv3", "LayoutLMv3Config"), @@ -462,6 +463,7 @@ ("jetmoe", "JetMoe"), ("jukebox", "Jukebox"), ("kosmos-2", "KOSMOS-2"), + ("kosmos-2.5", "KOSMOS-2.5"), ("layoutlm", "LayoutLM"), ("layoutlmv2", "LayoutLMv2"), ("layoutlmv3", "LayoutLMv3"), @@ -695,6 +697,7 @@ ("data2vec-vision", "data2vec"), ("donut-swin", "donut"), ("kosmos-2", "kosmos2"), + ("kosmos-2.5", "kosmos2_5"), ("maskformer-swin", "maskformer"), ("xclip", "x_clip"), ("clip_vision_model", "clip"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index db25591eaa3544..88c57b64e5f86e 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -98,6 +98,7 @@ ("instructblip", ("BlipImageProcessor",)), ("instructblipvideo", ("InstructBlipVideoImageProcessor",)), ("kosmos-2", ("CLIPImageProcessor",)), + ("kosmos-2.5", ("Kosmos2_5ImageProcessor",)), ("layoutlmv2", ("LayoutLMv2ImageProcessor",)), ("layoutlmv3", ("LayoutLMv3ImageProcessor",)), ("levit", ("LevitImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 5d41ad42beea7e..4b12cd4c840f42 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -143,6 +143,7 @@ ("jetmoe", "JetMoeModel"), ("jukebox", "JukeboxModel"), ("kosmos-2", "Kosmos2Model"), + ("kosmos-2.5", "Kosmos2_5Model"), ("layoutlm", "LayoutLMModel"), ("layoutlmv2", "LayoutLMv2Model"), ("layoutlmv3", "LayoutLMv3Model"), @@ -762,6 +763,7 @@ ("instructblip", "InstructBlipForConditionalGeneration"), ("instructblipvideo", "InstructBlipVideoForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"), + ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"), ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), @@ -795,6 +797,7 @@ ("idefics3", "Idefics3ForConditionalGeneration"), ("instructblip", "InstructBlipForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"), + ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"), ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 815e2ca755bee3..fe6ae598eb7e49 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -71,6 +71,7 @@ ("instructblip", "InstructBlipProcessor"), ("instructblipvideo", "InstructBlipVideoProcessor"), ("kosmos-2", "Kosmos2Processor"), + ("kosmos-2.5", "Kosmos2_5Processor"), ("layoutlmv2", "LayoutLMv2Processor"), ("layoutlmv3", "LayoutLMv3Processor"), ("llava", "LlavaProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 1cdebde8cd904f..c24e215e6b7916 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -248,6 +248,7 @@ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None, ), ), + ("kosmos-2.5", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)), ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)), ("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index ffd8277f0268a3..aa561b18c51b90 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -2073,6 +2073,7 @@ def forward( vision_model_output=vision_model_output, ) + @torch.no_grad() def generate( self, pixel_values: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/kosmos2_5/__init__.py b/src/transformers/models/kosmos2_5/__init__.py new file mode 100644 index 00000000000000..6d7f300084fcca --- /dev/null +++ b/src/transformers/models/kosmos2_5/__init__.py @@ -0,0 +1,30 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_kosmos2_5 import * + from .image_processing_kosmos2_5 import * + from .modeling_kosmos2_5 import * + from .processing_kosmos2_5 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/kosmos2_5/configuration_kosmos2_5.py b/src/transformers/models/kosmos2_5/configuration_kosmos2_5.py new file mode 100644 index 00000000000000..d1e0ce43b547e6 --- /dev/null +++ b/src/transformers/models/kosmos2_5/configuration_kosmos2_5.py @@ -0,0 +1,275 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""KOSMOS-2.5 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Kosmos2_5TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Kosmos2_5TextModel`]. It is used to instantiate a + KOSMOS-2.5 text decoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text decoder of the KOSMOS-2.5 + [microsoft/kosmos-2.5](https://huggingface.co/microsoft/kosmos-2.5) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 108481): + Vocabulary size of the Kosmos2_5 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Kosmos2_5Model`]. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + embed_dim (`int`, *optional*, defaults to 2048): + Dimensionality of the layers and the pooler layer. + layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + ffn_dim (`int`, *optional*, defaults to 8192): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_embedding (`bool`, *optional*, defaults to `True`): + Scale embeddings by diving by sqrt(embed_dim). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + ```""" + + model_type = "kosmos_2_5_text_model" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_attention_heads": "attention_heads", + "hidden_size": "embed_dim", + "num_hidden_layers": "layers", + } + + def __init__( + self, + vocab_size=108481, + max_position_embeddings=4096, + embed_dim=1536, + layers=24, + ffn_dim=6144, + attention_heads=16, + activation_function="gelu", + dropout=0.1, + attention_dropout=0, + activation_dropout=0.0, + layerdrop=0.0, + layer_norm_eps=1e-5, + init_std=0.02, + scale_embedding=True, + use_cache=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.embed_dim = embed_dim + self.layers = layers + self.ffn_dim = ffn_dim + self.attention_heads = attention_heads + self.activation_function = activation_function + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.layerdrop = layerdrop + self.layer_norm_eps = layer_norm_eps + self.init_std = init_std + self.scale_embedding = scale_embedding + self.use_cache = use_cache + + +class Kosmos2_5VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Kosmos2_5VisionModel`]. It is used to + instantiate a KOSMOS-2.5 vision encoder according to the specified arguments, defining the model architecture. + Instantiating a configuration defaults will yield a similar configuration to that of the vision encoder of the KOSMOS-2.5 + [microsoft/kosmos-2.5](https://huggingface.co/microsoft/kosmos-2.5) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + patch_embed_hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the input patch_embedding layer in the Transformer encoder. + d_ff (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + d_kv (`int`, *optional*, defaults to 64): + Dimensionality of the key, query, value projections per attention head. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + dense_act_fn (`str` or `function`, *optional*, defaults to `"gelu_new"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + seq_len (`int`, *optional*, defaults to 4096): + Maximum sequence length (here number of patches) supported by the model. + Example: + + ```python + >>> from transformers import Kosmos2_5VisionConfig, Kosmos2_5VisionModel + + >>> # Initializing a Kosmos2_5VisionConfig with microsoft/kosmos-2.5 style configuration + >>> configuration = Kosmos2_5VisionConfig() + + >>> # Initializing a Kosmos2_5VisionModel (with random weights) from the microsoft/kosmos-2.5 style configuration + >>> model = Kosmos2_5VisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "kosmos_2_5_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=1536, + patch_embed_hidden_size=768, + d_ff=3968, + d_kv=64, + num_hidden_layers=18, + num_attention_heads=24, + dense_act_fn="gelu_new", + layer_norm_eps=1e-6, + dropout_rate=0.0, + attention_dropout=0.0, + seq_len=4096, + initializer_factor=1.0, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.patch_embed_hidden_size = patch_embed_hidden_size + self.d_ff = d_ff + self.dropout_rate = dropout_rate + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.dense_act_fn = dense_act_fn + self.seq_len = seq_len + self.d_kv = d_kv + self.initializer_factor = initializer_factor + self.initializer_range = initializer_range + + +class Kosmos2_5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Kosmos2_5Model`]. It is used to instantiate a + KOSMOS-2.5 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the KOSMOS-2.5 + [microsoft/kosmos-2.5](https://huggingface.co/microsoft/kosmos-2.5) architecture. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Kosmos2_5TextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`Kosmos2_5VisionConfig`]. + latent_query_num (`int`, *optional*, defaults to 2048): + The number of latent query tokens that represent the image features used in the text decoder component. + kwargs (*optional*): + Dictionary of keyword arguments. + """ + + model_type = "kosmos-2.5" + sub_configs = {"text_config": Kosmos2_5TextConfig, "vision_config": Kosmos2_5VisionConfig} + + def __init__( + self, + text_config=None, + vision_config=None, + latent_query_num=2048, + **kwargs, + ): + super().__init__(**kwargs) + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the Kosmos2_5TextConfig with default values.") + if vision_config is None: + vision_config = {} + logger.info("vision_config is None. Initializing the Kosmos2_5VisionConfig with default values.") + + self.text_config = Kosmos2_5TextConfig(**text_config) + self.vision_config = Kosmos2_5VisionConfig(**vision_config) + + self.latent_query_num = latent_query_num + + @classmethod + def from_text_vision_configs( + cls, + text_config: Kosmos2_5TextConfig, + vision_config: Kosmos2_5VisionConfig, + **kwargs, + ): + r""" + Instantiate a [`Kosmos2_5Config`] (or a derived class) from Kosmos2_5 text model configuration and Kosmos2_5 + vision model configuration. + + Returns: + [`Kosmos2_5Config`]: An instance of a configuration object + """ + + return cls( + text_config=text_config.to_dict(), + vision_config=vision_config.to_dict(), + **kwargs, + ) + + +__all__ = ["Kosmos2_5Config"] diff --git a/src/transformers/models/kosmos2_5/convert_kosmos2_5.py b/src/transformers/models/kosmos2_5/convert_kosmos2_5.py new file mode 100644 index 00000000000000..39966cf760bba2 --- /dev/null +++ b/src/transformers/models/kosmos2_5/convert_kosmos2_5.py @@ -0,0 +1,87 @@ +import argparse + +from fairseq.checkpoint_utils import load_checkpoint_to_cpu + +from transformers import Kosmos2_5Config, Kosmos2_5ForConditionalGeneration + + +KEYS_TO_MODIFY_MAPPING = { + "gpt_model.decoder.output_projection": "text_model.lm_head", + "gpt_model.decoder": "text_model.model", + "img_connector": "image_to_text_projection", + "img_model.embeddings": "vision_model.embeddings", + "img_model.encoder": "vision_model.encoder", + "img_model.layernorm": "vision_model.layernorm", + "img_model": "vision_model", + "ln_pre": "pre_layrnorm", + "ln_post": "post_layernorm", + "transformer.resblocks": "encoder.layers", + "ts_attn": "self_attn", + "ln_1": "layer_norm1", + "ln_2": "layer_norm2", + "c_fc": "fc1", + "c_proj": "fc2", +} + + +KEYS_TO_IGNORE = [ + # this buffer in the original code is only used to send weights to the desired device + "gpt_model.decoder.embed_positions._float_tensor", + # this weight is never used in the forward in the original KOSMOS-2.5) + "gpt_model.decoder.self_attn_sope.scale", +] + + +def rename_key(key): + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + return key + + +def convert_kosmos2_5_checkpoint_to_pytorch(checkpoint_path, pytorch_dump_folder_path): + state = load_checkpoint_to_cpu(checkpoint_path) + state_dict = state["model"] + state_dict_keys = list(state_dict.keys()) + + config = Kosmos2_5Config() + # This is necessary to match the results given by the original demo + config.text_config.no_repeat_ngram_size = 3 + model = Kosmos2_5ForConditionalGeneration(config) + + # convert (by renaming keys) + converted_state_dict = {} + for key in state_dict_keys: + if key in KEYS_TO_IGNORE: + continue + renamed_key = rename_key(key) + converted_state_dict[renamed_key] = state_dict[key] + + # set + # check weight loading + # check whether the state in converted_state_dict is the same as the state in the model + model.load_state_dict(converted_state_dict, strict=True) + # save the result + model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--kosmos2_5_checkpoint_path", + default="ckpt.pt", + type=str, + required=False, + help="Path the official PyTorch dump.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default="ckpt", + type=str, + required=False, + help="Path to the output PyTorch model.", + ) + args = parser.parse_args() + convert_kosmos2_5_checkpoint_to_pytorch(args.kosmos2_5_checkpoint_path, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py b/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py new file mode 100644 index 00000000000000..8112f61cdbff32 --- /dev/null +++ b/src/transformers/models/kosmos2_5/image_processing_kosmos2_5.py @@ -0,0 +1,342 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Kosmos2_5.""" + +import math +from typing import Dict, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import ( + convert_to_rgb, + normalize, + to_channel_dimension_format, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + get_image_size, + infer_channel_dimension_format, + make_list_of_images, + to_numpy_array, + valid_images, +) +from ...utils import TensorType, is_torch_available, logging +from ...utils.import_utils import requires_backends + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) +DEFAULT_FONT_PATH = "ybelkada/fonts" + + +# Copied from transformers.models.pix2struct.image_processing_pix2struct.torch_extract_patches +def torch_extract_patches(image_tensor, patch_height, patch_width): + """ + Utiliy function to extract patches from a given image tensor. Returns a tensor of shape (1, `patch_height`, + `patch_width`, `num_channels`x `patch_height` x `patch_width`) + + Args: + image_tensor (torch.Tensor): + The image tensor to extract patches from. + patch_height (int): + The height of the patches to extract. + patch_width (int): + The width of the patches to extract. + """ + requires_backends(torch_extract_patches, ["torch"]) + + image_tensor = image_tensor.unsqueeze(0) + patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) + patches = patches.reshape(image_tensor.size(0), image_tensor.size(1), patch_height, patch_width, -1) + patches = patches.permute(0, 4, 2, 3, 1).reshape( + image_tensor.size(2) // patch_height, + image_tensor.size(3) // patch_width, + image_tensor.size(1) * patch_height * patch_width, + ) + return patches.unsqueeze(0) + + +# similar to transformers.models.pix2struct.image_processing_pix2struct.Pix2StructImageProcessor, but delete is_vqa and additionaly return width and height after resizing +class Kosmos2_5ImageProcessor(BaseImageProcessor): + r""" + Constructs a Kosmos2_5 image processor. + + Args: + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. According to Kosmos2_5 paper and code, the image is normalized with its own mean and standard + deviation. + patch_size (`Dict[str, int]`, *optional*, defaults to `{"height": 16, "width": 16}`): + The patch size to use for the image. According to Kosmos2_5 paper and code, the patch size is 16x16. + max_patches (`int`, *optional*, defaults to 4096): + The maximum number of patches to extract from the image as per the [Kosmos2_5 + paper](https://arxiv.org/pdf/2309.11419). + """ + + model_input_names = ["flattened_patches"] + + def __init__( + self, + do_convert_rgb: bool = True, + do_normalize: bool = True, + patch_size: Dict[str, int] = None, + max_patches: int = 4096, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16} + self.do_normalize = do_normalize + self.do_convert_rgb = do_convert_rgb + self.max_patches = max_patches + + def extract_flattened_patches( + self, + image: np.ndarray, + max_patches: int, + patch_size: dict, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Extract flattened patches from an image. + + Args: + image (`np.ndarray`): + Image to extract flattened patches from. + max_patches (`int`): + Maximum number of patches to extract. + patch_size (`dict`): + Dictionary containing the patch height and width. + + Returns: + result (`np.ndarray`): + A sequence of `max_patches` flattened patches. + """ + requires_backends(self.extract_flattened_patches, "torch") + + # convert to torch + image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format) + image = torch.from_numpy(image) + + patch_height, patch_width = patch_size["height"], patch_size["width"] + image_height, image_width = get_image_size(image, ChannelDimension.FIRST) + + # maximize scale s.t. + scale = math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width)) + num_feasible_rows = max(min(math.floor(scale * image_height / patch_height), max_patches), 1) + num_feasible_cols = max(min(math.floor(scale * image_width / patch_width), max_patches), 1) + resized_height = max(num_feasible_rows * patch_height, 1) + resized_width = max(num_feasible_cols * patch_width, 1) + + image = torch.nn.functional.interpolate( + image.unsqueeze(0), + size=(resized_height, resized_width), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze(0) + + # [1, rows, columns, patch_height * patch_width * image_channels] + patches = torch_extract_patches(image, patch_height, patch_width) + + patches_shape = patches.shape + rows = patches_shape[1] + columns = patches_shape[2] + depth = patches_shape[3] + + # [rows * columns, patch_height * patch_width * image_channels] + patches = patches.reshape([rows * columns, depth]) + + # [rows * columns, 1] + row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, columns).reshape([rows * columns, 1]) + col_ids = torch.arange(columns).reshape([1, columns]).repeat(rows, 1).reshape([rows * columns, 1]) + + # Offset by 1 so the ids do not contain zeros, which represent padding. + row_ids += 1 + col_ids += 1 + + # Prepare additional patch features. + # [rows * columns, 1] + row_ids = row_ids.to(torch.float32) + col_ids = col_ids.to(torch.float32) + + # [rows * columns, 2 + patch_height * patch_width * image_channels] + result = torch.cat([row_ids, col_ids, patches], -1) + + # [max_patches, 2 + patch_height * patch_width * image_channels] + result = torch.nn.functional.pad(result, [0, 0, 0, max_patches - (rows * columns)]).float() + + result = to_numpy_array(result) + + return result, resized_width, resized_height, rows, columns + + def normalize( + self, + image: np.ndarray, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Normalize an image. image = (image - image_mean) / image_std. + + The image std is to mimic the tensorflow implementation of the `per_image_standardization`: + https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization + + Args: + image (`np.ndarray`): + Image to normalize. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if image.dtype == np.uint8: + image = image.astype(np.float32) + + # take mean across the whole `image` + mean = np.mean(image) + std = np.std(image) + adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(image.shape))) + + return normalize( + image, + mean=mean, + std=adjusted_stddev, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_convert_rgb: bool = None, + do_normalize: Optional[bool] = None, + max_patches: Optional[int] = None, + patch_size: Optional[Dict[str, int]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> ImageInput: + """ + Preprocess an image or batch of images. The processor first computes the maximum possible number of + aspect-ratio preserving patches of size `patch_size` that can be extracted from the image. It then pads the + image with zeros to make the image respect the constraint of `max_patches`. Before extracting the patches the + images are standardized following the tensorflow implementation of `per_image_standardization` + (https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization). + + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + max_patches (`int`, *optional*, defaults to `self.max_patches`): + Maximum number of patches to extract. + patch_size (`dict`, *optional*, defaults to `self.patch_size`): + Dictionary containing the patch height and width. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + patch_size = patch_size if patch_size is not None else self.patch_size + max_patches = max_patches if max_patches is not None else self.max_patches + + if kwargs.get("data_format", None) is not None: + raise ValueError("data_format is not an accepted input as the outputs are ") + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + flattened_patches, width, height, rows, cols, attention_masks = [], [], [], [], [], [] + for image in images: + if do_normalize: + image = self.normalize(image=image, input_data_format=input_data_format) + + # convert to torch tensor and permute + f, w, h, r, c = self.extract_flattened_patches( + image=image, + max_patches=max_patches, + patch_size=patch_size, + input_data_format=input_data_format, + ) + flattened_patches.append(f) + width.append(w) + height.append(h) + rows.append(r) + cols.append(c) + # create attention mask in numpy + attention_masks.append((f.sum(axis=-1) != 0).astype(np.float32)) + + encoded_outputs = BatchFeature( + data={ + "flattened_patches": flattened_patches, + "attention_mask": attention_masks, + "width": width, + "height": height, + "rows": rows, + "cols": cols, + }, + tensor_type=return_tensors, + ) + + return encoded_outputs + + +__all__ = ["Kosmos2_5ImageProcessor"] diff --git a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py new file mode 100644 index 00000000000000..21d5ddb47e832a --- /dev/null +++ b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py @@ -0,0 +1,2255 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch KOSMOS-2.5 model.""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + CausalLMOutputWithCrossAttentions, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_kosmos2_5 import ( + Kosmos2_5Config, + Kosmos2_5TextConfig, + Kosmos2_5VisionConfig, +) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = Kosmos2_5Config + + +# Copied from transformers.models.kosmos2.modeling_kosmos2._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +KOSMOS2_5_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Kosmos2_5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +KOSMOS2_5_VISION_INPUTS_DOCSTRING = r""" + Args: + flattened_patches (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`Kosmos2_5ImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +KOSMOS2_5_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + image_embeds: (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): + Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. + image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, + 1]`: + + - 1 for places where to put the image features, + - 0 for places that are not for image features (i.e. for text tokens). + + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +KOSMOS2_5_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`Kosmos2_5ImageProcessor.__call__`] for details. + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0, + 1]`: + + - 1 for places where to put the image features, + - 0 for places that are not for image features (i.e. for text tokens). + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + image_embeds: (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): + Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@dataclass +class Kosmos2_5ModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): + Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. + projection_attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute + the weighted average in the self-attention heads. + vision_model_output(`BaseModelOutputWithPooling`, *optional*): + The output of the [`Kosmos2VisionModel`]. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_embeds: Optional[torch.FloatTensor] = None + projection_attentions: Optional[Tuple[torch.FloatTensor]] = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + (self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()) + for k in self.keys() + ) + + +@dataclass +class Kosmos2_5ForConditionalGenerationModelOutput(ModelOutput): + """ + Model output class for `Kosmos2_5ForConditionalGeneration`. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*): + Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`. + projection_attentions (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute + the weighted average in the self-attention heads. + vision_model_output(`BaseModelOutputWithPooling`, *optional*): + The output of the [`Kosmos2VisionModel`]. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + # past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_embeds: Optional[torch.FloatTensor] = None + projection_attentions: Optional[Tuple[torch.FloatTensor]] = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + (self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()) + for k in self.keys() + ) + + +# Copied from transformers.models.pix2struct.modeling_pix2struct.Pix2StructLayerNorm with Pix2Struct->Kosmos2_5 +class Kosmos2_5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + Kosmos2_5LayerNorm = FusedRMSNorm # noqa + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Kosmos2_5LayerNorm") +except ImportError: + # using the normal Kosmos2_5LayerNorm + pass +except Exception: + logger.warning("Discovered apex but it failed to load, falling back to Kosmos2_5LayerNorm") + pass + + +# similar to transformers.models.pix2struct.modeling_pix2struct.Pix2StructVisionEmbeddings but with `inplace=False` +# TODO: check with krip +class Kosmos2_5VisionEmbeddings(nn.Module): + def __init__(self, config: Kosmos2_5VisionConfig) -> None: + super().__init__() + self.config = config + self.patch_projection = nn.Linear(config.patch_embed_hidden_size, config.hidden_size) + + self.row_embedder = nn.Embedding(config.seq_len, config.hidden_size) + self.column_embedder = nn.Embedding(config.seq_len, config.hidden_size) + + self.dropout = nn.Dropout(config.dropout_rate, inplace=False) + + def forward(self, flattened_patches: torch.Tensor) -> torch.Tensor: + # the row and column indices are stored in the first and second position of the flattened_patches + # flattened_patches: `batch_size`, `seq_len`, `hidden_size` + 2 + row_indices = flattened_patches[:, :, 0].long() + col_indices = flattened_patches[:, :, 1].long() + + flattened_patches = flattened_patches[:, :, 2:] + + embeddings = self.patch_projection(flattened_patches) + row_embeddings = self.row_embedder(row_indices).to(embeddings.device) + col_embeddings = self.column_embedder(col_indices).to(embeddings.device) + + # sum all embeddings together + embeddings = embeddings + row_embeddings + col_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5DenseGatedActDense->Pix2StructVisionMlp,T5Config->Pix2StructVisionConfig,config.d_model->config.hidden_size,dropout_rate->dropout_rate +class Kosmos2_5VisionMlp(nn.Module): + def __init__(self, config: Kosmos2_5VisionConfig): + super().__init__() + self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + # Ignore copy + self.config = config + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class Kosmos2_5VisionAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_attention_heads + self.dropout = config.attention_dropout + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.is_causal = False + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.query = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.key = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.value = nn.Linear(self.hidden_size, self.inner_dim, bias=False) + self.output = nn.Linear(self.inner_dim, self.hidden_size, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + ): + """ + Self-attention block + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length, _ = hidden_states.size() + + query_states = self.query(hidden_states) + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + # get query states + # (batch_size, n_heads, seq_length, dim_per_head) + query_states = query_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + key_states = key_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.key_value_proj_dim) + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, -1) + attn_output = self.output(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class Kosmos2_5VisionFlashAttention2(Kosmos2_5VisionAttention): + """ + Kosmos-2.5 vision encoder flash attention module. This module inherits from `Kosmos2_5VisionAttention` as the + weights of the module stays untouched. The only required change would be on the forward pass where it needs to + correctly call the public API of flash attention and deal with padding tokens in case the input contains any of + them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + ): + """ + Flash attn Self-attention block + """ + output_attentions = False + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + batch_size, seq_length, _ = hidden_states.size() + # (batch_size, seq_length, inner_dim) + query_states = self.query(hidden_states) + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + # (batch_size, seq_length, self.n_heads , self.key_value_proj_dim) + query_states = query_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim) + key_states = key_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim) + value_states = value_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.query.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + seq_length, + dropout=self.dropout, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.output(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class Kosmos2_5VisionSdpaAttention(Kosmos2_5VisionAttention): + """ + Kosmos-2.5 vision encoder attention module using torch.nn.functional.scaled_dot_product_attention. This module + inherits from` Kosmos2_5VisionAttention` as the weights of the module stays untouched. The only changes are on the + forward pass to adapt to SDPA API. + """ + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + ): + if output_attentions: + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + ) + batch_size, seq_length, _ = hidden_states.size() + + query_states = self.query(hidden_states) + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = query_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + key_states = key_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + causal_mask = attention_mask + if attention_mask is not None: + # Slice the causal_mask to match key_states' last dimension + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and seq_length > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_length, -1) + + attn_output = self.output(attn_output) + + return attn_output, None + + +KOSMOS2_5_VISION_ATTENTION_CLASSES = { + "eager": Kosmos2_5VisionAttention, + "flash_attention_2": Kosmos2_5VisionFlashAttention2, + "sdpa": Kosmos2_5VisionSdpaAttention, +} + + +class Kosmos2_5VisionLayer(nn.Module): + def __init__(self, config: Kosmos2_5VisionConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + + self.config = config + + self.attention = KOSMOS2_5_VISION_ATTENTION_CLASSES[config._attn_implementation](config) + self.mlp = Kosmos2_5VisionMlp(config) + self.pre_mlp_layer_norm = Kosmos2_5LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pre_attention_layer_norm = Kosmos2_5LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + residual = hidden_states + + # in Kosmos2_5Vision, layernorm is applied before self-attention + hidden_states = self.pre_attention_layer_norm(hidden_states) + + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + residual + + # in Kosmos2_5Vision, layernorm is also applied after self-attention + layer_output = self.pre_mlp_layer_norm(hidden_states) + layer_output = self.mlp(layer_output) + hidden_states # second residual connection + + outputs = (layer_output,) + outputs + + return outputs + + +# Adapted from transformers.models.pix2struct.modeling_pix2struct.Pix2StructVisionEncoder with Pix2Struct->Kosmos2_5 +class Kosmos2_5VisionEncoder(nn.Module): + def __init__(self, config: Kosmos2_5VisionConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([Kosmos2_5VisionLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def _prepare_attention_mask(self, attention_mask, input_shape, inputs_embeds): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + return expanded_attn_mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + attention_mask = self._prepare_attention_mask(attention_mask, hidden_states.shape[:2], hidden_states) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.kosmos2.modeling_kosmos2.Kosmos2TextSinusoidalPositionalEmbedding with Kosmos2->Kosmos2_5 +class Kosmos2_5TextSinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.__init__ + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.make_weights + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.register_buffer("weights", emb_weights, persistent=False) + + @staticmethod + # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.get_embedding + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of + "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor = None, + inputs_embeds: torch.Tensor = None, + past_key_values_length: int = 0, + position_ids: torch.Tensor = None, + ): + if input_ids is not None: + bsz, seq_len = input_ids.size() + if position_ids is None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids( + input_ids, self.padding_idx, past_key_values_length + ).to(input_ids.device) + else: + bsz, seq_len = inputs_embeds.size()[:-1] + if position_ids is None: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + past_key_values_length + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach() + + # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.create_position_ids_from_inputs_embeds + def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length + + +# Copied from transformers.models.kosmos2.modeling_kosmos2.Kosmos2TextFFN with Kosmos2->Kosmos2_5 +class Kosmos2_5TextFFN(nn.Module): + def __init__(self, config: Kosmos2_5TextConfig): + super().__init__() + + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.fc1 = nn.Linear(config.embed_dim, config.ffn_dim) + self.fc2 = nn.Linear(config.ffn_dim, config.embed_dim) + + self.ffn_layernorm = nn.LayerNorm(config.ffn_dim, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.ffn_layernorm(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + return hidden_states + + +class Kosmos2_5TextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Similar to ...models.bart.modeling_bart.BartAttention.__init__ except an additional `inner_attn_ln`. + def __init__( + self, + config, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + add_inner_attn_layernorm: bool = False, + bias: bool = True, + is_causal=True, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.layer_idx = layer_idx + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.is_causal = is_causal + + # End opy + self.inner_attn_ln = None + if add_inner_attn_layernorm: + self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + # Copied from transformers.models.kosmos2.modeling_kosmos2.KosmosTextAttention._shape + def _shape(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, self.head_dim) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward( + self, + hidden_states: torch.Tensor, # text part + encoder_hidden_states: Optional[torch.Tensor] = None, # image part + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + batch_size, seq_length, _ = hidden_states.size() + + # use encoder_hidden_states if cross attention + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + + key_states = self._shape(self.k_proj(current_states)) + value_states = self._shape(self.v_proj(current_states)) + query_states = self._shape(self.q_proj(hidden_states) * self.scaling) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # this weight maybe overflow with fp16 + attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, seq_length, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, seq_length, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(batch_size, seq_length, -1) + + if self.inner_attn_ln is not None: + attn_output = self.inner_attn_ln(attn_output) + + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Kosmos2_5TextFlashAttention2(Kosmos2_5TextAttention): + """ + Kosmos-2.5 text flash attention module. This module inherits from `Kosmos2_5TextAttention` as the weights of the + module stays untouched. The only required change would be on the forward pass where it needs to correctly call the + public API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, # text part + encoder_hidden_states: Optional[torch.Tensor] = None, # image part + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + is_cross_attention = encoder_hidden_states is not None + bsz, q_len, _ = hidden_states.size() + + # use encoder_hidden_states if cross attention + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = self._shape(self.q_proj(hidden_states)) + key_states = self._shape(self.k_proj(current_states)) + value_states = self._shape(self.v_proj(current_states)) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + None, + q_len, + dropout=self.dropout, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.view(bsz, -1, self.embed_dim) + if self.inner_attn_ln is not None: + attn_output = self.inner_attn_ln(attn_output) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Kosmos2_5TextSdpaAttention(Kosmos2_5TextAttention): + """ + Kosmos-2.5 text decoder attention module using torch.nn.functional.scaled_dot_product_attention. This module + inherits from `Kosmos2_5TextAttention` as the weights of the module stays untouched. The only changes are on the + forward pass to adapt to SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, # text part + encoder_hidden_states: Optional[torch.Tensor] = None, # image part + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + if output_attentions: + logger.warning_once( + "Kosmos2_5TextModel is using Kosmos2_5TextSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + bsz, q_len, _ = hidden_states.size() + + # use encoder_hidden_states if cross attention + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + + query_states = self._shape(self.q_proj(hidden_states)) + key_states = self._shape(self.k_proj(current_states)) + value_states = self._shape(self.v_proj(current_states)) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + is_causal = is_causal and self.is_causal + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + if self.inner_attn_ln is not None: + attn_output = self.inner_attn_ln(attn_output) + + attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value + + +KOSMOS2_5_TEXT_ATTENTION_CLASSES = { + "eager": Kosmos2_5TextAttention, + "flash_attention_2": Kosmos2_5TextFlashAttention2, + "sdpa": Kosmos2_5TextSdpaAttention, +} + + +class Kosmos2_5TextBlock(nn.Module): + def __init__(self, config: Kosmos2_5TextConfig, layer_idx: int): + super().__init__() + self.embed_dim = config.embed_dim + self.layer_idx = layer_idx + self.self_attn = KOSMOS2_5_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + config, + embed_dim=self.embed_dim, + num_heads=config.attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + add_inner_attn_layernorm=False, + is_causal=True, + layer_idx=layer_idx, + ) + self.dropout = config.dropout + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.ffn = Kosmos2_5TextFFN(config) + self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + # Adapted from transformers.models.kosmos2.modeling_kosmos2.Kosmos2TextBlock.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Only for language part. + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.ffn(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +# Adapted from transformers.models.kosmos2.modeling_kosmos2.Kosmos2TextTransformer with Kosmos2->Kosmos2_5 +class Kosmos2_5TextTransformer(nn.Module): + """ + Transformer decoder consisting of `config.layers` layers. Each layer is a [`Kosmos2_5TextBlock`]. + Here we doesn't have cross attention. + Args: + config: Kosmos2_5TextConfig + """ + + def __init__(self, config: Kosmos2_5TextConfig): + super().__init__() + self.config = config + self.dropout = config.dropout + self.layerdrop = config.layerdrop + + self.embed_scale = math.sqrt(config.embed_dim) if config.scale_embedding else 1.0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_token_id) + + self.embed_positions = Kosmos2_5TextSinusoidalPositionalEmbedding( + num_positions=config.max_position_embeddings, + embedding_dim=config.embed_dim, + padding_idx=config.pad_token_id, + ) + + # Ignore copy + self.segment_emb = nn.Embedding(2, config.embed_dim) + self.layers = nn.ModuleList([Kosmos2_5TextBlock(config, layer_idx) for layer_idx in range(config.layers)]) + self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) + self.gradient_checkpointing = False + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + image_embeds_position_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # The argument `inputs_embeds` should be the one without being multiplied by `self.embed_scale`. + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # Ignore copy + if image_embeds is not None: + inputs_embeds[image_embeds_position_mask == 1] = image_embeds.to(inputs_embeds.device).view( + -1, image_embeds.size(-1) + ) + + inputs_embeds = inputs_embeds * self.embed_scale + + # embed positions + positions = self.embed_positions( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=0, + position_ids=position_ids, + ) + positions = positions.to(inputs_embeds.device) + + # Ignore copy + if image_embeds_position_mask is not None: + # make every not equal 0 be 1 + image_embeds_position_mask = image_embeds_position_mask.ne(0).long() + segment_embeds = self.segment_emb(image_embeds_position_mask).to(positions.device) + positions += segment_embeds + else: + # add zero embedding for padding tokens + bsz, seq_len, dim = positions.size() + zero_emb = self.segment_emb( + torch.zeros((bsz, 1), dtype=torch.long, device=self.segment_emb.weight.device) + ).to(positions.device) + positions += zero_emb + + hidden_states = inputs_embeds + positions + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add final layer norm + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +# Copied from transformers.models.kosmos2.modeling_kosmos2.Kosmos2ImageToTextProjection with Kosmos2->Kosmos2_5 +class Kosmos2_5ImageToTextProjection(nn.Module): + """The layer that transforms the image model's output to part of the text model's input (namely, image features)""" + + def __init__(self, config: Kosmos2_5Config): + super().__init__() + self.dense = nn.Linear(config.vision_config.hidden_size, config.text_config.embed_dim) + self.latent_query = nn.Parameter(torch.randn(config.latent_query_num, config.text_config.embed_dim)) + + # Ignore copy + self.x_attn = KOSMOS2_5_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + config.text_config, + config.text_config.embed_dim, + config.text_config.attention_heads, + dropout=config.text_config.attention_dropout, + is_decoder=False, + add_inner_attn_layernorm=False, + is_causal=False, + ) + + def forward(self, features): + hidden_states = self.dense(features) + + # shape = [batch, latent_query_num, h_dim] + latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1) + key_value_states = torch.cat([hidden_states, latent_query], dim=1) + + hidden_states, attn_weights, _ = self.x_attn( + hidden_states=latent_query, + encoder_hidden_states=key_value_states, + past_key_value=None, + attention_mask=None, + output_attentions=None, + ) + + return hidden_states, attn_weights + + +class Kosmos2_5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Kosmos2_5Config + supports_gradient_checkpointing = True + _no_split_modules = ["Kosmos2_5VisionLayer", "Kosmos2_5TextBlock"] + _supports_flash_attn_2 = True + _supports_cache_class = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(self, Kosmos2_5VisionModel): + init_factor = self.config.initializer_factor + std = self.config.initializer_range * init_factor + elif isinstance(self, (Kosmos2_5TextModel, Kosmos2_5TextForCausalLM)): + std = self.config.init_std + elif isinstance(self, (Kosmos2_5Model, Kosmos2_5ForConditionalGeneration)): + std = self.config.text_config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Kosmos2_5VisionModel(Kosmos2_5PreTrainedModel): + config_class = Kosmos2_5VisionConfig + + # Copied from transformers.models.pix2struct.modeling_pix2struct.Pix2StructVisionModel.__init__ with Pix2Struct->Kosmos2_5 + def __init__(self, config: Kosmos2_5VisionConfig): + super().__init__(config) + self.config = config + + self.embeddings = Kosmos2_5VisionEmbeddings(config) + self.encoder = Kosmos2_5VisionEncoder(config) + + self.layernorm = Kosmos2_5LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.pix2struct.modeling_pix2struct.Pix2StructVisionModel.get_input_embeddings + def get_input_embeddings(self): + return self.embeddings.patch_projection + + # Copied from transformers.models.pix2struct.modeling_pix2struct.Pix2StructVisionModel._prune_heads + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + # Similar to transformers.models.pix2struct.modeling_pix2struct.Pix2StructVisionModel.forward without docstring + def forward( + self, + flattened_patches: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if flattened_patches is None: + raise ValueError("You have to specify flattened_patches") + + if attention_mask is None: + # check where `flattened_patches` is not 0 + attention_mask = (flattened_patches.sum(dim=-1) != 0).float() + + embedding_output = self.embeddings(flattened_patches) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + if not return_dict: + head_outputs = (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Adapted from transformers.models.kosmos2.modeling_kosmos2.Kosmos2TextModel with KOSMOS2->KOSMOS2_5 +class Kosmos2_5TextModel(Kosmos2_5PreTrainedModel): + config_class = Kosmos2_5TextConfig + + def __init__(self, config: Kosmos2_5TextConfig): + super().__init__(config) + self.model = Kosmos2_5TextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(KOSMOS2_5_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=Kosmos2_5TextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + image_embeds_position_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + r""" + Returns: + + """ + return self.model( + input_ids=input_ids, + attention_mask=attention_mask, + image_embeds=image_embeds, + image_embeds_position_mask=image_embeds_position_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + +@add_start_docstrings( + """ + KOSMOS-2.5 Model for generating text and image features. The model consists of a vision encoder and a language model. + """, + KOSMOS2_5_START_DOCSTRING, +) +class Kosmos2_5Model(Kosmos2_5PreTrainedModel): + config_class = Kosmos2_5Config + + def __init__(self, config: Kosmos2_5Config): + super().__init__(config) + + self.text_model = Kosmos2_5TextModel(config.text_config) + self.vision_model = Kosmos2_5VisionModel(config.vision_config) + self.image_to_text_projection = Kosmos2_5ImageToTextProjection(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.model.embed_tokens + + def set_input_embeddings(self, value): + self.text_model.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(KOSMOS2_5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Kosmos2_5ModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + flattened_patches: Optional[torch.Tensor] = None, + image_embeds_position_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + image_embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, Kosmos2_5ModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Kosmos2_5Model + + >>> model = Kosmos2_5Model.from_pretrained("microsoft/kosmos2.5") + >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos2.5") + + >>> url = "https://huggingface.co/microsoft/kosmos2.5/resolve/main/snowman.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = ( + ... " An image of a snowman" + ... " warming himself by a fire" + ... "" + ... ) + + >>> inputs = processor(text=text, images=image, return_tensors="pt", add_eos_token=True) + + >>> last_hidden_state = model( + ... pixel_values=inputs["pixel_values"], + ... input_ids=inputs["input_ids"], + ... attention_mask=inputs["attention_mask"], + ... image_embeds_position_mask=inputs["image_embeds_position_mask"], + ... ).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 91, 2048] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_model_output = None + projection_attentions = None + if image_embeds is None: + if flattened_patches is not None: + vision_model_output = self.vision_model( + flattened_patches=flattened_patches, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # normalized features + image_embeds = nn.functional.normalize(vision_model_output[0], dim=-1) + image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) + + outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + image_embeds=image_embeds, + image_embeds_position_mask=image_embeds_position_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + outputs = outputs + (image_embeds, projection_attentions, vision_model_output) + return tuple(output for output in outputs if output is not None) + + return Kosmos2_5ModelOutput( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_embeds=image_embeds, + projection_attentions=projection_attentions, + vision_model_output=vision_model_output, + ) + + +@add_start_docstrings( + """ + The text model from KOSMOS-2.5 with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + KOSMOS2_5_START_DOCSTRING, +) +# Adapted from transformers.models.kosmos2.modeling_kosmos2.Kosmos2TextForCausalLM with KOSMOS-2->KOSMOS-2.5,KOSMOS2->KOSMOS2_5,Kosmos2->Kosmos2_5 +class Kosmos2_5TextForCausalLM(Kosmos2_5PreTrainedModel): + config_class = Kosmos2_5TextConfig + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Kosmos2_5TextConfig): + super().__init__(config) + + self.model = Kosmos2_5TextTransformer(config) + self.lm_head = nn.Linear(in_features=config.embed_dim, out_features=config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(KOSMOS2_5_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=Kosmos2_5TextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + image_embeds_position_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Returns: + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + image_embeds=image_embeds, + image_embeds_position_mask=image_embeds_position_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + image_embeds=None, + image_embeds_position_mask=None, + past_key_values=None, + attention_mask=None, + use_cache=None, + cache_position=None, + position_ids=None, + **model_kwargs, + ): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + position_ids = None + + # cut input_ids if past_key_values is used + if past_key_values is not None: + position_ids = create_position_ids_from_input_ids( + input_ids, + padding_idx=self.config.pad_token_id, + past_key_values_length=0, + )[:, -cache_position.shape[0] :] + + input_ids = input_ids[:, -cache_position.shape[0] :] + # the image info. is already encoded into the past keys/values + if past_key_values.get_seq_length() > 0: + image_embeds = None + image_embeds_position_mask = None + elif image_embeds_position_mask is not None: + # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation) + batch_size, seq_len = input_ids.size() + mask_len = image_embeds_position_mask.size()[-1] + image_embeds_position_mask = torch.cat( + ( + image_embeds_position_mask, + torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool, device=input_ids.device), + ), + dim=1, + ) + + return { + "input_ids": input_ids, + "image_embeds": image_embeds, + "image_embeds_position_mask": image_embeds_position_mask, + "past_key_values": past_key_values, + "attention_mask": attention_mask, + "position_ids": position_ids, + "use_cache": use_cache, + } + + @staticmethod + # Copied from transformers.models.umt5.modeling_umt5.UMT5ForConditionalGeneration._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + KOSMOS-2.5 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a + language model. + """, + KOSMOS2_5_START_DOCSTRING, +) +class Kosmos2_5ForConditionalGeneration(Kosmos2_5PreTrainedModel, GenerationMixin): + config_class = Kosmos2_5Config + _tied_weights_keys = ["text_model.lm_head.weight"] + + def __init__(self, config: Kosmos2_5Config): + super().__init__(config) + self.text_model = Kosmos2_5TextForCausalLM(config.text_config) + self.vision_model = Kosmos2_5VisionModel(config.vision_config) + self.image_to_text_projection = Kosmos2_5ImageToTextProjection(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.model.embed_tokens + + def set_input_embeddings(self, value): + self.text_model.model.embed_tokens = value + + def get_output_embeddings(self) -> nn.Module: + return self.text_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.text_model.set_output_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(KOSMOS2_5_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Kosmos2_5ForConditionalGenerationModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + flattened_patches: Optional[torch.Tensor] = None, + image_embeds_position_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + image_embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Kosmos2_5ForConditionalGenerationModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> import torch + >>> from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration + + >>> repo = "microsoft/kosmos-2.5" + >>> device = "cuda:0" + >>> dtype = torch.bfloat16 # torch.float16 + >>> model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype) + >>> processor = AutoProcessor.from_pretrained(repo) + + >>> url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png" + + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> prompt = "" # + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + >>> height, width = inputs.pop("height"), inputs.pop("width") + >>> inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()} + >>> inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) + + >>> generated_ids = model.generate(**inputs,max_new_tokens=1024) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> generated_text + '1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n' + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_model_output = None + projection_attentions = None + + if image_embeds is None: + if flattened_patches is not None: + vision_model_output = self.vision_model( + flattened_patches=flattened_patches, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeds = nn.functional.normalize(vision_model_output[0], dim=-1) + image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) + + lm_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + image_embeds=image_embeds, + image_embeds_position_mask=image_embeds_position_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + outputs = lm_outputs + (image_embeds, projection_attentions, vision_model_output) + return tuple(output for output in outputs if output is not None) + + return Kosmos2_5ForConditionalGenerationModelOutput( + loss=lm_outputs.loss, + logits=lm_outputs.logits, + past_key_values=lm_outputs.past_key_values, + hidden_states=lm_outputs.hidden_states, + attentions=lm_outputs.attentions, + image_embeds=image_embeds, + projection_attentions=projection_attentions, + vision_model_output=vision_model_output, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + flattened_patches=None, + image_embeds=None, + image_embeds_position_mask=None, + past_key_values=None, + attention_mask=None, + use_cache=None, + cache_position=None, + position_ids=None, + **model_kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = self.text_model.prepare_inputs_for_generation( + input_ids, + image_embeds=image_embeds, + image_embeds_position_mask=image_embeds_position_mask, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + cache_position=cache_position, + position_ids=position_ids, + **model_kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need `flattened_patches` to be passed to model + model_inputs["flattened_patches"] = flattened_patches + + return model_inputs + + @staticmethod + # Copied from transformers.models.umt5.modeling_umt5.UMT5ForConditionalGeneration._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +__all__ = [ + "Kosmos2_5ForConditionalGeneration", + "Kosmos2_5Model", + "Kosmos2_5PreTrainedModel", +] diff --git a/src/transformers/models/kosmos2_5/processing_kosmos2_5.py b/src/transformers/models/kosmos2_5/processing_kosmos2_5.py new file mode 100644 index 00000000000000..36fe6f306a20c3 --- /dev/null +++ b/src/transformers/models/kosmos2_5/processing_kosmos2_5.py @@ -0,0 +1,164 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Kosmos2_5. +""" + +from typing import List, Optional, Union + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from ...tokenization_utils_base import TextInput +from ...utils import is_torch_available + + +if is_torch_available(): + import torch + + +class Kosmos2_5ImagesKwargs(ImagesKwargs, total=False): + max_patches: Optional[int] + num_image_tokens: Optional[int] + + +class Kosmos2_5ProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: TextKwargs + images_kwargs: Kosmos2_5ImagesKwargs + _defaults = { + "text_kwargs": { + "padding": True, + "truncation": True, + "max_length": None, + "stride": 0, + "pad_to_multiple_of": None, + "return_attention_mask": None, + }, + "images_kwargs": { + "max_patches": 4096, + "num_image_tokens": 2048, + }, + "common_kwargs": {"return_tensors": "pt"}, + } + + +class Kosmos2_5Processor(ProcessorMixin): + r""" + Constructs a Kosmos2_5 processor which wraps a PreTrainedTokenizerFast and Kosmos2_5 image processor into a single + processor. + + [`Kosmos2_5Processor`] offers all the functionalities of [`Kosmos2_5ImageProcessor`] and [`PreTrainedTokenizerFast`]. See + the docstring of [`~Kosmos2_5Processor.__call__`] and [`~Kosmos2_5Processor.decode`] for more information. + + Args: + image_processor (`Kosmos2_5ImageProcessor`): + An instance of [`Kosmos2_5ImageProcessor`]. The image processor is a required input. + tokenizer (Union[`T5TokenizerFast`, `T5Tokenizer`]): + An instance of ['T5TokenizerFast`] or ['T5Tokenizer`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "Kosmos2_5ImageProcessor" + tokenizer_class = "PreTrainedTokenizerFast" + + def __init__(self, image_processor, tokenizer): + tokenizer.return_token_type_ids = False + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, List[TextInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Kosmos2_5ProcessorKwargs], + ) -> BatchFeature: + """ + This method uses [`Kosmos2_5ImageProcessor.preprocess`] method to prepare image(s) for the model, and + [`PreTrainedTokenizerFast.__call__`] to prepare text for the model. + + Please refer to the docstring of the above two methods for more information. + + The rest of this documentation shows the arguments specific to `Kosmos2_5Processor`. + """ + if images is None and text is None: + raise ValueError("You have to specify either images or text.") + + if images is None: + raise ValueError("Kosmos2_5Processor requires images to be passed.") + + output_kwargs = self._merge_kwargs( + Kosmos2_5ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + num_image_tokens = output_kwargs["images_kwargs"].setdefault("num_image_tokens", None) + + encoding = BatchFeature() + + if images is not None: + image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"]) + image_encoding.pop("rows") + image_encoding.pop("cols") + encoding.update(image_encoding) + + prompt = "" + "" * num_image_tokens + "" + + if text is not None: + if isinstance(text, str): + text = [prompt + text] + else: + text = [prompt + t for t in text] + input = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + batch_size, seq_len = input.input_ids.shape + image_embeds_position_mask = [0, -1] + [1] * num_image_tokens + [-1] + image_embeds_position_mask += [0] * (seq_len - len(image_embeds_position_mask)) + image_embeds_position_mask = ( + torch.LongTensor(image_embeds_position_mask).unsqueeze(0).repeat(batch_size, 1) + ) + + encoding.update( + { + "input_ids": input.input_ids, + "attention_mask": input.attention_mask, + "image_embeds_position_mask": image_embeds_position_mask, + } + ) + + return encoding + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Kosmos2_5TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. + Please refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Kosmos2_5TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["Kosmos2_5Processor"] diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 176dadd5b883e1..d49028f928b63a 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -302,7 +302,7 @@ def forward( class Pix2StructVisionEncoder(nn.Module): - def __init__(self, config: Pix2StructConfig) -> None: + def __init__(self, config: Pix2StructVisionConfig) -> None: super().__init__() self.config = config self.layer = nn.ModuleList([Pix2StructVisionLayer(config) for _ in range(config.num_hidden_layers)]) @@ -531,7 +531,7 @@ class Pix2StructVisionModel(Pix2StructPreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Pix2StructVisionLayer"] - def __init__(self, config: Pix2StructConfig): + def __init__(self, config: Pix2StructVisionConfig): super().__init__(config) self.config = config diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 823c51a290713d..7f2ffb34c7e22c 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -5272,6 +5272,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Kosmos2_5ForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Kosmos2_5Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Kosmos2_5PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LayoutLMForMaskedLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 3ebda4404aae9c..41d1a0c2e80c21 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -331,6 +331,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class Kosmos2_5ImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class LayoutLMv2FeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/kosmos2_5/__init__.py b/tests/models/kosmos2_5/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/kosmos2_5/test_image_processing_kosmos2_5.py b/tests/models/kosmos2_5/test_image_processing_kosmos2_5.py new file mode 100644 index 00000000000000..95f895de2589db --- /dev/null +++ b/tests/models/kosmos2_5/test_image_processing_kosmos2_5.py @@ -0,0 +1,308 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import requests + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import Kosmos2_5ImageProcessor + + +class Kosmos2_5ImageProcessingTester: + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + size=None, + do_normalize=True, + do_convert_rgb=True, + patch_size=None, + ): + size = size if size is not None else {"height": 20, "width": 20} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.size = size + self.do_normalize = do_normalize + self.do_convert_rgb = do_convert_rgb + self.max_patches = [512, 1024, 2048, 4096] + self.patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16} + + def prepare_image_processor_dict(self): + return {"do_normalize": self.do_normalize, "do_convert_rgb": self.do_convert_rgb} + + def prepare_dummy_image(self): + img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/australia.jpg" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +class Kosmos2_5ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = Kosmos2_5ImageProcessor if is_vision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = Kosmos2_5ImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processor = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processor, "do_normalize")) + self.assertTrue(hasattr(image_processor, "do_convert_rgb")) + + def test_expected_patches(self): + dummy_image = self.image_processor_tester.prepare_dummy_image() + + image_processor = self.image_processing_class(**self.image_processor_dict) + max_patch = 2048 + + inputs = image_processor(dummy_image, return_tensors="pt", max_patches=max_patch) + self.assertTrue(torch.allclose(inputs.flattened_patches.mean(), torch.tensor(0.0606), atol=1e-3, rtol=1e-3)) + + def test_call_pil(self): + # Initialize image_processor + image_processor = self.image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + expected_hidden_dim = ( + (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) + * self.image_processor_tester.num_channels + ) + 2 + + for max_patch in self.image_processor_tester.max_patches: + # Test not batched input + encoded_images = image_processor( + image_inputs[0], return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (1, max_patch, expected_hidden_dim), + ) + + # Test batched + encoded_images = image_processor( + image_inputs, return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), + ) + + def test_call_numpy(self): + # Initialize image_processor + image_processor = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + expected_hidden_dim = ( + (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) + * self.image_processor_tester.num_channels + ) + 2 + + for max_patch in self.image_processor_tester.max_patches: + # Test not batched input + encoded_images = image_processor( + image_inputs[0], return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (1, max_patch, expected_hidden_dim), + ) + + # Test batched + encoded_images = image_processor( + image_inputs, return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), + ) + + def test_call_numpy_4_channels(self): + # Initialize image_processor + image_processor = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + self.image_processor_tester.num_channels = 4 + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + expected_hidden_dim = ( + (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) + * self.image_processor_tester.num_channels + ) + 2 + + for max_patch in self.image_processor_tester.max_patches: + # Test not batched input + encoded_images = image_processor( + image_inputs[0], return_tensors="pt", max_patches=max_patch, input_data_format="channels_last" + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (1, max_patch, expected_hidden_dim), + ) + + # Test batched + encoded_images = image_processor( + image_inputs, return_tensors="pt", max_patches=max_patch, input_data_format="channels_last" + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), + ) + self.image_processor_tester.num_channels = 3 + + def test_call_pytorch(self): + # Initialize image_processor + image_processor = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + expected_hidden_dim = ( + (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) + * self.image_processor_tester.num_channels + ) + 2 + + for max_patch in self.image_processor_tester.max_patches: + # Test not batched input + encoded_images = image_processor( + image_inputs[0], return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (1, max_patch, expected_hidden_dim), + ) + + # Test batched + encoded_images = image_processor( + image_inputs, return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), + ) + + +@require_torch +@require_vision +class Kosmos2_5ImageProcessingTestFourChannels(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = Kosmos2_5ImageProcessor if is_vision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = Kosmos2_5ImageProcessingTester(self, num_channels=4) + self.expected_encoded_image_num_channels = 3 + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processor = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processor, "do_normalize")) + self.assertTrue(hasattr(image_processor, "do_convert_rgb")) + + def test_call_pil(self): + # Initialize image_processor + image_processor = self.image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + expected_hidden_dim = ( + (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) + * (self.image_processor_tester.num_channels - 1) + ) + 2 + + for max_patch in self.image_processor_tester.max_patches: + # Test not batched input + encoded_images = image_processor( + image_inputs[0], return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (1, max_patch, expected_hidden_dim), + ) + + # Test batched + encoded_images = image_processor( + image_inputs, return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), + ) + + @unittest.skip(reason="Kosmos2_5ImageProcessor does not support 4 channels yet") # FIXME Amy + def test_call_numpy(self): + return super().test_call_numpy() + + @unittest.skip(reason="Kosmos2_5ImageProcessor does not support 4 channels yet") # FIXME Amy + def test_call_pytorch(self): + return super().test_call_torch() + + @unittest.skip( + reason="Kosmos2_5ImageProcessor does treat numpy and PIL 4 channel images consistently" + ) # FIXME Amy + def test_call_numpy_4_channels(self): + return super().test_call_torch() diff --git a/tests/models/kosmos2_5/test_modeling_kosmos2_5.py b/tests/models/kosmos2_5/test_modeling_kosmos2_5.py new file mode 100644 index 00000000000000..be2b29aad5a304 --- /dev/null +++ b/tests/models/kosmos2_5/test_modeling_kosmos2_5.py @@ -0,0 +1,848 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch KOSMOS-2.5 model.""" + +import copy +import inspect +import os +import tempfile +import unittest + +import numpy as np +import pytest +import requests +from parameterized import parameterized + +from transformers import AutoProcessor, Kosmos2_5Config +from transformers.models.kosmos2_5.configuration_kosmos2_5 import ( + Kosmos2_5TextConfig, + Kosmos2_5VisionConfig, +) +from transformers.testing_utils import ( + require_flash_attn, + require_torch, + require_torch_gpu, + require_torch_sdpa, + require_vision, + slow, + torch_device, +) +from transformers.utils import is_torch_available, is_vision_available + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + _config_zero_init, + floats_tensor, + ids_tensor, + random_attention_mask, +) +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import Kosmos2_5ForConditionalGeneration, Kosmos2_5Model + + +if is_vision_available(): + from PIL import Image + + +class Kosmos2_5VisionModelTester: + def __init__( + self, + parent, + batch_size=6, + image_size=32, + patch_size=4, + num_channels=3, + is_training=True, + hidden_size=32, + d_ff=64, + num_hidden_layers=2, + num_attention_heads=4, + dropout=0, + attention_dropout=0, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.hidden_size = hidden_size + self.d_ff = d_ff + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.patch_embed_hidden_size = patch_size * patch_size * num_channels + self.dropout = dropout + self.attention_dropout = attention_dropout + self.scope = scope + + # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + def prepare_config_and_inputs(self): + flattened_patches = floats_tensor([self.batch_size, self.seq_length, self.patch_embed_hidden_size + 2]) + config = self.get_config() + + return config, flattened_patches + + def get_config(self): + return Kosmos2_5VisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + d_ff=self.d_ff, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + patch_embed_hidden_size=self.patch_embed_hidden_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, flattened_patches = config_and_inputs + inputs_dict = {"flattened_patches": flattened_patches} + return config, inputs_dict + + +class Kosmos2_5TextModelTester: + def __init__( + self, + parent, + batch_size=6, + seq_length=7, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + ffn_dim=64, + num_hidden_layers=2, + num_attention_heads=4, + dropout=0, + attention_dropout=0, + max_position_embeddings=512, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.ffn_dim = ffn_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.max_position_embeddings = max_position_embeddings + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + if input_mask is not None: + batch_size, seq_length = input_mask.shape + rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,)) + for batch_idx, start_index in enumerate(rnd_start_indices): + input_mask[batch_idx, :start_index] = 1 + input_mask[batch_idx, start_index:] = 0 + + config = self.get_config() + + return config, input_ids, input_mask + + def get_config(self): + return Kosmos2_5TextConfig( + vocab_size=self.vocab_size, + embed_dim=self.hidden_size, + ffn_dim=self.ffn_dim, + layers=self.num_hidden_layers, + attention_heads=self.num_attention_heads, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + max_position_embeddings=self.max_position_embeddings, + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, input_mask = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +class Kosmos2_5ModelTester: + def __init__( + self, + parent, + text_kwargs=None, + vision_kwargs=None, + latent_query_num=3, + is_training=True, + ): + if text_kwargs is None: + text_kwargs = {} + if vision_kwargs is None: + vision_kwargs = {} + + self.parent = parent + self.text_model_tester = Kosmos2_5TextModelTester(parent, **text_kwargs) + self.vision_model_tester = Kosmos2_5VisionModelTester(parent, **vision_kwargs) + self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test + self.seq_length = self.text_model_tester.seq_length + self.latent_query_num = latent_query_num + self.is_training = is_training + + def prepare_config_and_inputs(self): + text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs() + vision_config, flattened_patches = self.vision_model_tester.prepare_config_and_inputs() + + # build `image_embeds_position_mask` + image_embeds_position_mask = torch.zeros_like(input_ids) + image_embeds_position_mask[:, 1 : 1 + self.latent_query_num :] = 1 + + config = self.get_config() + + return ( + config, + input_ids, + attention_mask, + image_embeds_position_mask, + flattened_patches, + ) + + def get_config(self): + return Kosmos2_5Config( + self.text_model_tester.get_config().to_dict(), + self.vision_model_tester.get_config().to_dict(), + latent_query_num=self.latent_query_num, + ) + + def create_and_check_model( + self, + config, + input_ids, + attention_mask, + image_embeds_position_mask, + flattened_patches, + ): + model = Kosmos2_5Model(config).to(torch_device).eval() + with torch.no_grad(): + result = model(input_ids, flattened_patches, image_embeds_position_mask, attention_mask) + self.parent.assertEqual( + result.last_hidden_state.shape, + ( + self.text_model_tester.batch_size, + self.text_model_tester.seq_length, + self.text_model_tester.hidden_size, + ), + ) + self.parent.assertEqual( + result.image_embeds.shape, + ( + self.text_model_tester.batch_size, + self.latent_query_num, + self.text_model_tester.hidden_size, + ), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + image_embeds_position_mask, + flattened_patches, + ) = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "image_embeds_position_mask": image_embeds_position_mask, + "flattened_patches": flattened_patches, + } + return config, inputs_dict + + +@require_torch +class Kosmos2_5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Kosmos2_5Model, Kosmos2_5ForConditionalGeneration) if is_torch_available() else () + all_generative_model_classes = (Kosmos2_5ForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": Kosmos2_5Model, + "image-to-text": Kosmos2_5ForConditionalGeneration, + } + if is_torch_available() + else {} + ) + fx_compatible = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + _is_composite = True + + # TODO: `image-to-text` pipeline for this model needs Processor. + def is_pipeline_test_to_skip( + self, + pipeline_test_casse_name, + config_class, + model_architecture, + tokenizer_name, + processor_name, + ): + return pipeline_test_casse_name == "ImageToTextPipelineTests" + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = copy.deepcopy(inputs_dict) + + if return_labels: + if model_class.__name__ == "Kosmos2_5ForConditionalGeneration": + inputs_dict["labels"] = torch.zeros( + ( + self.model_tester.text_model_tester.batch_size, + self.model_tester.text_model_tester.seq_length, + ), + dtype=torch.long, + device=torch_device, + ) + + if model_class.__name__ in [ + "Kosmos2_5Model", + "Kosmos2_5ForConditionalGeneration", + ]: + bs, _ = inputs_dict["input_ids"].shape + seqlen = self.model_tester.text_model_tester.seq_length + inputs_dict["input_ids"] = torch.arange(seqlen, device=torch_device).unsqueeze(0).expand(bs, seqlen) + inputs_dict["input_ids"] = inputs_dict["input_ids"] % self.model_tester.text_model_tester.vocab_size + inputs_dict["attention_mask"] = torch.ones((bs, seqlen), device=torch_device) + inputs_dict["image_embeds_position_mask"] = torch.zeros((bs, seqlen), device=torch_device) + inputs_dict["image_embeds_position_mask"][:, : self.model_tester.latent_query_num] = 1 + return inputs_dict + + def setUp(self): + self.model_tester = Kosmos2_5ModelTester(self) + self.config_tester = ConfigTester(self, config_class=Kosmos2_5Config, hidden_size=37) + + # overwrite from common to skip `image_to_text_projection.latent_query` + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + if name == "image_to_text_projection.latent_query": + # The original code use ` nn.Parameter(torch.randn(...))` for which this test won't pass. + continue + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["input_ids"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_load_save_without_tied_weights(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config.text_config.tie_word_embeddings = False + for model_class in self.all_model_classes: + model = model_class(config) + with tempfile.TemporaryDirectory() as d: + model.save_pretrained(d) + + model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True) + # Checking the state dicts are correct + reloaded_state = model_reloaded.state_dict() + for k, v in model.state_dict().items(): + self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded") + torch.testing.assert_close( + v, + reloaded_state[k], + msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}", + ) + # Checking there was no complain of missing weights + self.assertEqual(infos["missing_keys"], []) + + # overwrite from common in order to use `self.model_tester.text_model_tester.num_hidden_layers` + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, + "expected_num_hidden_layers", + self.model_tester.text_model_tester.num_hidden_layers + 1, + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + seq_length = self.model_tester.text_model_tester.seq_length + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.text_model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # overwrite from common in order to use `config.text_config.vocab_size` instead of `config.vocab_size` + def test_tie_model_weights(self): + if not self.test_torchscript: + self.skipTest(reason="test_torchscript is set to False") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def check_same_values(layer_1, layer_2): + equal = True + for p1, p2 in zip(layer_1.weight, layer_2.weight): + if p1.data.ne(p2.data).sum() > 0: + equal = False + return equal + + for model_class in self.all_model_classes: + config.torchscript = True + model_not_tied = model_class(config) + if model_not_tied.get_output_embeddings() is None: + continue + + config_tied = copy.deepcopy(config) + config_tied.torchscript = False + model_tied = model_class(config_tied) + params_tied = list(model_tied.parameters()) + # Check that the embedding layer and decoding layer are the same in size and in value + # self.assertTrue(check_same_values(embeddings, decoding)) + + # # Check that after modification, they remain the same. + # embeddings.weight.data.div_(2) + # # Check that the embedding layer and decoding layer are the same in size and in value + # self.assertTrue(embeddings.weight.shape, decoding.weight.shape) + # self.assertTrue(check_same_values(embeddings, decoding)) + + # # Check that after modification, they remain the same. + # decoding.weight.data.div_(4) + # # Check that the embedding layer and decoding layer are the same in size and in value + # self.assertTrue(embeddings.weight.shape, decoding.weight.shape) + # self.assertTrue(check_same_values(embeddings, decoding)) + + # Check that after resize they remain tied. + model_tied.resize_token_embeddings(config.text_config.vocab_size + 10) + params_tied_2 = list(model_tied.parameters()) + self.assertEqual(len(params_tied_2), len(params_tied)) + + # decoding.weight.data.mul_(20) + # # Check that the embedding layer and decoding layer are the same in size and in value + # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) + # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) + + @slow + def test_model_from_pretrained(self): + model_name = "microsoft/kosmos-2.5" + model = Kosmos2_5Model.from_pretrained(model_name) + self.assertIsNotNone(model) + + @unittest.skip(reason="Does not work on the tiny model as we keep hitting edge cases.") + def test_model_parallelism(self): + super().test_model_parallelism() + + # TODO: ydshieh + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + @unittest.skip(reason="kosmos-2.5 flash attention does not support right padding") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + + # TODO: ydshieh + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + @unittest.skip(reason="kosmos-2.5 test : the dummy inputs should be tweaked: dummy_input = inputs_dict") + def test_flash_attn_2_inference_equivalence(self): + pass + + # TODO: ydshieh + @require_torch_sdpa + @require_torch_gpu + @slow + @unittest.skip(reason="_update_causal_mask is not implemented yet which fails this test") + def test_sdpa_can_dispatch_on_flash(self): + pass + + # TODO: ydshieh + @unittest.skip(reason="doesn't support padding yet") + def test_eager_matches_sdpa_inference_1_bfloat16(self): + pass + + # TODO: ydshieh + @unittest.skip(reason=" the model hasn't been added to auto class") + def test_flash_attn_2_from_config(self): + pass + + @unittest.skip("This test is currently not well designed for multimodal model (float type as an input).") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip("This test is currently not well designed for multimodal model (float type as an input).") + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip("Kosmos 2.5 is multimodel and has specific input shapes.") + def test_flash_attn_2_generate_reuse_cache(self): + pass + + @pytest.mark.generate + @parameterized.expand([("greedy", 1), ("beam search", 2)]) + @unittest.skip( + "KOSMOS-2.5 doesn't support inputs embeds. The test isn't skipped by checking input args because KOSMOS-2 has `generate()` overwritten", + ) + def test_generate_from_inputs_embeds(self): + pass + + # TODO: ydshieh + @pytest.mark.generate + @unittest.skip( + "Kosmos2_5ForConditionalGeneration returns `vision_model_output` which is currently not working with `stack_model_outputs`", + ) + def test_beam_search_low_memory(self): + pass + + @pytest.mark.generate + def test_left_padding_compatibility(self): + # Overwrite because Kosmos-2.5 need to padd pixel values and pad image-attn-mask + + def _prepare_model_kwargs(input_ids, attention_mask, pad_size, signature): + model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} + if "position_ids" in signature: + position_ids = torch.cumsum(attention_mask, dim=-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + model_kwargs["position_ids"] = position_ids + if "cache_position" in signature: + cache_position = torch.arange(input_ids.shape[-1], device=torch_device) + model_kwargs["cache_position"] = cache_position + if "image_embeds_position_mask" in signature: + image_embeds_position_mask = torch.zeros_like(input_ids) + image_embeds_position_mask[:, (pad_size + 1) : pad_size + 1 + self.model_tester.latent_query_num] = 1 + model_kwargs["image_embeds_position_mask"] = image_embeds_position_mask + return model_kwargs + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + input_ids = inputs_dict["input_ids"] + flattened_patches = inputs_dict["flattened_patches"] + attention_mask = inputs_dict.get("attention_mask") + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + model = model_class(config).to(torch_device).eval() + signature = inspect.signature(model.forward).parameters.keys() + + # no cache as some models require special cache classes to be init outside forward + model.generation_config.use_cache = False + + # Without padding + model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, pad_size=0, signature=signature) + next_logits_wo_padding = model(**model_kwargs, flattened_patches=flattened_patches).logits[:, -1, :] + + # With left-padding (length 32) + # can hardcode pad_token to be 0 as we'll do attn masking anyway + pad_token_id = ( + config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 + ) + pad_size = (input_ids.shape[0], 32) + padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id + padded_input_ids = torch.cat((padding, input_ids), dim=1) + padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) + model_kwargs = _prepare_model_kwargs( + padded_input_ids, padded_attention_mask, pad_size=32, signature=signature + ) + next_logits_with_padding = model(**model_kwargs, flattened_patches=flattened_patches).logits[:, -1, :] + + # They should result in very similar logits + self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-3)) + + def _create_and_check_torchscript(self, config, inputs_dict): + if not self.test_torchscript: + self.skipTest(reason="test_torchscript is set to False") + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class) + + main_input_name = model_class.main_input_name + + try: + main_input = inputs[main_input_name] + model( + main_input, + inputs["flattened_patches"], + inputs["image_embeds_position_mask"], + ) + traced_model = torch.jit.trace( + model, + ( + main_input, + inputs["flattened_patches"], + inputs["image_embeds_position_mask"], + ), + ) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + if layer_name in loaded_model_state_dict: + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() + + +@require_vision +@require_torch +@slow +class Kosmos2_5ModelIntegrationTest(unittest.TestCase): + # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) + # Depending on the hardware we get different logits / generations + cuda_compute_capability_major_version = None + + @classmethod + def setUpClass(cls): + if is_torch_available() and torch.cuda.is_available(): + # 8 is for A100 / A10 and 7 for T4 + cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + + def run_example(self, prompt, image, model, processor): + inputs = processor(text=prompt, images=image, return_tensors="pt") + _, _ = inputs.pop("height"), inputs.pop("width") + inputs = {k: v.to(torch_device) if v is not None else None for k, v in inputs.items()} + inputs["flattened_patches"] = inputs["flattened_patches"].to(model.dtype) + + generation_outputs = model.generate( + **inputs, + max_new_tokens=1024, + ) + generated_ids = generation_outputs + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) + + return generated_ids, generated_text + + def test_eager(self): + url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png" + image = Image.open(requests.get(url, stream=True).raw) + + dtype = torch.bfloat16 + repo = "microsoft/kosmos-2.5" + model = Kosmos2_5ForConditionalGeneration.from_pretrained( + repo, device_map=torch_device, torch_dtype=dtype, attn_implementation="eager" + ) + processor = AutoProcessor.from_pretrained(repo) + prompt = "" + generated_ids, generated_text = self.run_example(prompt, image, model, processor) + EXPECTED_TEXT = { + 7: [ + "1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n" + ], + 8: [ + "1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n" + ], + } + + self.assertListEqual(generated_text, EXPECTED_TEXT[self.cuda_compute_capability_major_version]) + + prompt = "" + generated_ids, generated_text = self.run_example(prompt, image, model, processor) + + EXPECTED_TEXT = { + 7: [ + "- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000" + ], + 8: [ + "- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000" + ], + } + + self.assertListEqual(generated_text, EXPECTED_TEXT[self.cuda_compute_capability_major_version]) + + def test_sdpa(self): + url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png" + image = Image.open(requests.get(url, stream=True).raw) + + dtype = torch.bfloat16 + repo = "microsoft/kosmos-2.5" + model = Kosmos2_5ForConditionalGeneration.from_pretrained( + repo, device_map=torch_device, torch_dtype=dtype, attn_implementation="sdpa" + ) + processor = AutoProcessor.from_pretrained(repo) + prompt = "" + generated_ids, generated_text = self.run_example(prompt, image, model, processor) + EXPECTED_TEXT = { + 7: [ + "1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n", + ], + 8: [ + "1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n" + ], + } + + self.assertListEqual(generated_text, EXPECTED_TEXT[self.cuda_compute_capability_major_version]) + + prompt = "" + generated_ids, generated_text = self.run_example(prompt, image, model, processor) + + EXPECTED_TEXT = { + 7: [ + "- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000" + ], + 8: [ + "- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n\nCard Payment 50,000" + ], + } + + self.assertListEqual(generated_text, EXPECTED_TEXT[self.cuda_compute_capability_major_version]) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_FA2(self): + url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png" + image = Image.open(requests.get(url, stream=True).raw) + + dtype = torch.bfloat16 + repo = "microsoft/kosmos-2.5" + model = Kosmos2_5ForConditionalGeneration.from_pretrained( + repo, + device_map=torch_device, + torch_dtype=dtype, + attn_implementation="flash_attention_2", + ) + processor = AutoProcessor.from_pretrained(repo) + prompt = "" + generated_ids, generated_text = self.run_example(prompt, image, model, processor) + EXPECTED_TEXT = [ + "1\n[REG] BLACK SAKURA\n45,455\n1\nCOOKIE DOH SAUCES\n0\n1\nNATA DE COCO\n0\nSub Total 45,455\nPB1 (10%) 4,545\nRounding 0\nTotal 50,000\nCard Payment 50,000\n" + ] + + self.assertListEqual(generated_text, EXPECTED_TEXT) + + prompt = "" + generated_ids, generated_text = self.run_example(prompt, image, model, processor) + # A10 gives the 1st one, but A100 gives the 2nd one + EXPECTED_TEXT = [ + "- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
\nSub Total\n\n45,455\n
\nPB1 (10%)\n\n4,545\n
\nRounding\n\n0\n
\n\nTotal\n\n\n\n50,000\n\n
\n\nCard Payment 50,000", + "- **1 \\[REG\\] BLACK SAKURA** 45,455\n- **1 COOKIE DOH SAUCES** 0\n- **1 NATA DE COCO** 0\n- **Sub Total** 45,455\n- **PB1 (10%)** 4,545\n- **Rounding** 0\n- **Total** **50,000**\n", + ] + self.assertIn(generated_text[0], EXPECTED_TEXT) diff --git a/tests/models/kosmos2_5/test_processor_kosmos2_5.py b/tests/models/kosmos2_5/test_processor_kosmos2_5.py new file mode 100644 index 00000000000000..df93fc0dfa89ac --- /dev/null +++ b/tests/models/kosmos2_5/test_processor_kosmos2_5.py @@ -0,0 +1,391 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest +from tempfile import TemporaryDirectory + +import numpy as np +import pytest +import requests + +from transformers.testing_utils import ( + require_torch, + require_vision, +) +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from PIL import Image + + from transformers import ( + AutoProcessor, + AutoTokenizer, + Kosmos2_5ImageProcessor, + Kosmos2_5Processor, + PreTrainedTokenizerFast, + ) + + +@require_vision +class Kosmos2_5ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Kosmos2_5Processor + images_input_name = "flattened_patches" + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = Kosmos2_5ImageProcessor() + tokenizer = AutoTokenizer.from_pretrained("microsoft/kosmos-2.5") + processor = Kosmos2_5Processor(image_processor, tokenizer) + processor.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def test_image_procesor_load_save_reload(self): + # make sure load from Hub repo. -> save -> reload locally work + image_processor = Kosmos2_5ImageProcessor.from_pretrained("microsoft/kosmos-2.5") + with TemporaryDirectory() as tmp_dir: + image_processor.save_pretrained(tmp_dir) + reloaded_image_processor = Kosmos2_5ImageProcessor.from_pretrained(tmp_dir) + assert image_processor.to_dict() == reloaded_image_processor.to_dict() + assert image_processor.to_json_string() == reloaded_image_processor.to_json_string() + + def test_save_load_pretrained_additional_features(self): + processor = Kosmos2_5Processor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor()) + processor.save_pretrained(self.tmpdirname) + + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0) + + processor = Kosmos2_5Processor.from_pretrained( + self.tmpdirname, + bos_token="(BOS)", + eos_token="(EOS)", + do_normalize=False, + padding_value=1.0, + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, PreTrainedTokenizerFast) + + self.assertEqual( + processor.image_processor.to_json_string(), + image_processor_add_kwargs.to_json_string(), + ) + self.assertIsInstance(processor.image_processor, Kosmos2_5ImageProcessor) + + @unittest.skip(reason="kosmos-2.5 must have both image and text") + def test_image_processor(self): + pass + + @unittest.skip(reason="kosmos-2.5 must have both image and text") + def test_tokenizer(self): + pass + + def test_tokenizer_decode(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = Kosmos2_5Processor(tokenizer=tokenizer, image_processor=image_processor) + + predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] + + decoded_processor = processor.batch_decode(predicted_ids) + decoded_tok = tokenizer.batch_decode(predicted_ids) + + self.assertListEqual(decoded_tok, decoded_processor) + + def test_can_load_various_tokenizers(self): + for checkpoint in ["microsoft/kosmos-2.5", "kirp/kosmos2_5"]: + processor = AutoProcessor.from_pretrained(checkpoint) + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__) + + @require_torch + def test_model_input_names(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = Kosmos2_5Processor(tokenizer=tokenizer, image_processor=image_processor) + + input_str = "This is a test" + image_input = self.prepare_image_inputs() + + # both image and text + inputs = processor(text=input_str, images=image_input) + self.assertListEqual( + list(inputs.keys()), + [ + "flattened_patches", + "attention_mask", + "width", + "height", + "input_ids", + "image_embeds_position_mask", + ], + ) + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() + + @require_torch + @require_vision + def test_image_processor_defaults_preserved_by_image_kwargs(self): + # Rewrite as KOSMOS-2.5 processor return "flattened_patches" and not "pixel_values" + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", max_patches=1024, patch_size={"height": 8, "width": 8}) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + self.assertEqual(len(inputs["flattened_patches"][0][0]), 194) + + @require_torch + @require_vision + def test_kwargs_overrides_default_image_processor_kwargs(self): + # Rewrite as KOSMOS-2.5 processor return "flattened_patches" and not "pixel_values" + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", max_patches=4096) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, max_patches=1024) + self.assertEqual(len(inputs["flattened_patches"][0]), 1024) + + @require_torch + @require_vision + def test_unstructured_kwargs(self): + # Rewrite as KOSMOS-2.5 processor doesn't use `rescale_factor` + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + max_patches=1024, + padding="max_length", + max_length=76, + ) + + self.assertEqual(inputs["flattened_patches"].shape[1], 1024) + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + # Rewrite as KOSMOS-2.5 processor doesn't use `rescale_factor` + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = self.prepare_text_inputs(batch_size=2) + image_input = self.prepare_image_inputs(batch_size=2) + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + max_patches=1024, + padding="longest", + max_length=76, + ) + + self.assertEqual(inputs["flattened_patches"].shape[1], 1024) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + # Rewrite as KOSMOS-2.5 processor doesn't use `rescale_factor` + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"max_patches": 1024}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs["flattened_patches"].shape[1], 1024) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + # Rewrite as KOSMOS-2.5 processor doesn't use `rescale_factor` + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"max_patches": 1024}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertEqual(inputs["flattened_patches"].shape[1], 1024) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + def test_full_processor(self): + url = "https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png" + processor = AutoProcessor.from_pretrained("microsoft/kosmos-2.5") + texts = ["", ""] + expected_input_ids = [ + [100288], + [100282], + ] + expected_attention_mask = [[1], [1]] + + image = Image.open(requests.get(url, stream=True).raw) + # To match the official (microsoft) Kosmos-2 demo from which the expected values here are grabbed + image_path = os.path.join(self.tmpdirname, "image.png") + image.save(image_path) + image = Image.open(image_path) + + # test single image + outputs = processor(images=image, text=texts[0]) + self.assertListEqual( + outputs.input_ids[0].numpy().tolist(), + [0, 100283] + [0] * 2048 + [100284] + expected_input_ids[0], + ) + self.assertListEqual( + outputs.image_embeds_position_mask[0].numpy().tolist(), + [0, -1] + [1] * 2048 + [-1] + [0] * (len(expected_input_ids[0])), + ) + self.assertListEqual( + outputs.attention_mask[0].numpy().tolist(), + [1, 1] + [1] * 2048 + [1] + expected_attention_mask[0], + ) + EXPECTED_FP_1 = [ + 1.0, + 2.0, + -2.9527735710144043, + -2.672085762023926, + -2.9933173656463623, + -2.905944585800171, + -2.5891761779785156, + -2.8751866817474365, + -2.962153434753418, + -2.588062047958374, + ] + EXPECTED_FP_200 = [ + 4.0, + 45.0, + 1.5713728666305542, + 1.584628939628601, + 1.3589054346084595, + 1.6515952348709106, + 1.7014952898025513, + 1.3731343746185303, + 1.6010395288467407, + 1.6607422828674316, + ] + self.assertTupleEqual(outputs.flattened_patches.shape, (1, 4096, 770)) + np.testing.assert_allclose( + outputs.flattened_patches[0][1][:10].numpy().tolist(), + EXPECTED_FP_1, + atol=1e-9, + ) + np.testing.assert_allclose( + outputs.flattened_patches[0][200][:10].numpy().tolist(), + EXPECTED_FP_200, + atol=1e-9, + ) + + # test a batch of images and texts, right padding + outputs = processor(images=[image, image], text=texts) + self.assertListEqual( + outputs.input_ids[1].numpy().tolist(), + [0, 100283] + [0] * 2048 + [100284] + expected_input_ids[1], + ) + self.assertListEqual( + outputs.image_embeds_position_mask[1].numpy().tolist(), + [0, -1] + [1] * 2048 + [-1] + [0] * (len(expected_input_ids[1])), + ) + self.assertListEqual( + outputs.attention_mask[1].numpy().tolist(), + [1, 1] + [1] * 2048 + [1] + expected_attention_mask[1], + ) + self.assertTupleEqual(outputs.flattened_patches.shape, (2, 4096, 770)) + np.testing.assert_allclose( + outputs.flattened_patches[1][1][:10].numpy().tolist(), + EXPECTED_FP_1, + atol=1e-9, + ) + np.testing.assert_allclose( + outputs.flattened_patches[1][200][:10].numpy().tolist(), + EXPECTED_FP_200, + atol=1e-9, + ) diff --git a/utils/check_repo.py b/utils/check_repo.py index 3dbe59f192293a..08e3dfdf0ab57e 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -85,6 +85,9 @@ "Idefics2PerceiverResampler", "Idefics2VisionTransformer", "Idefics3VisionTransformer", + "Kosmos2_5TextModel", + "Kosmos2_5TextForCausalLM", + "Kosmos2_5VisionModel", "AriaTextForCausalLM", "AriaTextModel", ]