From 9ad4c93536855d78bcc3ea56a95bd53dc95d1a8e Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Fri, 6 Dec 2024 12:17:34 +0100 Subject: [PATCH] Add Aria (#34157) * Add Aria --------- Co-authored-by: Cyril Vallez Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 3 + docs/source/en/model_doc/aria.md | 106 + docs/source/en/model_doc/idefics3.md | 7 + docs/source/en/perf_infer_gpu_one.md | 2 + src/transformers/__init__.py | 32 + src/transformers/generation/utils.py | 1 + src/transformers/models/__init__.py | 1 + src/transformers/models/aria/__init__.py | 30 + .../models/aria/configuration_aria.py | 299 +++ .../models/aria/convert_aria_weights_to_hf.py | 162 ++ .../models/aria/image_processing_aria.py | 504 +++++ src/transformers/models/aria/modeling_aria.py | 1920 +++++++++++++++++ src/transformers/models/aria/modular_aria.py | 1598 ++++++++++++++ .../models/aria/processing_aria.py | 164 ++ .../models/auto/configuration_auto.py | 8 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 5 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + src/transformers/models/idefics3/__init__.py | 6 +- src/transformers/utils/dummy_pt_objects.py | 49 + .../utils/dummy_vision_objects.py | 7 + tests/generation/test_utils.py | 1 + tests/models/aria/__init__.py | 0 .../models/aria/test_image_processing_aria.py | 268 +++ tests/models/aria/test_modeling_aria.py | 669 ++++++ tests/models/aria/test_processor_aria.py | 391 ++++ utils/check_config_attributes.py | 2 +- utils/check_docstrings.py | 7 +- utils/check_repo.py | 2 + utils/modular_model_converter.py | 2 +- 32 files changed, 6244 insertions(+), 7 deletions(-) create mode 100644 docs/source/en/model_doc/aria.md create mode 100644 src/transformers/models/aria/__init__.py create mode 100644 src/transformers/models/aria/configuration_aria.py create mode 100644 src/transformers/models/aria/convert_aria_weights_to_hf.py create mode 100644 src/transformers/models/aria/image_processing_aria.py create mode 100644 src/transformers/models/aria/modeling_aria.py create mode 100644 src/transformers/models/aria/modular_aria.py create mode 100644 src/transformers/models/aria/processing_aria.py create mode 100644 tests/models/aria/__init__.py create mode 100644 tests/models/aria/test_image_processing_aria.py create mode 100644 tests/models/aria/test_modeling_aria.py create mode 100644 tests/models/aria/test_processor_aria.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3521d4ccfed894..6e325e499f342d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -810,6 +810,8 @@ title: ALIGN - local: model_doc/altclip title: AltCLIP + - local: model_doc/aria + title: Aria - local: model_doc/blip title: BLIP - local: model_doc/blip-2 diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 3cad4e663f23fd..181ec8b10fb1fc 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -62,6 +62,8 @@ Flax), PyTorch, and/or TensorFlow. | [ALBERT](model_doc/albert) | ✅ | ✅ | ✅ | | [ALIGN](model_doc/align) | ✅ | ❌ | ❌ | | [AltCLIP](model_doc/altclip) | ✅ | ❌ | ❌ | +| [Aria](model_doc/aria) | ✅ | ❌ | ❌ | +| [AriaText](model_doc/aria_text) | ✅ | ❌ | ❌ | | [Audio Spectrogram Transformer](model_doc/audio-spectrogram-transformer) | ✅ | ❌ | ❌ | | [Autoformer](model_doc/autoformer) | ✅ | ❌ | ❌ | | [Bark](model_doc/bark) | ✅ | ❌ | ❌ | @@ -172,6 +174,7 @@ Flax), PyTorch, and/or TensorFlow. | [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ | | [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ | | [Idefics3](model_doc/idefics3) | ✅ | ❌ | ❌ | +| [Idefics3VisionTransformer](model_doc/idefics3_vision) | ❌ | ❌ | ❌ | | [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ | | [Informer](model_doc/informer) | ✅ | ❌ | ❌ | | [InstructBLIP](model_doc/instructblip) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/aria.md b/docs/source/en/model_doc/aria.md new file mode 100644 index 00000000000000..9ff7a6687aa939 --- /dev/null +++ b/docs/source/en/model_doc/aria.md @@ -0,0 +1,106 @@ + + +# Aria + +## Overview + +The Aria model was proposed in [Aria: An Open Multimodal Native Mixture-of-Experts Model](https://huggingface.co/papers/2410.05993) by Li et al. from the Rhymes.AI team. + +Aria is an open multimodal-native model with best-in-class performance across a wide range of multimodal, language, and coding tasks. It has a Mixture-of-Experts architecture, with respectively 3.9B and 3.5B activated parameters per visual token and text token. + +The abstract from the paper is the following: + +*Information comes in diverse modalities. Multimodal native AI models are essential to integrate real-world information and deliver comprehensive understanding. While proprietary multimodal native models exist, their lack of openness imposes obstacles for adoptions, let alone adaptations. To fill this gap, we introduce Aria, an open multimodal native model with best-in-class performance across a wide range of multimodal, language, and coding tasks. Aria is a mixture-of-expert model with 3.9B and 3.5B activated parameters per visual token and text token, respectively. It outperforms Pixtral-12B and Llama3.2-11B, and is competitive against the best proprietary models on various multimodal tasks. We pre-train Aria from scratch following a 4-stage pipeline, which progressively equips the model with strong capabilities in language understanding, multimodal understanding, long context window, and instruction following. We open-source the model weights along with a codebase that facilitates easy adoptions and adaptations of Aria in real-world applications.* + +This model was contributed by [m-ric](https://huggingface.co/m-ric). +The original code can be found [here](https://github.com/rhymes-ai/Aria). + +## Usage tips + +Here's how to use the model for vision tasks: +```python +import requests +import torch +from PIL import Image + +from transformers import AriaProcessor, AriaForConditionalGeneration + +model_id_or_path = "rhymes-ai/Aria" + +model = AriaForConditionalGeneration.from_pretrained( + model_id_or_path, device_map="auto" +) + +processor = AriaProcessor.from_pretrained(model_id_or_path) + +image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + +messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"text": "what is the image?", "type": "text"}, + ], + } +] + +text = processor.apply_chat_template(messages, add_generation_prompt=True) +inputs = processor(text=text, images=image, return_tensors="pt") +inputs.to(model.device) + +output = model.generate( + **inputs, + max_new_tokens=15, + stop_strings=["<|im_end|>"], + tokenizer=processor.tokenizer, + do_sample=True, + temperature=0.9, +) +output_ids = output[0][inputs["input_ids"].shape[1]:] +response = processor.decode(output_ids, skip_special_tokens=True) +``` + + +## AriaImageProcessor + +[[autodoc]] AriaImageProcessor + +## AriaProcessor + +[[autodoc]] AriaProcessor + +## AriaTextConfig + +[[autodoc]] AriaTextConfig + +## AriaConfig + +[[autodoc]] AriaConfig + +## AriaTextModel + +[[autodoc]] AriaTextModel + +## AriaTextForCausalLM + +[[autodoc]] AriaTextForCausalLM + +## AriaForConditionalGeneration + +[[autodoc]] AriaForConditionalGeneration + - forward diff --git a/docs/source/en/model_doc/idefics3.md b/docs/source/en/model_doc/idefics3.md index dfaf40477a7b52..cf7c043e928901 100644 --- a/docs/source/en/model_doc/idefics3.md +++ b/docs/source/en/model_doc/idefics3.md @@ -51,6 +51,13 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) [[autodoc]] Idefics3Config +## Idefics3VisionConfig + +[[autodoc]] Idefics3VisionConfig + +## Idefics3VisionTransformer + +[[autodoc]] Idefics3VisionTransformer ## Idefics3Model diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index ec8dea2735b531..ab5e1c47a448f3 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -37,6 +37,7 @@ FlashAttention-2 is experimental and may change considerably in future versions. 2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them FlashAttention-2 is currently supported for the following architectures: +* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration) * [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon) @@ -216,6 +217,7 @@ PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.o For now, Transformers supports SDPA inference and training for the following architectures: * [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel) +* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration) * [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 23940a240876b9..2eaec8f1def96e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -170,6 +170,11 @@ "AltCLIPTextConfig", "AltCLIPVisionConfig", ], + "models.aria": [ + "AriaConfig", + "AriaProcessor", + "AriaTextConfig", + ], "models.audio_spectrogram_transformer": [ "ASTConfig", "ASTFeatureExtractor", @@ -1176,6 +1181,7 @@ _import_structure["image_processing_base"] = ["ImageProcessingMixin"] _import_structure["image_processing_utils"] = ["BaseImageProcessor"] _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] + _import_structure["models.aria"].extend(["AriaImageProcessor"]) _import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"]) _import_structure["models.bit"].extend(["BitImageProcessor"]) _import_structure["models.blip"].extend(["BlipImageProcessor"]) @@ -1406,6 +1412,15 @@ "AltCLIPVisionModel", ] ) + _import_structure["models.aria"].extend( + [ + "AriaForConditionalGeneration", + "AriaPreTrainedModel", + "AriaTextForCausalLM", + "AriaTextModel", + "AriaTextPreTrainedModel", + ] + ) _import_structure["models.audio_spectrogram_transformer"].extend( [ "ASTForAudioClassification", @@ -2461,6 +2476,8 @@ "Idefics3Model", "Idefics3PreTrainedModel", "Idefics3Processor", + "Idefics3VisionConfig", + "Idefics3VisionTransformer", ] ) _import_structure["models.ijepa"].extend( @@ -5033,6 +5050,11 @@ AltCLIPTextConfig, AltCLIPVisionConfig, ) + from .models.aria import ( + AriaConfig, + AriaProcessor, + AriaTextConfig, + ) from .models.audio_spectrogram_transformer import ( ASTConfig, ASTFeatureExtractor, @@ -6096,6 +6118,7 @@ from .image_processing_base import ImageProcessingMixin from .image_processing_utils import BaseImageProcessor from .image_utils import ImageFeatureExtractionMixin + from .models.aria import AriaImageProcessor from .models.beit import BeitFeatureExtractor, BeitImageProcessor from .models.bit import BitImageProcessor from .models.blip import BlipImageProcessor @@ -6325,6 +6348,13 @@ AltCLIPTextModel, AltCLIPVisionModel, ) + from .models.aria import ( + AriaForConditionalGeneration, + AriaPreTrainedModel, + AriaTextForCausalLM, + AriaTextModel, + AriaTextPreTrainedModel, + ) from .models.audio_spectrogram_transformer import ( ASTForAudioClassification, ASTModel, @@ -7189,6 +7219,8 @@ Idefics3Model, Idefics3PreTrainedModel, Idefics3Processor, + Idefics3VisionConfig, + Idefics3VisionTransformer, ) from .models.ijepa import ( IJepaForImageClassification, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 015cbebaa8e5dc..89c57cb913fec2 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1465,6 +1465,7 @@ def _prepare_generated_length( elif ( model_input_name == "inputs_embeds" and input_ids_length != inputs_tensor.shape[1] + and input_ids_length != 0 and not self.config.is_encoder_decoder ): generation_config.max_length -= inputs_tensor.shape[1] diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index e957d802d80e71..116b71c81ad9df 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -16,6 +16,7 @@ albert, align, altclip, + aria, audio_spectrogram_transformer, auto, autoformer, diff --git a/src/transformers/models/aria/__init__.py b/src/transformers/models/aria/__init__.py new file mode 100644 index 00000000000000..f73301321527c1 --- /dev/null +++ b/src/transformers/models/aria/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_aria import * + from .image_processing_aria import * + from .modeling_aria import * + from .processing_aria import * + +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/aria/configuration_aria.py b/src/transformers/models/aria/configuration_aria.py new file mode 100644 index 00000000000000..ff34d59f5dfe1a --- /dev/null +++ b/src/transformers/models/aria/configuration_aria.py @@ -0,0 +1,299 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/aria/modular_aria.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_aria.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation +from ..auto import CONFIG_MAPPING, AutoConfig + + +class AriaTextConfig(PretrainedConfig): + r""" + This class handles the configuration for the text component of the Aria model. + Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria + [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture. + This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 4096): + The size of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 2): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_heads + moe_num_experts (`int`, *optional*, defaults to 8): + The number of experts in the MoE layer. + moe_topk (`int`, *optional*, defaults to 2): + The number of top experts to route to for each token. + moe_num_shared_experts (`int`, *optional*, defaults to 2): + The number of shared experts. + """ + + model_type = "aria_text" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `AriaTextModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_config_key = "text_config" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size: int = 4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=2, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + moe_num_experts: int = 8, + moe_topk: int = 2, + moe_num_shared_experts: int = 2, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + self.moe_num_experts = moe_num_experts + self.moe_topk = moe_topk + self.moe_num_shared_experts = moe_num_shared_experts + + +class AriaConfig(PretrainedConfig): + r""" + This class handles the configuration for both vision and text components of the Aria model, + as well as additional parameters for image token handling and projector mapping. + Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria + [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`AriaVisionConfig` or `dict`, *optional*): + Configuration for the vision component. + vision_feature_layer (`int`, *optional*, defaults to -1): + The index of the layer to select the vision feature. + text_config (`AriaTextConfig` or `dict`, *optional*): + Configuration for the text component. + projector_patch_to_query_dict (`dict`, *optional*): + Mapping of patch sizes to query dimensions. + image_token_index (`int`, *optional*, defaults to 9): + Index used to represent image tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated normal initializer for initializing all weight matrices. + + Attributes: + model_type (`str`): + Type of the model, set to `"aria"`. + image_token_index (`int`): + Index used to represent image tokens. + projector_patch_to_query_dict (`dict`): + Mapping of patch sizes to query dimensions. + vision_config (`AriaVisionConfig`): + Configuration for the vision component. + text_config (`AriaTextConfig`): + Configuration for the text component. + """ + + model_type = "aria" + sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig} + + def __init__( + self, + vision_config=None, + vision_feature_layer: int = -1, + text_config: AriaTextConfig = None, + projector_patch_to_query_dict: Dict = None, + image_token_index: int = 9, + initializer_range: float = 0.02, + **kwargs, + ): + self.image_token_index = image_token_index + + # Convert the keys and values of projector_patch_to_query_dict to integers + # This ensures consistency even if they were provided as strings + if projector_patch_to_query_dict is None: + projector_patch_to_query_dict = { + 1225: 128, + 4900: 256, + } + self.projector_patch_to_query_dict = {int(k): int(v) for k, v in projector_patch_to_query_dict.items()} + self.max_value_projector_patch_to_query_dict = max(self.projector_patch_to_query_dict.values()) + self.vision_feature_layer = vision_feature_layer + if isinstance(vision_config, dict): + vision_config["model_type"] = "idefics3_vision" + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["idefics3_vision"]() + + self.vision_config = vision_config + self.initializer_range = initializer_range + + if isinstance(text_config, dict) and "model_type" in text_config: + text_config = AriaTextConfig(**text_config) + elif text_config is None: + text_config = AriaTextConfig() + + self.text_config = text_config + + super().__init__(**kwargs) + + +__all__ = ["AriaConfig", "AriaTextConfig"] diff --git a/src/transformers/models/aria/convert_aria_weights_to_hf.py b/src/transformers/models/aria/convert_aria_weights_to_hf.py new file mode 100644 index 00000000000000..dcc9e4d1397672 --- /dev/null +++ b/src/transformers/models/aria/convert_aria_weights_to_hf.py @@ -0,0 +1,162 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import glob + +import torch +from huggingface_hub import snapshot_download +from safetensors import safe_open + +from transformers import ( + AddedToken, + AriaForConditionalGeneration, + AriaProcessor, + AutoConfig, + AutoTokenizer, +) + + +EPILOG_TXT = """Example: + python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria + +Example for creating the old state dict file with Python: + + import torch + from aria.model.language_model.aria_llama import AriaTextForCausalLM + + # load model + kwargs = {"device_map": "auto", "torch_dtype": torch.float16} + model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", low_cpu_mem_usage=True, **kwargs) + + # load vision tower + model.get_vision_tower().load_model() + + # Save state dict + torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin") +""" + +KEYS_TO_MODIFY_MAPPING = { + "vision_tower.vision_model": "vision_tower", + "ln_ffn": "layer_norm", + "ffn": "feed_forward", + "ln_kv": "layer_norm_kv", +} + + +def load_original_state_dict(model_id): + directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"]) + + original_state_dict = {} + for path in glob.glob(f"{directory_path}/*"): + if path.endswith(".safetensors"): + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + original_state_dict[key] = f.get_tensor(key) + + return original_state_dict + + +def convert_state_dict_to_hf(state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if key.endswith(".inv_freq"): + continue + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + new_state_dict[key] = value + new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,)) + new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,)) + + return new_state_dict + + +def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id): + torch.set_default_dtype(torch.float16) + + tokenizer = AutoTokenizer.from_pretrained( + text_model_id, + extra_special_tokens={ + "image_token": "<|img|>", + "pad_token": "", + }, + ) + tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True) + tokenizer.add_special_tokens({"pad_token": ""}) + tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<|img|>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + + processor = AriaProcessor.from_pretrained( + text_model_id, + tokenizer=tokenizer, + ) + + config = AutoConfig.from_pretrained(text_model_id) + config.vision_config.hidden_size = 1152 + config.vision_config.attention_heads = 16 + config.pad_token_id = 2 + config.image_token_index = 9 + config.intermediate_size = config.moe_intermediate_size + config.auto_map = { + "AutoConfig": "modeling_aria.AriaConfig", + "AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration", + } + + with torch.device("meta"): + model = AriaForConditionalGeneration(config) + + state_dict = load_original_state_dict(old_state_dict_id) + + state_dict = convert_state_dict_to_hf(state_dict) + model.load_state_dict(state_dict, strict=False, assign=True) + + # print("Saving models") + # model.save_pretrained("local_aria", safe_serialization=False) + # processor.save_pretrained("local_aria") + print("Pushing to hub") + model.push_to_hub(output_hub_path, create_pr=True) + processor.push_to_hub(output_hub_path, create_pr=True) + + +def main(): + parser = argparse.ArgumentParser( + epilog=EPILOG_TXT, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--text_model_id", + default="rhymes-ai/Aria", + help="Hub location of the text model", + ) + parser.add_argument( + "--vision_model_id", + default="rhymes-ai/Aria", + help="Hub location of the vision model", + ) + parser.add_argument( + "--output_hub_path", + default="rhymes-ai/Aria", + help="Location on the hub of the converted model", + ) + parser.add_argument( + "--old_state_dict_id", + default="rhymes-ai/Aria", + help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`", + ) + args = parser.parse_args() + convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/aria/image_processing_aria.py b/src/transformers/models/aria/image_processing_aria.py new file mode 100644 index 00000000000000..7b00665aa2859d --- /dev/null +++ b/src/transformers/models/aria/image_processing_aria.py @@ -0,0 +1,504 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/aria/modular_aria.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_aria.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Iterable, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, select_best_resolution +from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_valid_image, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType + + +def make_batched_images(images) -> List[List[ImageInput]]: + """ + Accepts images in list or nested list format, and makes a list of images for preprocessing. + + Args: + images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): + The input image. + + Returns: + list: A list of images. + """ + if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): + return [img for img_list in images for img in img_list] + + elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + return images + + elif is_valid_image(images): + return [images] + + raise ValueError(f"Could not make batched video from {images}") + + +def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]: + """ + Divides an image into patches of a specified size. + + Args: + image (`np.array`): + The input image. + patch_size (`int`): + The size of each patch. + input_data_format (`ChannelDimension` or `str`): + The channel dimension format of the input image. + + Returns: + list: A list of np.array representing the patches. + """ + patches = [] + height, width = get_image_size(image, channel_dim=input_data_format) + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + if input_data_format == ChannelDimension.LAST: + patch = image[i : i + patch_size, j : j + patch_size] + else: + patch = image[:, i : i + patch_size, j : j + patch_size] + patches.append(patch) + + return patches + + +def _get_patch_output_size(image, target_resolution, input_data_format): + original_height, original_width = get_image_size(image, channel_dim=input_data_format) + target_height, target_width = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + return new_height, new_width + + +class AriaImageProcessor(BaseImageProcessor): + """ + A vision processor for the Aria model that handles image preprocessing. + Initialize the AriaImageProcessor. + + Args: + image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]): + Mean values for normalization. + image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]): + Standard deviation values for normalization. + max_image_size (`int`, *optional*, defaults to 980): + Maximum image size. + min_image_size (`int`, *optional*, defaults to 336): + Minimum image size. + split_resolutions (`list`, *optional*, defaults to a list of optimal,resolutions as tuples): + The optimal resolutions for splitting the image. + split_image (`bool`, *optional*, defaults to `False`): + Whether to split the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + resample (PILImageResampling, *optional*, defaults to `BICUBIC`): + The resampling filter to use if resizing the image. + """ + + def __init__( + self, + image_mean: List[float] = None, + image_std: List[float] = None, + max_image_size: int = 980, + min_image_size: int = 336, + split_resolutions: Optional[List[Tuple[int, int]]] = None, + split_image: Optional[bool] = False, + do_convert_rgb: Optional[bool] = True, + do_normalize: Optional[bool] = True, + resample: PILImageResampling = PILImageResampling.BICUBIC, + **kwargs, + ): + super().__init__(**kwargs) + + if image_mean is None: + image_mean = [0.5, 0.5, 0.5] + if image_std is None: + image_std = [0.5, 0.5, 0.5] + self.max_image_size = max_image_size + self.min_image_size = min_image_size + self.image_mean = image_mean + self.image_std = image_std + self.split_image = split_image + if split_resolutions is None: + split_resolutions = [(1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 4), (2, 3), (2, 2), (2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1)] # fmt: skip + split_resolutions = [(el[0] * 490, el[1] * 490) for el in split_resolutions] + self.split_resolutions = split_resolutions + self.do_convert_rgb = do_convert_rgb + self.do_normalize = do_normalize + self.resample = resample + + def preprocess( + self, + images: Union[ImageInput, List[ImageInput]], + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + max_image_size: Optional[int] = None, + min_image_size: Optional[int] = None, + split_image: Optional[bool] = None, + do_convert_rgb: Optional[bool] = None, + do_normalize: Optional[bool] = None, + resample: PILImageResampling = None, + return_tensors: Optional[Union[str, TensorType]] = "pt", + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Process a list of images. + + Args: + images (ImageInput or list of ImageInput): + The input image or a list of images. + image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]): + Mean values for normalization. + image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]): + Standard deviation values for normalization. + max_image_size (`int`, *optional*, defaults to `self.max_image_size` (980)): + Maximum image size. + min_image_size (`int`, *optional*, defaults to `self.min_image_size` (336)): + Minimum image size. + split_image (`bool`, *optional*, defaults to `self.split_image` (False)): + Whether to split the image. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb` (True)): + Whether to convert the image to RGB. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize` (True)): + Whether to normalize the image. + resample (PILImageResampling, *optional*, defaults to `self.resample` (BICUBIC)): + The resampling filter to use if resizing the image. + return_tensors (`str` or `TensorType`, *optional*, defaults to "pt"): + The type of tensor to return. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: + image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: + image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: + image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: + image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + BatchFeature: + A BatchFeature object containing: + - 'pixel_values': + Tensor of processed image pixel values. + - 'pixel_mask': + Boolean pixel mask. This mask is a 2D tensor of shape (max_image_size, max_image_size) where: + - True (1) values indicate pixels that belong to the original resized image. + - False (0) values indicate pixels that are part of the padding. + The mask helps distinguish between actual image content and padded areas in subsequent processing steps. + - 'num_crops': + The maximum number of crops across all images. + """ + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + max_image_size = max_image_size if max_image_size is not None else self.max_image_size + min_image_size = min_image_size if min_image_size is not None else self.min_image_size + split_image = split_image if split_image is not None else self.split_image + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + resample = resample if resample is not None else self.resample + + if max_image_size not in [490, 980]: + raise ValueError("max_image_size must be either 490 or 980") + + images = make_batched_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + resample=resample, + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + pixel_values = [] + pixel_masks = [] + num_crops = None + + for image in images: + if split_image: + crop_images = self.get_image_patches( + image, + self.split_resolutions, + max_image_size, + resample, + data_format=input_data_format, + input_data_format=input_data_format, + ) + else: + crop_images = [image] + if num_crops is None or len(crop_images) > num_crops: + num_crops = len(crop_images) + + for crop_image in crop_images: + # At this point the scale is the rescaling factor that would bring the image to max_size in its larger dimension + h, w = get_image_size(crop_image) + scale = max_image_size / max(h, w) + if w >= h: + new_size = (max(int(h * scale), min_image_size), max_image_size) # h, w + else: + new_size = (max_image_size, max(int(w * scale), min_image_size)) # h, w + + crop_image_resized = resize( + crop_image, + new_size, + resample=resample, + data_format=input_data_format, + input_data_format=input_data_format, + ) + + padding_bottom, padding_right = max_image_size - new_size[0], max_image_size - new_size[1] + crop_image_padded = pad( + crop_image_resized, + ((0, padding_bottom), (0, padding_right)), + data_format=input_data_format, + input_data_format=input_data_format, + ) + + # Create a pixel mask + pixel_mask = np.zeros((max_image_size, max_image_size), dtype=bool) + pixel_mask[: new_size[0], : new_size[1]] = 1 + pixel_masks.append(pixel_mask) + + if do_normalize: + crop_image_padded = self.normalize( + crop_image_padded / 255.0, + self.image_mean, + self.image_std, + data_format=input_data_format, + input_data_format=input_data_format, + ) + crop_image_padded = ( + to_channel_dimension_format(crop_image_padded, data_format, input_data_format) + if data_format is not None + else crop_image_padded + ) + + pixel_values.append(crop_image_padded) + return BatchFeature( + data={ + "pixel_values": np.stack(pixel_values, axis=0), + "pixel_mask": np.stack(pixel_masks, axis=0), + "num_crops": num_crops, + }, + tensor_type=return_tensors, + ) + + def _resize_for_patching( + self, image: np.array, target_resolution: tuple, resample, input_data_format: ChannelDimension + ) -> np.array: + """ + Resizes an image to a target resolution while maintaining aspect ratio. + + Args: + image (np.array): + The input image. + target_resolution (tuple): + The target resolution (height, width) of the image. + resample (`PILImageResampling`): + Resampling filter to use if resizing the image. + input_data_format (`ChannelDimension` or `str`): + The channel dimension format of the input image. + + Returns: + np.array: The resized and padded image. + """ + new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format) + + # Resize the image + resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format) + + return resized_image + + def _pad_for_patching( + self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension + ) -> np.array: + """ + Pad an image to a target resolution while maintaining aspect ratio. + """ + target_height, target_width = target_resolution + new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format) + + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + + padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x))) + + return padded_image + + def pad( + self, + image: np.ndarray, + padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]], + mode: PaddingMode = PaddingMode.CONSTANT, + constant_values: Union[float, Iterable[float]] = 0.0, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`) + dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected + as input. + + Args: + image (`np.ndarray`): + The image to pad. + padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`): + Padding to apply to the edges of the height, width axes. Can be one of three formats: + - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis. + - `((before, after),)` yields same before and after pad for height and width. + - `(pad,)` or int is a shortcut for before = after = pad width for all axes. + mode (`PaddingMode`): + The padding mode to use. Can be one of: + - `"constant"`: pads with a constant value. + - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the + vector along each axis. + - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis. + - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + `np.ndarray`: The padded image. + + """ + + # call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim + if isinstance(padding, int) or len(padding) != 4: + return pad(image, padding, mode, constant_values, data_format, input_data_format) + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + padding_mode_mapping = { + PaddingMode.CONSTANT: "constant", + PaddingMode.REFLECT: "reflect", + PaddingMode.REPLICATE: "edge", + PaddingMode.SYMMETRIC: "symmetric", + } + image = np.pad(image, padding, mode=padding_mode_mapping[mode], constant_values=constant_values) + image = ( + to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image + ) + return image + + def get_image_patches( + self, + image: np.array, + grid_pinpoints: List[Tuple[int, int]], + patch_size: int, + resample: PILImageResampling, + data_format: ChannelDimension, + input_data_format: ChannelDimension, + ) -> List[np.array]: + """ + Process an image with variable resolutions by dividing it into patches. + + Args: + image (`np.array`): + The input image to be processed. + grid_pinpoints (List[Tuple[int, int]]): + A list of possible resolutions as tuples. + patch_size (`int`): + Size of the patches to divide the image into. + resample (`PILImageResampling`): + Resampling filter to use if resizing the image. + data_format (`ChannelDimension` or `str`): + The channel dimension format for the output image. + input_data_format (`ChannelDimension` or `str`): + The channel dimension format of the input image. + + Returns: + `List[np.array]`: A list of NumPy arrays containing the processed image patches. + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints must be a list of possible resolutions.") + + possible_resolutions = grid_pinpoints + + image_size = get_image_size(image, channel_dim=input_data_format) + best_resolution = select_best_resolution(image_size, possible_resolutions) + resized_image = self._resize_for_patching( + image, best_resolution, resample=resample, input_data_format=input_data_format + ) + padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format) + + patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format) + + # make sure that all patches are in the input data format + patches = [ + to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format) + for patch in patches + ] + return patches + + +__all__ = ["AriaImageProcessor"] diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py new file mode 100644 index 00000000000000..1b4e4087b1a49d --- /dev/null +++ b/src/transformers/models/aria/modeling_aria.py @@ -0,0 +1,1920 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/aria/modular_aria.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_aria.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from ...utils.import_utils import is_torch_available +from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_aria import AriaConfig, AriaTextConfig + + +if is_torch_available(): + import torch + from torch import nn + + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "AriaTextConfig" + + +class AriaTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + AriaTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class AriaProjectorMLP(nn.Module): + """ + Feed-Forward Network module for the Aria Projector. + + Args: + in_features (`int`): + Input embedding dimension. + hidden_features (`int`): + Hidden dimension of the feed-forward network. + output_dim (`int`): + Output dimension. + """ + + def __init__(self, in_features, hidden_features, output_dim): + super().__init__() + self.linear_in = nn.Linear(in_features, hidden_features, bias=False) + self.linear_out = nn.Linear(hidden_features, output_dim, bias=False) + self.act = ACT2FN["gelu_new"] + + def forward(self, hidden_states): + hidden_states = self.act(self.linear_in(hidden_states)) + hidden_states = self.linear_out(hidden_states) + return hidden_states + + +class AriaCrossAttention(nn.Module): + """ + Aria Cross-Attention module. + + Args: + config (`AriaConfig`): + The configuration to use. + """ + + def __init__(self, config: AriaConfig, dropout_rate: float = 0): + super().__init__() + hidden_size = config.vision_config.hidden_size + num_heads = config.vision_config.num_attention_heads + self.num_heads = num_heads + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) + + # Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48 + self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True) + self.linear = nn.Linear(hidden_size, hidden_size) + self.dropout = nn.Dropout(dropout_rate) + + self.layer_norm = nn.LayerNorm(hidden_size) + self.layer_norm_kv = nn.LayerNorm(hidden_size) + + def forward(self, key_value_states, hidden_states, attn_mask=None): + """ + Forward pass of the AriaCrossAttention module. + + Args: + key_value_states (`torch.Tensor`): + Input tensor for key and value. + hidden_states (`torch.Tensor`): + Input tensor for query. + attn_mask (`torch.Tensor`, *optional*, defaults to None): + Attention mask. + + Returns: + torch.Tensor: + Output tensor after cross-attention. + """ + query = self.q_proj(self.layer_norm(hidden_states)) + + key_value_states = self.layer_norm_kv(key_value_states) + key = self.k_proj(key_value_states) + value = self.v_proj(key_value_states) + + attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask) + + attn_output = self.dropout(self.linear(attn_output)) + + return attn_output + + +class AriaProjector(nn.Module): + """ + Aria Projector module. + + This module projects vision features into the language model's embedding space, enabling interaction between vision and language components. + + Args: + config (`AriaConfig`): + Configuration object for the model. + """ + + def __init__( + self, + config: AriaConfig, + ): + super().__init__() + + self.patch_to_query_dict = config.projector_patch_to_query_dict + self.in_features = config.vision_config.hidden_size + self.num_heads = config.vision_config.num_attention_heads + self.kv_dim = config.vision_config.hidden_size + self.hidden_features = config.text_config.hidden_size + self.output_dim = config.text_config.hidden_size + + self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features)) + + self.cross_attn = AriaCrossAttention(config) + + self.layer_norm = nn.LayerNorm(self.in_features) + self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim) + + def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + """ + Forward pass of the Projector module. + + Args: + key_value_states (`torch.Tensor`): + Input tensor of shape (batch_size, num_patches, kv_dim). + attn_mask (`torch.Tensor`, *optional*, default is None): + Attention mask. + + Returns: + `torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim). + """ + batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1] + + if num_patches not in self.patch_to_query_dict.keys(): + raise KeyError( + f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}." + ) + query_num = self.patch_to_query_dict[num_patches] + + queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1) + + if attn_mask is not None: + attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) + attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1) + + attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask) + + out = self.feed_forward(self.layer_norm(attention_out)) + + return out + + +class AriaSharedExpertsMLP(nn.Module): + """ + Shared Expert MLP for shared experts. + + Unlike routed experts, shared experts process all tokens without routing. + This class reconfigures the intermediate size in comparison to the LlamaMLP. + + Args: + config (`AriaTextConfig`): Configuration object for the Aria language model. + """ + + def __init__(self, config: AriaTextConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert): + """ + Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts. + + Args: + token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features). + expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features). + tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. + + Returns: + torch.Tensor: Output tensor of shape (num_tokens, out_features). + """ + num_tokens = token_states.shape[0] + out_features = expert_weights.shape[-1] + output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device) + + cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0) + # Insert zero at the begining for offset index's convenience + zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device) + cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens)) + + for expert_num in range(expert_weights.shape[0]): + start = cumsum_num_tokens[expert_num] + end = cumsum_num_tokens[expert_num + 1] + tokens = token_states[start:end] + + out = torch.matmul(tokens, expert_weights[expert_num]) + output[start:end] = out + return output + + +class AriaGroupedExpertsGemm(nn.Module): + """ + Grouped GEMM (General Matrix Multiplication) module for efficient expert computation. + This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm) + for optimized performance. If the grouped_gemm library is not installed, it gracefully + falls back to a sequential GEMM implementation, which may be slower but ensures + functionality. + + Args: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. + groups (`int`): + Number of expert groups. + """ + + def __init__(self, in_features, out_features, groups): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.groups = groups + self.weight = nn.Parameter(torch.empty(groups, in_features, out_features)) + + def forward(self, input, tokens_per_expert): + """ + Perform grouped matrix multiplication. + + Args: + input (`torch.Tensor`): + Input tensor of shape (num_tokens, in_features). + tokens_per_expert (`torch.Tensor`): + Number of tokens assigned to each expert. + + Returns: + torch.Tensor: Output tensor of shape (num_tokens, out_features). + """ + return sequential_experts_gemm( + input, + self.weight, + tokens_per_expert.cpu(), + ) + + +class AriaGroupedExpertsMLP(nn.Module): + """ + Grouped MLP module for Mixture of Experts. + + Args: + config (`AriaTextConfig`): + Configuration object for the model. + """ + + def __init__(self, config: AriaTextConfig) -> None: + super().__init__() + self.config = config + self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts) + self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts) + + def forward(self, permuted_tokens, tokens_per_expert): + """ + Forward pass of the Grouped MLP. + + Args: + permuted_tokens (torch.Tensor): Permuted input tokens. + tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. + + Returns: + torch.Tensor: Output tensor after passing through the MLP. + """ + fc1_output = self.fc1(permuted_tokens, tokens_per_expert) + projection, gate = torch.chunk(fc1_output, 2, dim=-1) + fc1_output = nn.functional.silu(projection) * gate + fc2_output = self.fc2(fc1_output, tokens_per_expert) + return fc2_output + + +# Token permutation adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587 +class AriaTextMoELayer(nn.Module): + """ + Aria Text Mixture of Experts (MoE) Layer. + + This layer applies a gating mechanism to route input tokens to different experts. + + Args: + config (`AriaTextConfig`): + Configuration object for the text component of the model. + """ + + def __init__(self, config: AriaTextConfig): + super().__init__() + + self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False) + self.experts = AriaGroupedExpertsMLP(config) + self.shared_experts = AriaSharedExpertsMLP(config) + self.config = config + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the MoE Layer. + + Args: + hidden_states (`torch.Tensor`): + Input tensor of shape (batch_size, sequence_length, hidden_size). + + Returns: + torch.Tensor: Output tensor after passing through the MoE layer. + + Process: + 1. Route tokens to experts using the router. + 2. Permute tokens based on routing decisions. + 3. Process tokens through experts. + 4. Unpermute and combine expert outputs. + 5. Add shared expert output to the final result. + """ + original_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + + # Top K Routing + logits = self.router(hidden_states) + top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1) + scores = nn.functional.softmax(top_logits, dim=-1) + + original_dtype = top_indices.dtype + + tokens_per_expert = torch.histc( + top_indices.flatten().to(torch.float32), + bins=self.config.moe_num_experts, + min=0, + max=self.config.moe_num_experts - 1, + ).to(original_dtype) + indices = top_indices + + # Token permutation + flatten_indices = indices.view(-1) + sorted_indices = torch.argsort(flatten_indices) + permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk) + + # Process through experts + expert_output = self.experts(permuted_tokens, tokens_per_expert) + + # Token unpermutation + unpermuted_tokens = torch.zeros( + (scores.shape[0] * self.config.moe_topk, expert_output.size(1)), + dtype=expert_output.dtype, + device=expert_output.device, + ) + unpermuted_tokens.index_copy_(0, sorted_indices, expert_output) + unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1)) + + output = (unpermuted_tokens * scores.unsqueeze(-1)).sum(dim=1).view(original_shape) + + # Add shared expert output + shared_expert_output = self.shared_experts(hidden_states.view(original_shape)) + return output + shared_expert_output + + +class AriaTextRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[AriaTextConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`AriaTextRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class AriaTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: AriaTextConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = AriaTextRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class AriaTextFlashAttention2(AriaTextAttention): + """ + AriaText flash attention module. This module inherits from `AriaTextAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (AriaTextRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class AriaTextSdpaAttention(AriaTextAttention): + """ + AriaText attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `AriaTextAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from AriaTextAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "AriaTextModel is using AriaTextSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +ARIA_TEXT_ATTENTION_CLASSES = { + "eager": AriaTextAttention, + "flash_attention_2": AriaTextFlashAttention2, + "sdpa": AriaTextSdpaAttention, +} + + +class AriaTextDecoderLayer(nn.Module): + """ + Aria Text Decoder Layer. + + This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network. + + Args: + config (`AriaTextConfig`): + Configuration object for the text component of the model. + layer_idx (`int`): + Index of the layer. + """ + + def __init__(self, config: AriaTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ARIA_TEXT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.mlp = AriaTextMoELayer(config) + self.input_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class AriaTextPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. + """ + + config_class = AriaConfig + base_model_prefix = "model" + _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"] + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = False + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, AriaGroupedExpertsGemm): + module.weight.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=std) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + + +ARIA_TEXT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`AriaTextConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Aria Model outputting raw hidden-states without any specific head on top.", + ARIA_TEXT_START_DOCSTRING, +) +class AriaPreTrainedModel(PreTrainedModel): + config_class = AriaTextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["AriaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, AriaProjector): + nn.init.trunc_normal_(module.query, std=std) + + +ARIA_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare AriaText Model outputting raw hidden-states without any specific head on top.", + ARIA_TEXT_START_DOCSTRING, +) +class AriaTextModel(AriaTextPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AriaTextDecoderLayer`] + + Args: + config: AriaTextConfig + """ + + def __init__(self, config: AriaTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = AriaTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): + """ + Aria model for causal language modeling tasks. + + This class extends `LlamaForCausalLM` to incorporate the Mixture of Experts (MoE) approach, + allowing for more efficient and scalable language modeling. + + Args: + config (`AriaTextConfig`): + Configuration object for the model. + """ + + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + config_class = AriaTextConfig + + def __init__(self, config: AriaTextConfig): + super().__init__(config) + self.model = AriaTextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, AriaTextForCausalLM + + >>> model = AriaTextForCausalLM.from_pretrained("meta-aria_text/AriaText-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-aria_text/AriaText-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@dataclass +class AriaCausalLMOutputWithPast(ModelOutput): + """ + Base class for Aria causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +ARIA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor`, *optional*): + Input token IDs. + pixel_values (`torch.FloatTensor`, *optional*): + Pixel values of the images. + pixel_mask (`torch.LongTensor`, *optional*): + Mask for the pixel values. + attention_mask (`torch.Tensor`, *optional*): + Attention mask. + position_ids (`torch.LongTensor`, *optional*): + Position IDs. + past_key_values (`List[torch.FloatTensor]`, *optional*): + Past key values for efficient processing. + inputs_embeds (`torch.FloatTensor`, *optional*): + Input embeddings. + labels (`torch.LongTensor`, *optional*): + Labels for computing the language modeling loss. + use_cache (`bool`, *optional*): + Whether to use the model's cache mechanism. + output_attentions (`bool`, *optional*): + Whether to output attention weights. + output_hidden_states (`bool`, *optional*): + Whether to output hidden states. + return_dict (`bool`, *optional*): + Whether to return a `ModelOutput` object. + num_logits_to_keep (`int`, *optional*, defaults to 0): + Calculate logits for the last `num_logits_to_keep` tokens, or all `input_ids` if `0`. + cache_position (`torch.LongTensor`, *optional*): + Cache positions. + **loss_kwargs: + Additional keyword arguments for loss calculation. +""" + +ARIA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config (`AriaConfig`): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """Aria model for conditional generation tasks. + + This model combines a vision tower, a multi-modal projector, and a language model + to perform tasks that involve both image and text inputs.""", + ARIA_START_DOCSTRING, +) +class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): + config_class = AriaConfig + _supports_flash_attn_2 = False + _supports_sdpa = False + + def __init__(self, config: AriaConfig): + super().__init__(config) + + self.vision_tower = AutoModel.from_config(config.vision_config) + self.multi_modal_projector = AriaProjector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" + self.post_init() + + def _create_patch_attention_mask(self, pixel_mask): + if pixel_mask is None: + return None + + patches_subgrid = pixel_mask.unfold( + dimension=1, + size=self.vision_tower.config.patch_size, + step=self.vision_tower.config.patch_size, + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, + size=self.vision_tower.config.patch_size, + step=self.vision_tower.config.patch_size, + ) + return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.FloatTensor = None, + vision_feature_layer: int = -1, + ): + patch_attention_mask = self._create_patch_attention_mask(pixel_mask) + image_outputs = self.vision_tower( + pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True + ) + image_attn_mask = None + if patch_attention_mask is not None: + flattened_mask = patch_attention_mask.flatten(1) + image_attn_mask = torch.logical_not(flattened_mask) + + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) + return image_features + + @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + pixel_mask: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + num_logits_to_keep: int = 0, + cache_position: Optional[torch.LongTensor] = None, + **loss_kwargs, + ) -> Union[Tuple, AriaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). + Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Returns: + + Example: + + ```python + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from transformers import AutoProcessor, AutoModel + >>> from transformers.image_utils import load_image + + >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible + >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg") + >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg") + >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg") + + >>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria") + >>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", torch_dtype=torch.bfloat16, device_map="auto") + + >>> # Create inputs + >>> messages = [ + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image"}, + ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."}, + ... {"type": "image"}, + ... {"type": "text", "text": "What can we see in this image?"}, + ... ] + ... }, + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image"}, + ... {"type": "text", "text": "In which city is that bridge located?"}, + ... ] + ... } + ... ] + + >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages] + >>> images = [[image1, image2], [image3]] + >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=256) + >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) + + >>> print(generated_texts[0]) + Assistant: There are buildings, trees, lights, and water visible in this image. + + >>> print(generated_texts[1]) + Assistant: The bridge is in San Francisco. + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and images + if pixel_values is not None and inputs_embeds.shape[1] != 1: + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) + ) + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + else: + image_embeds = input_ids == self.config.image_token_index + special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) + image_features = self.get_image_features( + pixel_values=pixel_values, + pixel_mask=pixel_mask, + vision_feature_layer=self.config.vision_feature_layer, + ) + n_images, n_features_per_image = image_features.shape[0], image_features.shape[1] + n_image_features = n_images * n_features_per_image + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + num_logits_to_keep=num_logits_to_keep, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return AriaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + pixel_mask=None, + attention_mask=None, + cache_position=None, + num_logits_to_keep=None, + **kwargs, + ): + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_mask"] = pixel_mask + + return model_inputs + + +__all__ = [ + "AriaForConditionalGeneration", + "AriaPreTrainedModel", + "AriaTextPreTrainedModel", + "AriaTextModel", + "AriaTextForCausalLM", +] diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py new file mode 100644 index 00000000000000..78c6e08bdfd0e5 --- /dev/null +++ b/src/transformers/models/aria/modular_aria.py @@ -0,0 +1,1598 @@ +# coding=utf-8 +# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np + +from ...activations import ACT2FN +from ...configuration_utils import PretrainedConfig +from ...generation import GenerationMixin +from ...image_processing_utils import BaseImageProcessor, BatchFeature, select_best_resolution +from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...modeling_utils import PreTrainedModel +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils import ( + PreTokenizedInput, + TextInput, +) +from ...utils import ( + TensorType, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.import_utils import is_torch_available +from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer +from ..llama.configuration_llama import LlamaConfig +from ..llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaMLP, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, +) +from ..llava.modeling_llava import LlavaCausalLMOutputWithPast +from ..llava_next.image_processing_llava_next import divide_to_patches, make_batched_images + + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + from torch import nn + + +def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert): + """ + Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts. + + Args: + token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features). + expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features). + tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. + + Returns: + torch.Tensor: Output tensor of shape (num_tokens, out_features). + """ + num_tokens = token_states.shape[0] + out_features = expert_weights.shape[-1] + output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device) + + cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0) + # Insert zero at the begining for offset index's convenience + zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device) + cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens)) + + for expert_num in range(expert_weights.shape[0]): + start = cumsum_num_tokens[expert_num] + end = cumsum_num_tokens[expert_num + 1] + tokens = token_states[start:end] + + out = torch.matmul(tokens, expert_weights[expert_num]) + output[start:end] = out + return output + + +class AriaTextConfig(LlamaConfig): + r""" + This class handles the configuration for the text component of the Aria model. + Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria + [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture. + This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 4096): + The size of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 2): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_heads + moe_num_experts (`int`, *optional*, defaults to 8): + The number of experts in the MoE layer. + moe_topk (`int`, *optional*, defaults to 2): + The number of top experts to route to for each token. + moe_num_shared_experts (`int`, *optional*, defaults to 2): + The number of shared experts. + """ + + model_type = "aria_text" + base_config_key = "text_config" + + def __init__( + self, + intermediate_size: int = 4096, + moe_num_experts: int = 8, + moe_topk: int = 2, + moe_num_shared_experts: int = 2, + pad_token_id=2, + **super_kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **super_kwargs) + self.intermediate_size = intermediate_size + self.moe_num_experts = moe_num_experts + self.moe_topk = moe_topk + self.moe_num_shared_experts = moe_num_shared_experts + + +class AriaConfig(PretrainedConfig): + r""" + This class handles the configuration for both vision and text components of the Aria model, + as well as additional parameters for image token handling and projector mapping. + Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria + [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`AriaVisionConfig` or `dict`, *optional*): + Configuration for the vision component. + vision_feature_layer (`int`, *optional*, defaults to -1): + The index of the layer to select the vision feature. + text_config (`AriaTextConfig` or `dict`, *optional*): + Configuration for the text component. + projector_patch_to_query_dict (`dict`, *optional*): + Mapping of patch sizes to query dimensions. + image_token_index (`int`, *optional*, defaults to 9): + Index used to represent image tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated normal initializer for initializing all weight matrices. + + Attributes: + model_type (`str`): + Type of the model, set to `"aria"`. + image_token_index (`int`): + Index used to represent image tokens. + projector_patch_to_query_dict (`dict`): + Mapping of patch sizes to query dimensions. + vision_config (`AriaVisionConfig`): + Configuration for the vision component. + text_config (`AriaTextConfig`): + Configuration for the text component. + """ + + model_type = "aria" + sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig} + + def __init__( + self, + vision_config=None, + vision_feature_layer: int = -1, + text_config: AriaTextConfig = None, + projector_patch_to_query_dict: Dict = None, + image_token_index: int = 9, + initializer_range: float = 0.02, + **kwargs, + ): + self.image_token_index = image_token_index + + # Convert the keys and values of projector_patch_to_query_dict to integers + # This ensures consistency even if they were provided as strings + if projector_patch_to_query_dict is None: + projector_patch_to_query_dict = { + 1225: 128, + 4900: 256, + } + self.projector_patch_to_query_dict = {int(k): int(v) for k, v in projector_patch_to_query_dict.items()} + self.max_value_projector_patch_to_query_dict = max(self.projector_patch_to_query_dict.values()) + self.vision_feature_layer = vision_feature_layer + if isinstance(vision_config, dict): + vision_config["model_type"] = "idefics3_vision" + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["idefics3_vision"]() + + self.vision_config = vision_config + self.initializer_range = initializer_range + + if isinstance(text_config, dict) and "model_type" in text_config: + text_config = AriaTextConfig(**text_config) + elif text_config is None: + text_config = AriaTextConfig() + + self.text_config = text_config + + super().__init__(**kwargs) + + +class AriaTextRMSNorm(LlamaRMSNorm): + pass + + +class AriaProjectorMLP(nn.Module): + """ + Feed-Forward Network module for the Aria Projector. + + Args: + in_features (`int`): + Input embedding dimension. + hidden_features (`int`): + Hidden dimension of the feed-forward network. + output_dim (`int`): + Output dimension. + """ + + def __init__(self, in_features, hidden_features, output_dim): + super().__init__() + self.linear_in = nn.Linear(in_features, hidden_features, bias=False) + self.linear_out = nn.Linear(hidden_features, output_dim, bias=False) + self.act = ACT2FN["gelu_new"] + + def forward(self, hidden_states): + hidden_states = self.act(self.linear_in(hidden_states)) + hidden_states = self.linear_out(hidden_states) + return hidden_states + + +class AriaCrossAttention(nn.Module): + """ + Aria Cross-Attention module. + + Args: + config (`AriaConfig`): + The configuration to use. + """ + + def __init__(self, config: AriaConfig, dropout_rate: float = 0): + super().__init__() + hidden_size = config.vision_config.hidden_size + num_heads = config.vision_config.num_attention_heads + self.num_heads = num_heads + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) + + # Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48 + self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True) + self.linear = nn.Linear(hidden_size, hidden_size) + self.dropout = nn.Dropout(dropout_rate) + + self.layer_norm = nn.LayerNorm(hidden_size) + self.layer_norm_kv = nn.LayerNorm(hidden_size) + + def forward(self, key_value_states, hidden_states, attn_mask=None): + """ + Forward pass of the AriaCrossAttention module. + + Args: + key_value_states (`torch.Tensor`): + Input tensor for key and value. + hidden_states (`torch.Tensor`): + Input tensor for query. + attn_mask (`torch.Tensor`, *optional*, defaults to None): + Attention mask. + + Returns: + torch.Tensor: + Output tensor after cross-attention. + """ + query = self.q_proj(self.layer_norm(hidden_states)) + + key_value_states = self.layer_norm_kv(key_value_states) + key = self.k_proj(key_value_states) + value = self.v_proj(key_value_states) + + attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask) + + attn_output = self.dropout(self.linear(attn_output)) + + return attn_output + + +class AriaProjector(nn.Module): + """ + Aria Projector module. + + This module projects vision features into the language model's embedding space, enabling interaction between vision and language components. + + Args: + config (`AriaConfig`): + Configuration object for the model. + """ + + def __init__( + self, + config: AriaConfig, + ): + super().__init__() + + self.patch_to_query_dict = config.projector_patch_to_query_dict + self.in_features = config.vision_config.hidden_size + self.num_heads = config.vision_config.num_attention_heads + self.kv_dim = config.vision_config.hidden_size + self.hidden_features = config.text_config.hidden_size + self.output_dim = config.text_config.hidden_size + + self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features)) + + self.cross_attn = AriaCrossAttention(config) + + self.layer_norm = nn.LayerNorm(self.in_features) + self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim) + + def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + """ + Forward pass of the Projector module. + + Args: + key_value_states (`torch.Tensor`): + Input tensor of shape (batch_size, num_patches, kv_dim). + attn_mask (`torch.Tensor`, *optional*, default is None): + Attention mask. + + Returns: + `torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim). + """ + batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1] + + if num_patches not in self.patch_to_query_dict.keys(): + raise KeyError( + f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}." + ) + query_num = self.patch_to_query_dict[num_patches] + + queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1) + + if attn_mask is not None: + attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) + attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1) + + attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask) + + out = self.feed_forward(self.layer_norm(attention_out)) + + return out + + +def _get_patch_output_size(image, target_resolution, input_data_format): + original_height, original_width = get_image_size(image, channel_dim=input_data_format) + target_height, target_width = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + return new_height, new_width + + +class AriaImageProcessor(BaseImageProcessor): + """ + A vision processor for the Aria model that handles image preprocessing. + Initialize the AriaImageProcessor. + + Args: + image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]): + Mean values for normalization. + image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]): + Standard deviation values for normalization. + max_image_size (`int`, *optional*, defaults to 980): + Maximum image size. + min_image_size (`int`, *optional*, defaults to 336): + Minimum image size. + split_resolutions (`list`, *optional*, defaults to a list of optimal,resolutions as tuples): + The optimal resolutions for splitting the image. + split_image (`bool`, *optional*, defaults to `False`): + Whether to split the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + resample (PILImageResampling, *optional*, defaults to `BICUBIC`): + The resampling filter to use if resizing the image. + """ + + def __init__( + self, + image_mean: List[float] = None, + image_std: List[float] = None, + max_image_size: int = 980, + min_image_size: int = 336, + split_resolutions: Optional[List[Tuple[int, int]]] = None, + split_image: Optional[bool] = False, + do_convert_rgb: Optional[bool] = True, + do_normalize: Optional[bool] = True, + resample: PILImageResampling = PILImageResampling.BICUBIC, + **kwargs, + ): + super().__init__(**kwargs) + + if image_mean is None: + image_mean = [0.5, 0.5, 0.5] + if image_std is None: + image_std = [0.5, 0.5, 0.5] + self.max_image_size = max_image_size + self.min_image_size = min_image_size + self.image_mean = image_mean + self.image_std = image_std + self.split_image = split_image + if split_resolutions is None: + split_resolutions = [(1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 4), (2, 3), (2, 2), (2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1)] # fmt: skip + split_resolutions = [(el[0] * 490, el[1] * 490) for el in split_resolutions] + self.split_resolutions = split_resolutions + self.do_convert_rgb = do_convert_rgb + self.do_normalize = do_normalize + self.resample = resample + + def preprocess( + self, + images: Union[ImageInput, List[ImageInput]], + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + max_image_size: Optional[int] = None, + min_image_size: Optional[int] = None, + split_image: Optional[bool] = None, + do_convert_rgb: Optional[bool] = None, + do_normalize: Optional[bool] = None, + resample: PILImageResampling = None, + return_tensors: Optional[Union[str, TensorType]] = "pt", + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Process a list of images. + + Args: + images (ImageInput or list of ImageInput): + The input image or a list of images. + image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]): + Mean values for normalization. + image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]): + Standard deviation values for normalization. + max_image_size (`int`, *optional*, defaults to `self.max_image_size` (980)): + Maximum image size. + min_image_size (`int`, *optional*, defaults to `self.min_image_size` (336)): + Minimum image size. + split_image (`bool`, *optional*, defaults to `self.split_image` (False)): + Whether to split the image. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb` (True)): + Whether to convert the image to RGB. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize` (True)): + Whether to normalize the image. + resample (PILImageResampling, *optional*, defaults to `self.resample` (BICUBIC)): + The resampling filter to use if resizing the image. + return_tensors (`str` or `TensorType`, *optional*, defaults to "pt"): + The type of tensor to return. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: + image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: + image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: + image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: + image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + BatchFeature: + A BatchFeature object containing: + - 'pixel_values': + Tensor of processed image pixel values. + - 'pixel_mask': + Boolean pixel mask. This mask is a 2D tensor of shape (max_image_size, max_image_size) where: + - True (1) values indicate pixels that belong to the original resized image. + - False (0) values indicate pixels that are part of the padding. + The mask helps distinguish between actual image content and padded areas in subsequent processing steps. + - 'num_crops': + The maximum number of crops across all images. + """ + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + max_image_size = max_image_size if max_image_size is not None else self.max_image_size + min_image_size = min_image_size if min_image_size is not None else self.min_image_size + split_image = split_image if split_image is not None else self.split_image + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + resample = resample if resample is not None else self.resample + + if max_image_size not in [490, 980]: + raise ValueError("max_image_size must be either 490 or 980") + + images = make_batched_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + resample=resample, + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + pixel_values = [] + pixel_masks = [] + num_crops = None + + for image in images: + if split_image: + crop_images = self.get_image_patches( + image, + self.split_resolutions, + max_image_size, + resample, + data_format=input_data_format, + input_data_format=input_data_format, + ) + else: + crop_images = [image] + if num_crops is None or len(crop_images) > num_crops: + num_crops = len(crop_images) + + for crop_image in crop_images: + # At this point the scale is the rescaling factor that would bring the image to max_size in its larger dimension + h, w = get_image_size(crop_image) + scale = max_image_size / max(h, w) + if w >= h: + new_size = (max(int(h * scale), min_image_size), max_image_size) # h, w + else: + new_size = (max_image_size, max(int(w * scale), min_image_size)) # h, w + + crop_image_resized = resize( + crop_image, + new_size, + resample=resample, + data_format=input_data_format, + input_data_format=input_data_format, + ) + + padding_bottom, padding_right = max_image_size - new_size[0], max_image_size - new_size[1] + crop_image_padded = pad( + crop_image_resized, + ((0, padding_bottom), (0, padding_right)), + data_format=input_data_format, + input_data_format=input_data_format, + ) + + # Create a pixel mask + pixel_mask = np.zeros((max_image_size, max_image_size), dtype=bool) + pixel_mask[: new_size[0], : new_size[1]] = 1 + pixel_masks.append(pixel_mask) + + if do_normalize: + crop_image_padded = self.normalize( + crop_image_padded / 255.0, + self.image_mean, + self.image_std, + data_format=input_data_format, + input_data_format=input_data_format, + ) + crop_image_padded = ( + to_channel_dimension_format(crop_image_padded, data_format, input_data_format) + if data_format is not None + else crop_image_padded + ) + + pixel_values.append(crop_image_padded) + return BatchFeature( + data={ + "pixel_values": np.stack(pixel_values, axis=0), + "pixel_mask": np.stack(pixel_masks, axis=0), + "num_crops": num_crops, + }, + tensor_type=return_tensors, + ) + + def _resize_for_patching( + self, image: np.array, target_resolution: tuple, resample, input_data_format: ChannelDimension + ) -> np.array: + """ + Resizes an image to a target resolution while maintaining aspect ratio. + + Args: + image (np.array): + The input image. + target_resolution (tuple): + The target resolution (height, width) of the image. + resample (`PILImageResampling`): + Resampling filter to use if resizing the image. + input_data_format (`ChannelDimension` or `str`): + The channel dimension format of the input image. + + Returns: + np.array: The resized and padded image. + """ + new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format) + + # Resize the image + resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format) + + return resized_image + + def _pad_for_patching( + self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension + ) -> np.array: + """ + Pad an image to a target resolution while maintaining aspect ratio. + """ + target_height, target_width = target_resolution + new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format) + + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + + padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x))) + + return padded_image + + def pad( + self, + image: np.ndarray, + padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]], + mode: PaddingMode = PaddingMode.CONSTANT, + constant_values: Union[float, Iterable[float]] = 0.0, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`) + dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected + as input. + + Args: + image (`np.ndarray`): + The image to pad. + padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`): + Padding to apply to the edges of the height, width axes. Can be one of three formats: + - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis. + - `((before, after),)` yields same before and after pad for height and width. + - `(pad,)` or int is a shortcut for before = after = pad width for all axes. + mode (`PaddingMode`): + The padding mode to use. Can be one of: + - `"constant"`: pads with a constant value. + - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the + vector along each axis. + - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis. + - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + `np.ndarray`: The padded image. + + """ + + # call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim + if isinstance(padding, int) or len(padding) != 4: + return pad(image, padding, mode, constant_values, data_format, input_data_format) + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + padding_mode_mapping = { + PaddingMode.CONSTANT: "constant", + PaddingMode.REFLECT: "reflect", + PaddingMode.REPLICATE: "edge", + PaddingMode.SYMMETRIC: "symmetric", + } + image = np.pad(image, padding, mode=padding_mode_mapping[mode], constant_values=constant_values) + image = ( + to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image + ) + return image + + def get_image_patches( + self, + image: np.array, + grid_pinpoints: List[Tuple[int, int]], + patch_size: int, + resample: PILImageResampling, + data_format: ChannelDimension, + input_data_format: ChannelDimension, + ) -> List[np.array]: + """ + Process an image with variable resolutions by dividing it into patches. + + Args: + image (`np.array`): + The input image to be processed. + grid_pinpoints (List[Tuple[int, int]]): + A list of possible resolutions as tuples. + patch_size (`int`): + Size of the patches to divide the image into. + resample (`PILImageResampling`): + Resampling filter to use if resizing the image. + data_format (`ChannelDimension` or `str`): + The channel dimension format for the output image. + input_data_format (`ChannelDimension` or `str`): + The channel dimension format of the input image. + + Returns: + `List[np.array]`: A list of NumPy arrays containing the processed image patches. + """ + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints must be a list of possible resolutions.") + + possible_resolutions = grid_pinpoints + + image_size = get_image_size(image, channel_dim=input_data_format) + best_resolution = select_best_resolution(image_size, possible_resolutions) + resized_image = self._resize_for_patching( + image, best_resolution, resample=resample, input_data_format=input_data_format + ) + padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format) + + patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format) + + # make sure that all patches are in the input data format + patches = [ + to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format) + for patch in patches + ] + return patches + + +class AriaProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + "max_image_size": 980, + "split_image": False, + }, + "return_tensors": TensorType.PYTORCH, + } + + +class AriaProcessor(ProcessorMixin): + """ + AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer. + + Args: + image_processor (`AriaImageProcessor`, *optional*): + The AriaImageProcessor to use for image preprocessing. + tokenizer (`PreTrainedTokenizerBase`, *optional*): + An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input. + chat_template (`str`, *optional*): + A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. + size_conversion (`Dict`, *optional*): + A dictionary indicating size conversions for images. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template", "size_conversion"] + image_processor_class = "AriaImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer: Union[AutoTokenizer, str] = None, + chat_template: Optional[str] = None, + size_conversion: Optional[Dict[Union[float, int], int]] = None, + ): + if size_conversion is None: + size_conversion = {490: 128, 980: 256} + self.size_conversion = {int(k): v for k, v in size_conversion.items()} + + if tokenizer is not None and tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.unk_token + + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + images: Optional[ImageInput] = None, + audio=None, + videos=None, + **kwargs: Unpack[AriaProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). + + Args: + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + AriaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + if images is not None: + image_inputs = self.image_processor( + images, + **output_kwargs["images_kwargs"], + ) + # expand the image_token according to the num_crops and tokens per image + tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]] + prompt_strings = [] + num_crops = image_inputs.pop("num_crops") * tokens_per_image + for sample in text: + sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops) + prompt_strings.append(sample) + + else: + image_inputs = {} + prompt_strings = text + + text_inputs = self.tokenizer( + prompt_strings, + **output_kwargs["text_kwargs"], + ) + + return BatchFeature(data={**text_inputs, **image_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +class AriaSharedExpertsMLP(LlamaMLP): + """ + Shared Expert MLP for shared experts. + + Unlike routed experts, shared experts process all tokens without routing. + This class reconfigures the intermediate size in comparison to the LlamaMLP. + + Args: + config (`AriaTextConfig`): Configuration object for the Aria language model. + """ + + def __init__(self, config: AriaTextConfig): + super().__init__(self) + self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts + + +class AriaGroupedExpertsGemm(nn.Module): + """ + Grouped GEMM (General Matrix Multiplication) module for efficient expert computation. + This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm) + for optimized performance. If the grouped_gemm library is not installed, it gracefully + falls back to a sequential GEMM implementation, which may be slower but ensures + functionality. + + Args: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. + groups (`int`): + Number of expert groups. + """ + + def __init__(self, in_features, out_features, groups): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.groups = groups + self.weight = nn.Parameter(torch.empty(groups, in_features, out_features)) + + def forward(self, input, tokens_per_expert): + """ + Perform grouped matrix multiplication. + + Args: + input (`torch.Tensor`): + Input tensor of shape (num_tokens, in_features). + tokens_per_expert (`torch.Tensor`): + Number of tokens assigned to each expert. + + Returns: + torch.Tensor: Output tensor of shape (num_tokens, out_features). + """ + return sequential_experts_gemm( + input, + self.weight, + tokens_per_expert.cpu(), + ) + + +class AriaGroupedExpertsMLP(nn.Module): + """ + Grouped MLP module for Mixture of Experts. + + Args: + config (`AriaTextConfig`): + Configuration object for the model. + """ + + def __init__(self, config: AriaTextConfig) -> None: + super().__init__() + self.config = config + self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts) + self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts) + + def forward(self, permuted_tokens, tokens_per_expert): + """ + Forward pass of the Grouped MLP. + + Args: + permuted_tokens (torch.Tensor): Permuted input tokens. + tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. + + Returns: + torch.Tensor: Output tensor after passing through the MLP. + """ + fc1_output = self.fc1(permuted_tokens, tokens_per_expert) + projection, gate = torch.chunk(fc1_output, 2, dim=-1) + fc1_output = nn.functional.silu(projection) * gate + fc2_output = self.fc2(fc1_output, tokens_per_expert) + return fc2_output + + +# Token permutation adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587 +class AriaTextMoELayer(nn.Module): + """ + Aria Text Mixture of Experts (MoE) Layer. + + This layer applies a gating mechanism to route input tokens to different experts. + + Args: + config (`AriaTextConfig`): + Configuration object for the text component of the model. + """ + + def __init__(self, config: AriaTextConfig): + super().__init__() + + self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False) + self.experts = AriaGroupedExpertsMLP(config) + self.shared_experts = AriaSharedExpertsMLP(config) + self.config = config + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the MoE Layer. + + Args: + hidden_states (`torch.Tensor`): + Input tensor of shape (batch_size, sequence_length, hidden_size). + + Returns: + torch.Tensor: Output tensor after passing through the MoE layer. + + Process: + 1. Route tokens to experts using the router. + 2. Permute tokens based on routing decisions. + 3. Process tokens through experts. + 4. Unpermute and combine expert outputs. + 5. Add shared expert output to the final result. + """ + original_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + + # Top K Routing + logits = self.router(hidden_states) + top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1) + scores = nn.functional.softmax(top_logits, dim=-1) + + original_dtype = top_indices.dtype + + tokens_per_expert = torch.histc( + top_indices.flatten().to(torch.float32), + bins=self.config.moe_num_experts, + min=0, + max=self.config.moe_num_experts - 1, + ).to(original_dtype) + indices = top_indices + + # Token permutation + flatten_indices = indices.view(-1) + sorted_indices = torch.argsort(flatten_indices) + permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk) + + # Process through experts + expert_output = self.experts(permuted_tokens, tokens_per_expert) + + # Token unpermutation + unpermuted_tokens = torch.zeros( + (scores.shape[0] * self.config.moe_topk, expert_output.size(1)), + dtype=expert_output.dtype, + device=expert_output.device, + ) + unpermuted_tokens.index_copy_(0, sorted_indices, expert_output) + unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1)) + + output = (unpermuted_tokens * scores.unsqueeze(-1)).sum(dim=1).view(original_shape) + + # Add shared expert output + shared_expert_output = self.shared_experts(hidden_states.view(original_shape)) + return output + shared_expert_output + + +class AriaTextDecoderLayer(LlamaDecoderLayer): + """ + Aria Text Decoder Layer. + + This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network. + + Args: + config (`AriaTextConfig`): + Configuration object for the text component of the model. + layer_idx (`int`): + Index of the layer. + """ + + def __init__(self, config: AriaTextConfig, layer_idx: int): + super().__init__(self) + self.mlp = AriaTextMoELayer(config) + + +class AriaTextPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. + """ + + config_class = AriaConfig + base_model_prefix = "model" + _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"] + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = False + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, AriaGroupedExpertsGemm): + module.weight.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=std) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + + +class AriaPreTrainedModel(LlamaPreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, AriaProjector): + nn.init.trunc_normal_(module.query, std=std) + + +class AriaTextModel(LlamaModel): + def __init__(self, config: AriaTextConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + self.post_init() + + +class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM): + """ + Aria model for causal language modeling tasks. + + This class extends `LlamaForCausalLM` to incorporate the Mixture of Experts (MoE) approach, + allowing for more efficient and scalable language modeling. + + Args: + config (`AriaTextConfig`): + Configuration object for the model. + """ + + _tied_weights_keys = ["lm_head.weight"] + config_class = AriaTextConfig + + def __init__(self, config: AriaTextConfig): + super().__init__(config) + self.model = AriaTextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + +class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): + pass + + +ARIA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor`, *optional*): + Input token IDs. + pixel_values (`torch.FloatTensor`, *optional*): + Pixel values of the images. + pixel_mask (`torch.LongTensor`, *optional*): + Mask for the pixel values. + attention_mask (`torch.Tensor`, *optional*): + Attention mask. + position_ids (`torch.LongTensor`, *optional*): + Position IDs. + past_key_values (`List[torch.FloatTensor]`, *optional*): + Past key values for efficient processing. + inputs_embeds (`torch.FloatTensor`, *optional*): + Input embeddings. + labels (`torch.LongTensor`, *optional*): + Labels for computing the language modeling loss. + use_cache (`bool`, *optional*): + Whether to use the model's cache mechanism. + output_attentions (`bool`, *optional*): + Whether to output attention weights. + output_hidden_states (`bool`, *optional*): + Whether to output hidden states. + return_dict (`bool`, *optional*): + Whether to return a `ModelOutput` object. + num_logits_to_keep (`int`, *optional*, defaults to 0): + Calculate logits for the last `num_logits_to_keep` tokens, or all `input_ids` if `0`. + cache_position (`torch.LongTensor`, *optional*): + Cache positions. + **loss_kwargs: + Additional keyword arguments for loss calculation. +""" + +ARIA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config (`AriaConfig`): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """Aria model for conditional generation tasks. + + This model combines a vision tower, a multi-modal projector, and a language model + to perform tasks that involve both image and text inputs.""", + ARIA_START_DOCSTRING, +) +class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): + config_class = AriaConfig + _supports_flash_attn_2 = False + _supports_sdpa = False + + def __init__(self, config: AriaConfig): + super().__init__(config) + + self.vision_tower = AutoModel.from_config(config.vision_config) + self.multi_modal_projector = AriaProjector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" + self.post_init() + + def _create_patch_attention_mask(self, pixel_mask): + if pixel_mask is None: + return None + + patches_subgrid = pixel_mask.unfold( + dimension=1, + size=self.vision_tower.config.patch_size, + step=self.vision_tower.config.patch_size, + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, + size=self.vision_tower.config.patch_size, + step=self.vision_tower.config.patch_size, + ) + return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.FloatTensor = None, + vision_feature_layer: int = -1, + ): + patch_attention_mask = self._create_patch_attention_mask(pixel_mask) + image_outputs = self.vision_tower( + pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True + ) + image_attn_mask = None + if patch_attention_mask is not None: + flattened_mask = patch_attention_mask.flatten(1) + image_attn_mask = torch.logical_not(flattened_mask) + + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) + return image_features + + @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + pixel_mask: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + num_logits_to_keep: int = 0, + cache_position: Optional[torch.LongTensor] = None, + **loss_kwargs, + ) -> Union[Tuple, AriaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`). + Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Returns: + + Example: + + ```python + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from transformers import AutoProcessor, AutoModel + >>> from transformers.image_utils import load_image + + >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible + >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg") + >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg") + >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg") + + >>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria") + >>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", torch_dtype=torch.bfloat16, device_map="auto") + + >>> # Create inputs + >>> messages = [ + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image"}, + ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."}, + ... {"type": "image"}, + ... {"type": "text", "text": "What can we see in this image?"}, + ... ] + ... }, + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image"}, + ... {"type": "text", "text": "In which city is that bridge located?"}, + ... ] + ... } + ... ] + + >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages] + >>> images = [[image1, image2], [image3]] + >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device) + + >>> # Generate + >>> generated_ids = model.generate(**inputs, max_new_tokens=256) + >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) + + >>> print(generated_texts[0]) + Assistant: There are buildings, trees, lights, and water visible in this image. + + >>> print(generated_texts[1]) + Assistant: The bridge is in San Francisco. + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and images + if pixel_values is not None and inputs_embeds.shape[1] != 1: + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) + ) + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + else: + image_embeds = input_ids == self.config.image_token_index + special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) + image_features = self.get_image_features( + pixel_values=pixel_values, + pixel_mask=pixel_mask, + vision_feature_layer=self.config.vision_feature_layer, + ) + n_images, n_features_per_image = image_features.shape[0], image_features.shape[1] + n_image_features = n_images * n_features_per_image + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + num_logits_to_keep=num_logits_to_keep, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return AriaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + pixel_mask=None, + attention_mask=None, + cache_position=None, + num_logits_to_keep=None, + **kwargs, + ): + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_mask"] = pixel_mask + + return model_inputs + + +__all__ = [ + "AriaConfig", + "AriaTextConfig", + "AriaImageProcessor", + "AriaProcessor", + "AriaForConditionalGeneration", + "AriaPreTrainedModel", + "AriaTextPreTrainedModel", + "AriaTextModel", + "AriaTextForCausalLM", +] diff --git a/src/transformers/models/aria/processing_aria.py b/src/transformers/models/aria/processing_aria.py new file mode 100644 index 00000000000000..2cfbd72a002061 --- /dev/null +++ b/src/transformers/models/aria/processing_aria.py @@ -0,0 +1,164 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/aria/modular_aria.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_aria.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Optional, Union + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils import PreTokenizedInput, TextInput +from ...utils import TensorType +from ..auto import AutoTokenizer + + +class AriaProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + "max_image_size": 980, + "split_image": False, + }, + "return_tensors": TensorType.PYTORCH, + } + + +class AriaProcessor(ProcessorMixin): + """ + AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer. + + Args: + image_processor (`AriaImageProcessor`, *optional*): + The AriaImageProcessor to use for image preprocessing. + tokenizer (`PreTrainedTokenizerBase`, *optional*): + An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input. + chat_template (`str`, *optional*): + A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. + size_conversion (`Dict`, *optional*): + A dictionary indicating size conversions for images. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template", "size_conversion"] + image_processor_class = "AriaImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer: Union[AutoTokenizer, str] = None, + chat_template: Optional[str] = None, + size_conversion: Optional[Dict[Union[float, int], int]] = None, + ): + if size_conversion is None: + size_conversion = {490: 128, 980: 256} + self.size_conversion = {int(k): v for k, v in size_conversion.items()} + + if tokenizer is not None and tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.unk_token + + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + images: Optional[ImageInput] = None, + audio=None, + videos=None, + **kwargs: Unpack[AriaProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). + + Args: + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + AriaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + if images is not None: + image_inputs = self.image_processor( + images, + **output_kwargs["images_kwargs"], + ) + # expand the image_token according to the num_crops and tokens per image + tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]] + prompt_strings = [] + num_crops = image_inputs.pop("num_crops") * tokens_per_image + for sample in text: + sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops) + prompt_strings.append(sample) + + else: + image_inputs = {} + prompt_strings = text + + text_inputs = self.tokenizer( + prompt_strings, + **output_kwargs["text_kwargs"], + ) + + return BatchFeature(data={**text_inputs, **image_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["AriaProcessor"] diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c1f2d689df7095..cc3a7d5baaeb49 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -35,6 +35,8 @@ ("albert", "AlbertConfig"), ("align", "AlignConfig"), ("altclip", "AltCLIPConfig"), + ("aria", "AriaConfig"), + ("aria_text", "AriaTextConfig"), ("audio-spectrogram-transformer", "ASTConfig"), ("autoformer", "AutoformerConfig"), ("bark", "BarkConfig"), @@ -135,6 +137,7 @@ ("idefics", "IdeficsConfig"), ("idefics2", "Idefics2Config"), ("idefics3", "Idefics3Config"), + ("idefics3_vision", "Idefics3VisionConfig"), ("ijepa", "IJepaConfig"), ("imagegpt", "ImageGPTConfig"), ("informer", "InformerConfig"), @@ -327,6 +330,8 @@ ("albert", "ALBERT"), ("align", "ALIGN"), ("altclip", "AltCLIP"), + ("aria", "Aria"), + ("aria_text", "AriaText"), ("audio-spectrogram-transformer", "Audio Spectrogram Transformer"), ("autoformer", "Autoformer"), ("bark", "Bark"), @@ -441,6 +446,7 @@ ("idefics", "IDEFICS"), ("idefics2", "Idefics2"), ("idefics3", "Idefics3"), + ("idefics3_vision", "Idefics3VisionTransformer"), ("ijepa", "I-JEPA"), ("imagegpt", "ImageGPT"), ("informer", "Informer"), @@ -687,6 +693,8 @@ ("clip_vision_model", "clip"), ("qwen2_audio_encoder", "qwen2_audio"), ("clip_text_model", "clip"), + ("aria_text", "aria"), + ("idefics3_vision", "idefics3"), ("siglip_vision_model", "siglip"), ("chinese_clip_vision_model", "chinese_clip"), ("rt_detr_resnet", "rt_detr"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index e19c8efd205552..a699314f858928 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -54,6 +54,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( [ ("align", ("EfficientNetImageProcessor",)), + ("aria", ("AriaImageProcessor")), ("beit", ("BeitImageProcessor",)), ("bit", ("BitImageProcessor",)), ("blip", ("BlipImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7a7cd9d475884c..e8e4814e6a0f6a 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -35,6 +35,8 @@ ("albert", "AlbertModel"), ("align", "AlignModel"), ("altclip", "AltCLIPModel"), + ("aria", "AriaForConditionalGeneration"), + ("aria_text", "AriaTextModel"), ("audio-spectrogram-transformer", "ASTModel"), ("autoformer", "AutoformerModel"), ("bark", "BarkModel"), @@ -132,6 +134,7 @@ ("idefics", "IdeficsModel"), ("idefics2", "Idefics2Model"), ("idefics3", "Idefics3Model"), + ("idefics3_vision", "Idefics3VisionTransformer"), ("ijepa", "IJepaModel"), ("imagegpt", "ImageGPTModel"), ("informer", "InformerModel"), @@ -464,6 +467,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping + ("aria_text", "AriaTextForCausalLM"), ("bart", "BartForCausalLM"), ("bert", "BertLMHeadModel"), ("bert-generation", "BertGenerationDecoder"), @@ -768,6 +772,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict( [ + ("aria", "AriaForConditionalGeneration"), ("blip", "BlipForConditionalGeneration"), ("blip-2", "Blip2ForConditionalGeneration"), ("chameleon", "ChameleonForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index c1f23bc1cb3f18..3e475b1be211fa 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -47,6 +47,7 @@ [ ("align", "AlignProcessor"), ("altclip", "AltCLIPProcessor"), + ("aria", "AriaProcessor"), ("bark", "BarkProcessor"), ("blip", "BlipProcessor"), ("blip-2", "Blip2Processor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index e246bf3094c9cb..3cc181ac87adc4 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -68,6 +68,7 @@ ), ), ("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("bart", ("BartTokenizer", "BartTokenizerFast")), ( diff --git a/src/transformers/models/idefics3/__init__.py b/src/transformers/models/idefics3/__init__.py index 35b1df5c678439..cec07ca6f5e2d3 100644 --- a/src/transformers/models/idefics3/__init__.py +++ b/src/transformers/models/idefics3/__init__.py @@ -16,7 +16,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available -_import_structure = {"configuration_idefics3": ["Idefics3Config"]} +_import_structure = {"configuration_idefics3": ["Idefics3Config", "Idefics3VisionConfig"]} try: @@ -38,11 +38,12 @@ "Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", + "Idefics3VisionTransformer", ] _import_structure["processing_idefics3"] = ["Idefics3Processor"] if TYPE_CHECKING: - from .configuration_idefics3 import Idefics3Config + from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig try: if not is_vision_available(): @@ -62,6 +63,7 @@ Idefics3ForConditionalGeneration, Idefics3Model, Idefics3PreTrainedModel, + Idefics3VisionTransformer, ) from .processing_idefics3 import Idefics3Processor diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index d770b83df935a5..8747939a62ce2a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -685,6 +685,41 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class AriaForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class AriaPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class AriaTextForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class AriaTextModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class AriaTextPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ASTForAudioClassification(metaclass=DummyObject): _backends = ["torch"] @@ -4978,6 +5013,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Idefics3VisionConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Idefics3VisionTransformer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class IJepaForImageClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index d2ccaeaaed23a8..3ebda4404aae9c 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -23,6 +23,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class AriaImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class BeitFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 12faeb8da9256a..76ab793e3a36c0 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1727,6 +1727,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self): num_hidden_layers = text_config.num_hidden_layers inputs_embeds = model.get_input_embeddings()(input_ids) + max_cache_len += inputs_embeds.shape[1] outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict) # we should get `max_length` in shape, not `max_length - embeds_length` diff --git a/tests/models/aria/__init__.py b/tests/models/aria/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/aria/test_image_processing_aria.py b/tests/models/aria/test_image_processing_aria.py new file mode 100644 index 00000000000000..8a0f84d34eefed --- /dev/null +++ b/tests/models/aria/test_image_processing_aria.py @@ -0,0 +1,268 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers.image_utils import PILImageResampling +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin + + +if is_vision_available(): + from PIL import Image + + from transformers import AriaImageProcessor + + +if is_torch_available(): + import torch + + +class AriaImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + num_images=1, + min_resolution=30, + max_resolution=40, + size=None, + max_image_size=980, + min_image_size=336, + split_resolutions=None, + split_image=True, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + do_convert_rgb=True, + resample=PILImageResampling.BICUBIC, + ): + super().__init__() + self.size = size if size is not None else {"longest_edge": max_resolution} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.num_images = num_images + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.resample = resample + self.max_image_size = max_image_size + self.min_image_size = min_image_size + self.split_resolutions = split_resolutions if split_resolutions is not None else [[980, 980]] + self.split_image = split_image + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + def prepare_image_processor_dict(self): + return { + "image_mean": self.image_mean, + "image_std": self.image_std, + "max_image_size": self.max_image_size, + "min_image_size": self.min_image_size, + "split_resolutions": self.split_resolutions, + "split_image": self.split_image, + "do_convert_rgb": self.do_convert_rgb, + "do_normalize": self.do_normalize, + "resample": self.resample, + } + + def get_expected_values(self, image_inputs, batched=False): + """ + This function computes the expected height and width when providing images to AriaImageProcessor, + assuming do_resize is set to True. The expected size in that case the max image size. + """ + return self.max_image_size, self.max_image_size + + def expected_output_image_shape(self, images): + height, width = self.get_expected_values(images, batched=True) + return self.num_channels, height, width + + def prepare_image_inputs( + self, + batch_size=None, + min_resolution=None, + max_resolution=None, + num_channels=None, + num_images=None, + size_divisor=None, + equal_resolution=False, + numpify=False, + torchify=False, + ): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + + One can specify whether the images are of the same resolution or not. + """ + assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time" + + batch_size = batch_size if batch_size is not None else self.batch_size + min_resolution = min_resolution if min_resolution is not None else self.min_resolution + max_resolution = max_resolution if max_resolution is not None else self.max_resolution + num_channels = num_channels if num_channels is not None else self.num_channels + num_images = num_images if num_images is not None else self.num_images + + images_list = [] + for i in range(batch_size): + images = [] + for j in range(num_images): + if equal_resolution: + width = height = max_resolution + else: + # To avoid getting image width/height 0 + if size_divisor is not None: + # If `size_divisor` is defined, the image needs to have width/size >= `size_divisor` + min_resolution = max(size_divisor, min_resolution) + width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2) + images.append(np.random.randint(255, size=(num_channels, width, height), dtype=np.uint8)) + images_list.append(images) + + if not numpify and not torchify: + # PIL expects the channel dimension as last dimension + images_list = [[Image.fromarray(np.moveaxis(image, 0, -1)) for image in images] for images in images_list] + + if torchify: + images_list = [[torch.from_numpy(image) for image in images] for images in images_list] + + if numpify: + # Numpy images are typically in channels last format + images_list = [[image.transpose(1, 2, 0) for image in images] for images in images_list] + + return images_list + + +@require_torch +@require_vision +class AriaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = AriaImageProcessor if is_vision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = AriaImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + self.assertTrue(hasattr(image_processing, "max_image_size")) + self.assertTrue(hasattr(image_processing, "min_image_size")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "split_image")) + + def test_call_numpy(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for sample_images in image_inputs: + for image in sample_images: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) + + def test_call_numpy_4_channels(self): + # Aria always processes images as RGB, so it always returns images with 3 channels + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processor_dict = self.image_processor_dict + image_processing = self.image_processing_class(**image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + + for sample_images in image_inputs: + for image in sample_images: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) + + def test_call_pil(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for images in image_inputs: + for image in images: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) + + def test_call_pytorch(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + + for images in image_inputs: + for image in images: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + tuple(encoded_images.shape), + (self.image_processor_tester.batch_size, *expected_output_image_shape), + ) diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py new file mode 100644 index 00000000000000..d3458530ac349e --- /dev/null +++ b/tests/models/aria/test_modeling_aria.py @@ -0,0 +1,669 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Aria model.""" + +import gc +import unittest + +import requests + +from transformers import ( + AriaConfig, + AriaForConditionalGeneration, + AriaTextConfig, + AutoProcessor, + AutoTokenizer, + is_torch_available, + is_vision_available, +) +from transformers.models.idefics3 import Idefics3VisionConfig +from transformers.testing_utils import ( + require_bitsandbytes, + require_torch, + require_torch_gpu, + require_vision, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch +else: + is_torch_greater_or_equal_than_2_0 = False + +if is_vision_available(): + from PIL import Image + + +class AriaVisionText2TextModelTester: + def __init__( + self, + parent, + ignore_index=-100, + image_token_index=9, + projector_hidden_act="gelu", + seq_length=7, + vision_feature_select_strategy="default", + vision_feature_layer=-1, + text_config=AriaTextConfig( + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=1, + hidden_size=32, + intermediate_size=64, + max_position_embeddings=60, + model_type="aria_moe_lm", + moe_intermediate_size=4, + moe_num_experts=4, + moe_topk=2, + num_attention_heads=20, + num_experts_per_tok=3, + num_hidden_layers=2, + num_key_value_heads=20, + rope_theta=5000000, + vocab_size=99, + eos_token_id=2, + head_dim=2, + ), + is_training=True, + vision_config=Idefics3VisionConfig( + image_size=358, + patch_size=10, + num_channels=3, + is_training=True, + hidden_size=32, + projection_dim=20, + num_hidden_layers=2, + num_attention_heads=16, + intermediate_size=10, + dropout=0.1, + attention_dropout=0.1, + initializer_range=0.02, + ), + ): + self.parent = parent + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.text_config = text_config + self.vision_config = vision_config + self.pad_token_id = text_config.pad_token_id + self.eos_token_id = text_config.eos_token_id + self.num_hidden_layers = text_config.num_hidden_layers + self.vocab_size = text_config.vocab_size + self.hidden_size = text_config.hidden_size + self.num_attention_heads = text_config.num_attention_heads + self.is_training = is_training + + self.batch_size = 10 + self.num_channels = 3 + self.image_size = 358 + self.num_image_tokens = 128 + self.seq_length = seq_length + self.num_image_tokens + + def get_config(self): + return AriaConfig( + text_config=self.text_config, + vision_config=self.vision_config, + ignore_index=self.ignore_index, + image_token_index=self.image_token_index, + projector_hidden_act=self.projector_hidden_act, + vision_feature_select_strategy=self.vision_feature_select_strategy, + vision_feature_layer=self.vision_feature_layer, + eos_token_id=self.eos_token_id, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config.num_channels, + self.vision_config.image_size, + self.vision_config.image_size, + ] + ) + config = self.get_config() + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + attention_mask = input_ids.ne(1).to(torch_device) + input_ids[input_ids == config.image_token_index] = self.pad_token_id + input_ids[:, : self.num_image_tokens] = config.image_token_index + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + def create_and_check_aria_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask): + model = AriaForConditionalGeneration(config=config) + model.to(torch_device) + model.eval() + with torch.autocast(device_type="cuda", dtype=torch.float16): + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values.to(torch.bfloat16), + return_dict=True, + )["logits"] + self.parent.assertFalse(torch.isnan(logits).any().item()) + + +@require_torch +class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + """ + Model tester for `AriaForConditionalGeneration`. + """ + + all_model_classes = (AriaForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (AriaForConditionalGeneration,) if is_torch_available() else () + test_pruning = False + test_head_masking = False + _is_composite = True + + def setUp(self): + self.model_tester = AriaVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=AriaConfig, has_text_modality=False) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + wte = model.get_input_embeddings() + inputs["inputs_embeds"] = wte(input_ids) + + with torch.no_grad(): + model(**inputs) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + # while some other models require pixel_values to be present + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + inputs_embeds = model.get_input_embeddings()(input_ids) + + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + self.assertTrue(torch.allclose(out_embeds, out_ids)) + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Compile not yet supported because in LLava models") + def test_sdpa_can_compile_dynamic(self): + pass + + @unittest.skip(reason="Compile not yet supported because in LLava models") + def test_sdpa_can_dispatch_on_flash(self): + pass + + @unittest.skip(reason="") + def test_new_cache_format_0(self): + pass + + @unittest.skip(reason="") + def test_new_cache_format_1(self): + pass + + @unittest.skip(reason="") + def test_new_cache_format_2(self): + pass + + @unittest.skip(reason="Feedforward chunking is not yet supported") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="Unstable test") + def test_initialization(self): + pass + + @unittest.skip(reason="Unstable test") + def test_dola_decoding_sample(self): + pass + + @unittest.skip(reason="Unsupported") + def test_generate_from_inputs_embeds_0_greedy(self): + pass + + @unittest.skip(reason="Unsupported") + def test_generate_from_inputs_embeds_1_beam_search(self): + pass + + @unittest.skip(reason="Unsupported") + def test_generate_with_static_cache(self): + pass + + +@require_torch +class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = AutoProcessor.from_pretrained("rhymes-ai/Aria") + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + @slow + @require_bitsandbytes + def test_small_model_integration_test(self): + # Let' s make sure we test the preprocessing to replace what is used + model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True) + + prompt = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:" + image_file = "https://aria-vl.github.io/static/images/view.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt") + + EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip + self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) + + output = model.generate(**inputs, max_new_tokens=20) + EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip + + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_llama_single(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "rhymes-ai/Aria" + + model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + + prompt = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT:" + image_file = "https://aria-vl.github.io/static/images/view.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) + + output = model.generate(**inputs, max_new_tokens=900, do_sample=False) + EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip + + self.assertEqual( + processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_llama_batched(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "rhymes-ai/Aria" + + model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + + prompts = [ + "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT:", + "USER: \nWhat is this? ASSISTANT:", + ] + image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw) + image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + + inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip + + self.assertEqual( + processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_batch(self): + # Let' s make sure we test the preprocessing to replace what is used + model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True) + # The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!. + prompts = [ + "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", + "USER: \nWhat is this?\nASSISTANT:", + ] + image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw) + image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + + inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = [ + 'USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring.', + 'USER: \nWhat is this?\nASSISTANT: Cats' + ] # fmt: skip + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_llama_batched_regression(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "rhymes-ai/Aria" + + # Multi-image & multi-prompt (e.g. 3 images and 2 prompts now fails with SDPA, this tests if "eager" works as before) + model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True, attn_implementation="eager") + processor = AutoProcessor.from_pretrained(model_id, pad_token="") + + prompts = [ + "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", + "USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT:", + ] + image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw) + image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + + inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip + + self.assertEqual( + processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_torch + @require_vision + def test_batched_generation(self): + model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True) + + processor = AutoProcessor.from_pretrained("rhymes-ai/Aria") + + prompt1 = "\n\nUSER: What's the the difference of two images?\nASSISTANT:" + prompt2 = "\nUSER: Describe the image.\nASSISTANT:" + prompt3 = "\nUSER: Describe the image.\nASSISTANT:" + url1 = "https://images.unsplash.com/photo-1552053831-71594a27632d?q=80&w=3062&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + url2 = "https://images.unsplash.com/photo-1617258683320-61900b281ced?q=80&w=3087&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" + image1 = Image.open(requests.get(url1, stream=True).raw) + image2 = Image.open(requests.get(url2, stream=True).raw) + + inputs = processor( + images=[image1, image2, image1, image2], + text=[prompt1, prompt2, prompt3], + return_tensors="pt", + padding=True, + ).to(torch_device) + + model = model.eval() + + EXPECTED_OUTPUT = [ + "\n \nUSER: What's the the difference of two images?\nASSISTANT: The difference between the two images is that one shows a dog standing on a grassy field, while", + "\nUSER: Describe the image.\nASSISTANT: The image features a brown and white dog sitting on a sidewalk. The dog is holding a small", + "\nUSER: Describe the image.\nASSISTANT: The image features a lone llama standing on a grassy hill. The llama is the", + ] + + generate_ids = model.generate(**inputs, max_new_tokens=20) + outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertEqual(outputs, EXPECTED_OUTPUT) + + @slow + @require_bitsandbytes + def test_aria_index_error_bug(self): + # This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore + # Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for + # more details + model_id = "rhymes-ai/Aria" + model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + + processor = AutoProcessor.from_pretrained(model_id) + + # Simulate a super long prompt + user_prompt = "Describe the image:?\n" * 200 + prompt = f"USER: \n{user_prompt}ASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) + + # Make sure that `generate` works + _ = model.generate(**inputs, max_new_tokens=20) + + @slow + @require_torch_gpu + def test_aria_merge_inputs_error_bug(self): + # This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore + model_id = "rhymes-ai/Aria" + model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + + # Simulate some user inputs + pixel_values = torch.randn( + (1, 3, 336, 336), + dtype=torch.float, + device=torch_device, + ) + input_ids = torch.tensor( + [ + [32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900], + ], + dtype=torch.long, + device=torch_device, + ) + attention_mask = torch.tensor( + [[0, 0, 1, 1, 1, 1, 1, 1, 1]], + dtype=torch.long, + device=torch_device, + ) + + # Make sure that the loss is properly computed + loss = model( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + labels=input_ids, + ).loss + loss.backward() + + def test_tokenizer_integration(self): + model_id = "rhymes-ai/Aria" + slow_tokenizer = AutoTokenizer.from_pretrained( + model_id, bos_token="<|startoftext|>", eos_token="<|endoftext|>", use_fast=False + ) + slow_tokenizer.add_tokens("", True) + + fast_tokenizer = AutoTokenizer.from_pretrained( + model_id, + bos_token="<|startoftext|>", + eos_token="<|endoftext|>", + from_slow=True, + legacy=False, + ) + fast_tokenizer.add_tokens("", True) + + prompt = "<|startoftext|><|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n\nWhat is shown in this image?<|im_end|>" + EXPECTED_OUTPUT = ['<|startoftext|>', '<', '|', 'im', '_', 'start', '|', '>', 'system', '\n', 'Answer', '▁the', '▁questions', '.<', '|', 'im', '_', 'end', '|', '><', '|', 'im', '_', 'start', '|', '>', 'user', '\n', '', '\n', 'What', '▁is', '▁shown', '▁in', '▁this', '▁image', '?', '<', '|', 'im', '_', 'end', '|', '>'] # fmt: skip + self.assertEqual(slow_tokenizer.tokenize(prompt), EXPECTED_OUTPUT) + self.assertEqual(fast_tokenizer.tokenize(prompt), EXPECTED_OUTPUT) + + @slow + @require_bitsandbytes + def test_generation_no_images(self): + model_id = "rhymes-ai/Aria" + model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + + # Prepare inputs with no images + inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device) + + # Make sure that `generate` works + _ = model.generate(**inputs, max_new_tokens=20) + + @slow + @require_bitsandbytes + def test_generation_siglip_backbone(self): + model_id = "rhymes-ai/Aria" + model = AriaForConditionalGeneration.from_pretrained(model_id, torch_dtype="float16", device_map=torch_device) + processor = AutoProcessor.from_pretrained(model_id) + + # check processing with expansion of inputs (w/o expansion should work with any backbone) + processor.vision_feature_select_strategy = "default" + processor.patch_size = 14 + + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor( + text="<|im_start|>user\n\nWhat are these?<|im_end|>\n<|im_start|>assistant", + images=raw_image, + return_tensors="pt", + ).to(torch_device, torch.float16) + + # Make sure that `generate` works + output = model.generate(**inputs, max_new_tokens=30) + + EXPECTED_DECODED_TEXT = "user\n\nWhat are these?\nassistant The image shows two cats, one on the left and one on the right. They appear to be resting or sleeping on a pink blanket. The cat" + self.assertTrue(processor.batch_decode(output, skip_special_tokens=True)[0] == EXPECTED_DECODED_TEXT) + + @slow + @require_bitsandbytes + def test_expansion_in_processing(self): + model_id = "rhymes-ai/Aria" + model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + + prompt = "USER: \nDescribe the image:\nASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + + # check processing with expansion of inputs + processor.vision_feature_select_strategy = "default" + processor.patch_size = 14 + inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) + self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593) + + # check processing without expansion of inputs (legacy behavior) + processor.vision_feature_select_strategy = None + processor.patch_size = None + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) + self.assertTrue(inputs.input_ids.shape[-1] == 18) + + # generate exactly 20 tokens + output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) + output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) + + # check that both inputs are handled correctly and generate the same output + self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) + + @slow + @require_bitsandbytes + def test_pixtral(self): + model_id = "rhymes-ai/Aria" + model = AriaForConditionalGeneration.from_pretrained(model_id) + processor = AutoProcessor.from_pretrained(model_id) + + IMG_URLS = [ + Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw), + Image.open(requests.get("https://picsum.photos/id/231/200/300", stream=True).raw), + Image.open(requests.get("https://picsum.photos/id/27/500/500", stream=True).raw), + Image.open(requests.get("https://picsum.photos/id/17/150/600", stream=True).raw), + ] + PROMPT = "[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]" + + # image = Image.open(requests.get(url, stream=True).raw) + inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda") + generate_ids = model.generate(**inputs, max_new_tokens=500) + ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + # fmt: off + EXPECTED_GENERATION = """ +Describe the images. +Sure, let's break down each image description: + +1. **Image 1:** + - **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera. + - **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur. + +2. **Image 2:** + - **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley. + - **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image. + +3. **Image 3:** + - **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset. + - **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene. + +4. **Image 4:** + - **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers. + - **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden. + +Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it. +""" + # fmt: on + # check that both inputs are handled correctly and generate the same output + self.assertListEqual(ouptut, EXPECTED_GENERATION) diff --git a/tests/models/aria/test_processor_aria.py b/tests/models/aria/test_processor_aria.py new file mode 100644 index 00000000000000..7e23d861c775c0 --- /dev/null +++ b/tests/models/aria/test_processor_aria.py @@ -0,0 +1,391 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest +from io import BytesIO +from typing import Optional + +import numpy as np +import requests + +from transformers import AriaProcessor +from transformers.models.auto.processing_auto import AutoProcessor +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from PIL import Image + + +@require_torch +@require_vision +class AriaProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = AriaProcessor + + @classmethod + def setUpClass(cls): + cls.tmpdirname = tempfile.mkdtemp() + processor = AriaProcessor.from_pretrained("m-ric/Aria_hf_2", image_seq_len=2) + processor.save_pretrained(cls.tmpdirname) + cls.image1 = Image.open( + BytesIO( + requests.get( + "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + ).content + ) + ) + cls.image2 = Image.open( + BytesIO(requests.get("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg").content) + ) + cls.image3 = Image.open( + BytesIO( + requests.get( + "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg" + ).content + ) + ) + cls.bos_token = "<|im_start|>" + cls.eos_token = "<|im_end|>" + + cls.image_token = processor.tokenizer.image_token + cls.fake_image_token = "o" + cls.global_img_token = "<|img|>" + + cls.bos_token_id = processor.tokenizer.convert_tokens_to_ids(cls.bos_token) + cls.eos_token_id = processor.tokenizer.convert_tokens_to_ids(cls.eos_token) + + cls.image_token_id = processor.tokenizer.convert_tokens_to_ids(cls.image_token) + cls.fake_image_token_id = processor.tokenizer.convert_tokens_to_ids(cls.fake_image_token) + cls.global_img_tokens_id = processor.tokenizer(cls.global_img_token, add_special_tokens=False)["input_ids"] + cls.padding_token_id = processor.tokenizer.pad_token_id + cls.image_seq_len = 256 + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def get_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdirname) + + def test_kwargs_overrides_default_image_processor_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + processor_components = self.prepare_components() + processor_components["image_processor"] = self.get_component( + "image_processor", do_rescale=True, rescale_factor=1 + ) + processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(**processor_components) + self.skip_processor_without_typed_kwargs(processor) + + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, return_tensors="pt") + self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0) + + def test_process_interleaved_images_prompts_image_splitting(self): + processor = self.get_processor() + processor.image_processor.split_image = True + + # Test that a single image is processed correctly + inputs = processor(images=self.image1, text="Ok<|img|>", images_kwargs={"split_image": True}) + self.assertEqual(np.array(inputs["pixel_values"]).shape, (2, 3, 980, 980)) + self.assertEqual(np.array(inputs["pixel_mask"]).shape, (2, 980, 980)) + + def test_process_interleaved_images_prompts_no_image_splitting(self): + processor = self.get_processor() + processor.image_processor.split_image = False + + # Test that a single image is processed correctly + inputs = processor(images=self.image1, text="Ok<|img|>") + image1_expected_size = (980, 980) + self.assertEqual(np.array(inputs["pixel_values"]).shape, (1, 3, *image1_expected_size)) + self.assertEqual(np.array(inputs["pixel_mask"]).shape, (1, *image1_expected_size)) + # fmt: on + + # Test a single sample with image and text + image_str = "<|img|>" + text_str = "In this image, we see" + text = image_str + text_str + inputs = processor(text=text, images=self.image1) + + # fmt: off + tokenized_sentence = processor.tokenizer(text_str, add_special_tokens=False) + + expected_input_ids = [[self.image_token_id] * self.image_seq_len + tokenized_sentence["input_ids"]] + # self.assertEqual(len(inputs["input_ids"]), len(expected_input_ids)) + + self.assertEqual(inputs["input_ids"], expected_input_ids) + self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])]) + self.assertEqual(np.array(inputs["pixel_values"]).shape, (1, 3, *image1_expected_size)) + self.assertEqual(np.array(inputs["pixel_mask"]).shape, (1, *image1_expected_size)) + # fmt: on + + # Test that batch is correctly processed + image_str = "<|img|>" + text_str_1 = "In this image, we see" + text_str_2 = "In this image, we see" + + text = [ + image_str + text_str_1, + image_str + image_str + text_str_2, + ] + images = [[self.image1], [self.image2, self.image3]] + + inputs = processor(text=text, images=images, padding=True) + + # fmt: off + tokenized_sentence_1 = processor.tokenizer(text_str_1, add_special_tokens=False) + tokenized_sentence_2 = processor.tokenizer(text_str_2, add_special_tokens=False) + + image_tokens = [self.image_token_id] * self.image_seq_len + expected_input_ids_1 = image_tokens + tokenized_sentence_1["input_ids"] + expected_input_ids_2 = 2 * image_tokens + tokenized_sentence_2["input_ids"] + + # Pad the first input to match the second input + pad_len = len(expected_input_ids_2) - len(expected_input_ids_1) + + expected_attention_mask = [[0] * pad_len + [1] * len(expected_input_ids_1), [1] * (len(expected_input_ids_2))] + + self.assertEqual( + inputs["attention_mask"], + expected_attention_mask + ) + self.assertEqual(np.array(inputs['pixel_values']).shape, (3, 3, 980, 980)) + self.assertEqual(np.array(inputs['pixel_mask']).shape, (3, 980, 980)) + # fmt: on + + def test_non_nested_images_with_batched_text(self): + processor = self.get_processor() + processor.image_processor.do_image_splitting = False + + image_str = "<|img|>" + text_str_1 = "In this image, we see" + text_str_2 = "In this image, we see" + + text = [ + image_str + text_str_1, + image_str + image_str + text_str_2, + ] + images = [self.image1, self.image2, self.image3] + + inputs = processor(text=text, images=images, padding=True) + + self.assertEqual(np.array(inputs["pixel_values"]).shape, (3, 3, 980, 980)) + self.assertEqual(np.array(inputs["pixel_mask"]).shape, (3, 980, 980)) + + def test_apply_chat_template(self): + # Message contains content which a mix of lists with images and image urls and string + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do these images show?"}, + {"type": "image"}, + {"type": "image"}, + "What do these images show?", + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "The first image shows the statue of Liberty in New York. The second image picture depicts Idefix, the dog of Obelix in Asterix and Obelix.", + } + ], + }, + {"role": "user", "content": [{"type": "text", "text": "And who is that?"}]}, + ] + processor = self.get_processor() + # Make short sequence length to test that the fake tokens are added correctly + rendered = processor.apply_chat_template(messages, add_generation_prompt=True) + print(rendered) + + expected_rendered = """<|im_start|>user +What do these images show?<|img|><|img|><|im_end|> +<|im_start|>assistant +The first image shows the statue of Liberty in New York. The second image picture depicts Idefix, the dog of Obelix in Asterix and Obelix.<|im_end|> +<|im_start|>user +And who is that?<|im_end|> +<|im_start|>assistant +""" + self.assertEqual(rendered, expected_rendered) + + # Override as AriaProcessor needs image tokens in prompts + def prepare_text_inputs(self, batch_size: Optional[int] = None): + if batch_size is None: + return "lower newer <|img|>" + + if batch_size < 1: + raise ValueError("batch_size must be greater than 0") + + if batch_size == 1: + return ["lower newer <|img|>"] + return ["lower newer <|img|>", "<|img|> upper older longer string"] + ["<|img|> lower newer"] * ( + batch_size - 2 + ) + + # Override tests as inputs_ids padded dimension is the second one but not the last one + @require_vision + @require_torch + def test_kwargs_overrides_default_tokenizer_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=30) + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, return_tensors="pt", max_length=30) + self.assertEqual(len(inputs["input_ids"][0]), 30) + + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + inputs = processor( + text=input_str, + images=image_input, + common_kwargs={"return_tensors": "pt"}, + images_kwargs={"max_image_size": 980}, + text_kwargs={"padding": "max_length", "max_length": 120, "truncation": "longest_first"}, + ) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs["pixel_values"].shape[3], 980) + + self.assertEqual(len(inputs["input_ids"][0]), 120) + + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"max_image_size": 980}, + "text_kwargs": {"padding": "max_length", "max_length": 120, "truncation": "longest_first"}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertEqual(inputs["pixel_values"].shape[3], 980) + self.assertEqual(len(inputs["input_ids"][0]), 120) + + @require_vision + @require_torch + def test_tokenizer_defaults_preserved_by_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=30) + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, return_tensors="pt") + self.assertEqual(len(inputs["input_ids"][0]), 30) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = self.prepare_text_inputs(batch_size=2) + image_input = self.prepare_image_inputs(batch_size=2) + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + padding="longest", + max_length=76, + truncation=True, + max_image_size=980, + ) + + self.assertEqual(inputs["pixel_values"].shape[1], 3) + self.assertEqual(inputs["pixel_values"].shape[3], 980) + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = self.prepare_text_inputs() + image_input = self.prepare_image_inputs() + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + max_image_size=980, + padding="max_length", + max_length=120, + truncation="longest_first", + ) + + self.assertEqual(inputs["pixel_values"].shape[3], 980) + self.assertEqual(len(inputs["input_ids"][0]), 120) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 9b8244c243fc4a..1c81c08fd845b1 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -381,7 +381,7 @@ def check_config_attributes_being_used(config_class): def check_config_attributes(): - """Check the arguments in `__init__` of all configuration classes are used in python files""" + """Check the arguments in `__init__` of all configuration classes are used in python files""" configs_with_unused_attributes = {} for _config_class in list(CONFIG_MAPPING.values()): # Skip deprecated models diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index a2ea05edce8063..a63ca59690f748 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -865,9 +865,10 @@ def match_docstring_with_signature(obj: Any) -> Optional[Tuple[str, str]]: # We went too far by one (perhaps more if there are a lot of new lines) idx -= 1 - while len(obj_doc_lines[idx].strip()) == 0: - arguments[current_arg] = arguments[current_arg][:-1] - idx -= 1 + if current_arg: + while len(obj_doc_lines[idx].strip()) == 0: + arguments[current_arg] = arguments[current_arg][:-1] + idx -= 1 # And we went too far by one again. idx += 1 diff --git a/utils/check_repo.py b/utils/check_repo.py index 10be5cdcd26230..3dbe59f192293a 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -85,6 +85,8 @@ "Idefics2PerceiverResampler", "Idefics2VisionTransformer", "Idefics3VisionTransformer", + "AriaTextForCausalLM", + "AriaTextModel", ] # Update this list for models that are not tested with a comment explaining the reason it should not be. diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index cf1a0cfd95ca0a..e8d117cd2af08f 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1678,7 +1678,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/gemma/modular_gemma.py"], + default=["src/transformers/models/aria/modular_aria.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", )