From 0c8aa0a1bdfd6f2c3a09163a3ea6f3caa05682a3 Mon Sep 17 00:00:00 2001 From: aymeric-roucher Date: Fri, 25 Oct 2024 13:58:46 +0000 Subject: [PATCH] Remove AriaVisionModel by just using Idefics3 --- src/transformers/__init__.py | 6 +- src/transformers/models/aria/__init__.py | 5 +- .../models/aria/configuration_aria.py | 122 +-- src/transformers/models/aria/modeling_aria.py | 788 +++++------------- src/transformers/models/aria/modular_aria.py | 62 +- .../models/auto/configuration_auto.py | 4 +- src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/idefics3/__init__.py | 6 +- tests/models/aria/test_modeling_aria.py | 1 - 9 files changed, 263 insertions(+), 732 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b70a0249a50ce8..c5ac462ec82b76 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -172,7 +172,6 @@ "models.aria": [ "AriaConfig", "AriaTextConfig", - "AriaVisionConfig", ], "models.audio_spectrogram_transformer": [ "ASTConfig", @@ -2461,6 +2460,8 @@ "Idefics3Model", "Idefics3PreTrainedModel", "Idefics3Processor", + "Idefics3VisionTransformer", + "Idefics3VisionConfig", ] ) _import_structure["models.imagegpt"].extend( @@ -5022,7 +5023,6 @@ from .models.aria import ( AriaConfig, AriaTextConfig, - AriaVisionConfig, ) from .models.audio_spectrogram_transformer import ( ASTConfig, @@ -7178,6 +7178,8 @@ Idefics3Model, Idefics3PreTrainedModel, Idefics3Processor, + Idefics3VisionTransformer, + Idefics3VisionConfig, ) from .models.imagegpt import ( ImageGPTForCausalImageModeling, diff --git a/src/transformers/models/aria/__init__.py b/src/transformers/models/aria/__init__.py index 1a78426275ba4c..20cf672586c0d6 100644 --- a/src/transformers/models/aria/__init__.py +++ b/src/transformers/models/aria/__init__.py @@ -17,7 +17,7 @@ _import_structure = { - "configuration_aria": ["AriaConfig", "AriaForCausalLM", "AriaTextConfig", "AriaVisionConfig"], + "configuration_aria": ["AriaConfig", "AriaForCausalLM", "AriaTextConfig"], "modeling_aria": ["AriaForConditionalGeneration", "AriaPreTrainedModel"], "processing_aria": ["AriaProcessor"], } @@ -41,13 +41,12 @@ ] _import_structure["configuration_aria"] = [ "AriaConfig", - "AriaVisionConfig", "AriaTextConfig", ] if TYPE_CHECKING: - from .configuration_aria import AriaConfig, AriaTextConfig, AriaVisionConfig + from .configuration_aria import AriaConfig, AriaTextConfig try: if not is_torch_available(): diff --git a/src/transformers/models/aria/configuration_aria.py b/src/transformers/models/aria/configuration_aria.py index ad0df22c96e276..3916639df2ac98 100644 --- a/src/transformers/models/aria/configuration_aria.py +++ b/src/transformers/models/aria/configuration_aria.py @@ -4,113 +4,11 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_aria.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -import os -from typing import Union + from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation -from ...utils import logging - - -logger = logging.get_logger(__name__) - - -class AriaVisionConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`AriaVisionModel`]. It is used to instantiate a - Aria vision encoder according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the vision encoder of the Aria - [google/aria-base-patch16-224](https://huggingface.co/google/aria-base-patch16-224) architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - hidden_size (`int`, *optional*, defaults to 768): - Dimensionality of the encoder layers and the pooler layer. - intermediate_size (`int`, *optional*, defaults to 3072): - Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. - num_hidden_layers (`int`, *optional*, defaults to 12): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 12): - Number of attention heads for each attention layer in the Transformer encoder. - num_channels (`int`, *optional*, defaults to 3): - Number of channels in the input images. - image_size (`int`, *optional*, defaults to 224): - The size (resolution) of each image. - patch_size (`int`, *optional*, defaults to 16): - The size (resolution) of each patch. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, - `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. - layer_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the layer normalization layers. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - - Example: - - ```python - >>> from transformers import AriaVisionConfig, AriaVisionModel - - >>> # Initializing a AriaVisionConfig with google/aria-base-patch16-224 style configuration - >>> configuration = AriaVisionConfig() - - >>> # Initializing a AriaVisionModel (with random weights) from the google/aria-base-patch16-224 style configuration - >>> model = AriaVisionModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ``` - Configuration class for AriaVisionModel.""" - - model_type = "aria_vision_model" - - def __init__( - self, - hidden_size=768, - intermediate_size=3072, - num_hidden_layers=12, - num_attention_heads=12, - num_channels=3, - image_size=224, - patch_size=16, - hidden_act="gelu_pytorch_tanh", - layer_norm_eps=1e-6, - attention_dropout=0.0, - **kwargs, - ): - super().__init__(**kwargs) - - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_channels = num_channels - self.patch_size = patch_size - self.image_size = image_size - self.attention_dropout = attention_dropout - self.layer_norm_eps = layer_norm_eps - self._supports_sdpa = False - self.hidden_act = hidden_act - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": - cls._set_token_in_kwargs(kwargs) - - config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - - # get the vision config dict if we are loading from AriaConfig - if config_dict.get("model_type") == "aria": - config_dict = config_dict["vision_config"] - - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: - logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) - - return cls.from_dict(config_dict, **kwargs) +from ..auto import CONFIG_MAPPING class AriaTextConfig(PretrainedConfig): @@ -245,7 +143,6 @@ def __init__( image_token_index=32000, **kwargs, ): - super().__init__(**kwargs) self.ignore_index = ignore_index self.image_token_index = image_token_index @@ -257,17 +154,20 @@ def __init__( 4900: 256, } self.projector_patch_to_query_dict = {int(k): int(v) for k, v in projector_patch_to_query_dict.items()} - if vision_config is None: - vision_config = AriaVisionConfig() - if text_config is None: - text_config = AriaTextConfig() - if isinstance(vision_config, dict) and "model_type" in vision_config: - vision_config = AriaVisionConfig(**vision_config) + if isinstance(vision_config, dict): + vision_config["model_type"] = "idefics3_vision" + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["idefics3_vision"]() self.vision_config = vision_config if isinstance(text_config, dict) and "model_type" in text_config: text_config = AriaTextConfig(**text_config) + elif text_config is None: + text_config = AriaTextConfig() self.text_config = text_config + + super().__init__(**kwargs) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 019ae5440a11f4..36a6199cd2acad 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -4,6 +4,8 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_aria.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +import warnings from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -11,25 +13,24 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn.init import trunc_normal_ +from torch.nn.init import _calculate_fan_in_and_fan_out, trunc_normal_ -from ... import PreTrainedModel from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ( + ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, + torch_int, ) from ..auto import AutoModel, AutoModelForCausalLM -from .configuration_aria import AriaConfig, AriaVisionConfig +from .configuration_aria import AriaConfig, AriaTextConfig from .processing_utils import ( experts_gemm, ) @@ -37,71 +38,15 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward - -import math -import warnings - -import torch -from torch.nn.init import _calculate_fan_in_and_fan_out - -from ...cache_utils import StaticCache +from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, - BaseModelOutputWithPooling, CausalLMOutputWithPast, -) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...utils import ( ModelOutput, - is_flash_attn_2_available, ) -from .configuration_aria import AriaTextConfig - - -class AriaPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. - """ - - config_class = AriaConfig - base_model_prefix = "model" - _no_split_modules = [] - supports_gradient_checkpointing = True - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_cache_class = True - - @property - def _supports_sdpa(self): - """ - Retrieve language_model's attribute to check whether the model supports - SDPA (Scaled Dot Product Attention) or not. - """ - return self.language_model._supports_sdpa - - def _init_weights(self, module): - if hasattr(self.config, 'initializer_range'): - std = self.config.initializer_range - elif hasattr(self.config, 'text_config'): - std = self.config.text_config.initializer_range - else: - std = 0.02 - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, AriaGroupedGEMM): - module.weight.data.normal_(mean=0.0, std=std) - elif isinstance(module, nn.Conv2d): - module.weight.data.normal_(mean=0.0, std=std) - if hasattr(module, "bias") and module.bias is not None: - module.bias.data.zero_() +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS class IdentityOp(torch.nn.Module): @@ -119,6 +64,29 @@ def forward(self, x, *args, **kwargs): return x +class AriaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + AriaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +logger = logging.get_logger(__name__) + + def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf @@ -155,9 +123,6 @@ def norm_cdf(x): tensor.clamp_(min=a, max=b) -logger = logging.get_logger(__name__) - - def trunc_normal_tf_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 ) -> torch.Tensor: @@ -184,64 +149,6 @@ def trunc_normal_tf_( tensor.mul_(std).add_(mean) -class AriaVisionEmbeddings(nn.Module): - """ - This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable - resolution. - - The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) - which allows treating images in their native aspect ratio and without the need to resize them to the same - fixed size. In particular, we start from the original pre-trained SigLIP model - (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions. - """ - - def __init__(self, config: AriaVisionConfig): - super().__init__() - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.patch_embedding = nn.Conv2d( - in_channels=config.num_channels, - out_channels=self.embed_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - padding="valid", - ) - - self.num_patches_per_side = self.image_size // self.patch_size - self.num_patches = self.num_patches_per_side**2 - self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - - def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: - batch_size, _, max_im_h, max_im_w = pixel_values.shape - - patch_embeds = self.patch_embedding(pixel_values) - embeddings = patch_embeds.flatten(2).transpose(1, 2) - - max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size - boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) - position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) - - for batch_idx, p_attn_mask in enumerate(patch_attention_mask): - nb_patches_h = p_attn_mask[:, 0].sum() - nb_patches_w = p_attn_mask[0].sum() - - fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) - fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) - - bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) - bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) - - pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() - position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids - - position_ids = position_ids.to(self.position_embedding.weight.device) - embeddings = embeddings + self.position_embedding(position_ids) - return embeddings - - def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": @@ -267,7 +174,7 @@ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): raise ValueError(f"invalid distribution {distribution}") -class AriaVisionAttention(nn.Module): +class AriaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config): @@ -289,9 +196,6 @@ def __init__(self, config): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - # Ignore copy - self.is_causal = False - def forward( self, hidden_states: torch.Tensor, @@ -345,13 +249,15 @@ def forward( return attn_output, attn_weights -class AriaVisionFlashAttention2(AriaVisionAttention): +class AriaFlashAttention2(AriaAttention): """ - AriaVision flash attention module. This module inherits from `AriaVisionAttention` as the weights of the module stays + AriaAttention flash attention module. This module inherits from `AriaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ + is_causal = False + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -360,19 +266,16 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, output_attentions: bool = False, - use_cache: bool = False, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False - bsz, q_len, _ = hidden_states.size() + batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -381,16 +284,13 @@ def forward( # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) @@ -400,7 +300,7 @@ def forward( # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (AriaVisionRMSNorm handles it correctly) + # in fp32. input_dtype = query_states.dtype if input_dtype == torch.float32: @@ -433,7 +333,7 @@ def forward( use_top_left_mask=self._flash_attn_uses_top_left_mask, ) - attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: @@ -442,204 +342,89 @@ def forward( return attn_output, attn_weights -IDEFICS_VISION_ATTENTION_CLASSES = { - "eager": AriaVisionAttention, - "flash_attention_2": AriaVisionFlashAttention2, -} - - -class AriaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout +class AriaSdpaAttention(AriaAttention): + """ + Aria attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `AriaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + is_causal = False + # Adapted from AriaAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Input shape: Batch x Time x Channel""" - - batch_size, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale - - if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): - raise ValueError( - f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}" + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "AriaModel is using AriaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class AriaVisionMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -class AriaFlashAttention2(AriaAttention): - """ - AriaAttention flash attention module. This module inherits from `AriaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - is_causal = False - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if self.is_causal and q_len > 1 else False - attn_output = _flash_attention_forward( + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, - attention_mask, - q_len, - dropout=dropout_rate, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, ) - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, q_len, self.embed_dim) + attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None + return attn_output, None - return attn_output, attn_weights + +ARIA_ATTENTION_CLASSES = { + "eager": AriaAttention, + "flash_attention_2": AriaFlashAttention2, + "sdpa": AriaSdpaAttention, +} class AriaEncoderLayer(nn.Module): - def __init__(self, config: AriaVisionConfig): + def __init__(self, config: AriaTextConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = IDEFICS_VISION_ATTENTION_CLASSES[config._attn_implementation](config) + self.self_attn = ARIA_ATTENTION_CLASSES[config._attn_implementation](config=config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = AriaVisionMLP(config) + self.mlp = AriaMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + # Ignore copy def forward( self, hidden_states: torch.Tensor, @@ -651,98 +436,65 @@ def forward( hidden_states (`torch.FloatTensor`): Input to the layer of shape `(batch, seq_len, embed_dim)`. attention_mask (`torch.FloatTensor`): - Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -class AriaSdpaAttention(AriaAttention): - """ - Aria attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `AriaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - is_causal = False - - # Adapted from AriaAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "AriaModel is using AriaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states - batch_size, q_len, _ = hidden_states.size() + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + outputs = (hidden_states,) - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() + if output_attentions: + outputs += (attn_weights,) - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if self.is_causal and q_len > 1 else False + return outputs - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, q_len, self.embed_dim) +ARIA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) - attn_output = self.out_proj(attn_output) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. - return attn_output, None + Parameters: + config ([`AriaConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ARIA_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" class AriaEncoder(nn.Module): @@ -751,10 +503,10 @@ class AriaEncoder(nn.Module): [`AriaEncoderLayer`]. Args: - config: AriaConfig + config: AriaTextConfig """ - def __init__(self, config: AriaConfig): + def __init__(self, config: AriaTextConfig): super().__init__() self.config = config self.layers = nn.ModuleList([AriaEncoderLayer(config) for _ in range(config.num_hidden_layers)]) @@ -833,178 +585,11 @@ def forward( ) -ARIA_ATTENTION_CLASSES = { - "eager": AriaAttention, - "flash_attention_2": AriaFlashAttention2, - "sdpa": AriaSdpaAttention, -} - - -ARIA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`AriaConfig`] or [`AriaVisionConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -ARIA_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): - Whether to interpolate the pre-trained position encodings. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -ARIA_VISION_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`AriaVisionConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The Aria Vision Transformer Model outputting raw image embedding.", - ARIA_VISION_START_DOCSTRING, -) -class AriaVisionTransformer(AriaPreTrainedModel): - """ - Aria Vision Transformer model based on Idefics3VisionTransformer. - - This class extends the original Idefics3VisionTransformer by removing the post-layernorm operation. - """ - - config_class = AriaVisionConfig - - _supports_sdpa = False - - def __init__(self, config: AriaVisionConfig): - super().__init__(config) - self.embed_dim = config.hidden_size - - self.embeddings = AriaVisionEmbeddings(config) - self.encoder = AriaEncoder(config) - self.patch_size = config.patch_size - self.post_layernorm = IdentityOp() - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, value): - self.embeddings = value - - def forward( - self, - pixel_values, - patch_attention_mask: Optional[torch.BoolTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size = pixel_values.size(0) - if patch_attention_mask is None: - patch_size = self.patch_size - patch_attention_mask = torch.ones( - ( - batch_size, - pixel_values.size(2) // patch_size, - pixel_values.size(3) // patch_size, - ) - ) - patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device) - - hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) - - patch_attention_mask = patch_attention_mask.view(batch_size, -1) - # The call to `_upad_input` in `_flash_attention_forward` is expensive - # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), - # avoiding passing the attention_mask, which is equivalent to attending to the full sequence - if not torch.any(~patch_attention_mask): - patch_attention_mask = None - elif not self._use_flash_attention_2: - patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - attention_mask=patch_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = encoder_outputs[0] - last_hidden_state = self.post_layernorm(last_hidden_state) - - if not return_dict: - return (last_hidden_state,) + encoder_outputs[1:] - - return BaseModelOutput( - last_hidden_state=last_hidden_state, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class AriaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - AriaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - @add_start_docstrings( """The vision model from Aria without any head or projection on top.""", ARIA_START_DOCSTRING, ) -class AriaVisionModel(AriaPreTrainedModel): +class AriaVisionModel(PreTrainedModel): """ Aria Vision Model extends SiglipVisionModel to support pixel_mask. @@ -1017,13 +602,12 @@ class AriaVisionModel(AriaPreTrainedModel): This mask helps the model focus on the relevant parts of the image during processing. """ - config_class = AriaVisionConfig main_input_name = "pixel_values" _supports_sdpa = False - def __init__(self, config: AriaVisionConfig): + def __init__(self, config): super().__init__(config) - self.vision_model = AriaVisionTransformer(config) + self.vision_model = AutoModel.from_config(config) # Initialize weights and apply final processing self.post_init() @@ -1032,7 +616,7 @@ def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding @add_start_docstrings_to_model_forward(ARIA_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=AriaVisionConfig) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling) def forward( self, pixel_values: torch.Tensor, @@ -1061,12 +645,16 @@ def forward( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + output_hidden_states=True, return_dict=return_dict, ) image_attentions = self._create_image_attention_mask(patch_attention_mask) + last_hidden_state_pre_normalization = vision_output.hidden_states[-1] + + vision_output.last_hidden_state = last_hidden_state_pre_normalization + if not return_dict: return vision_output, image_attentions @@ -1252,6 +840,50 @@ def forward(self, x, attn_mask=None): return out +class AriaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. + """ + + config_class = AriaConfig + base_model_prefix = "model" + _no_split_modules = [] + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + @property + def _supports_sdpa(self): + """ + Retrieve language_model's attribute to check whether the model supports + SDPA (Scaled Dot Product Attention) or not. + """ + return self.language_model._supports_sdpa + + def _init_weights(self, module): + if hasattr(self.config, "initializer_range"): + std = self.config.initializer_range + elif hasattr(self.config, "text_config"): + std = self.config.text_config.initializer_range + else: + std = 0.02 + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, AriaGroupedGEMM): + module.weight.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=std) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + + # adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/router.py#L96-L304 class AriaTopKRouter(nn.Module): """ @@ -1261,7 +893,7 @@ class AriaTopKRouter(nn.Module): It also applies auxiliary losses to encourage load balancing among experts. Args: - config (AriaConfig): Configuration object containing MoE-related parameters. + config (AriaTextConfig): Configuration object containing MoE-related parameters. """ def __init__(self, config: AriaTextConfig): @@ -1298,7 +930,7 @@ class AriaMLP(nn.Module): This class reconfigures the intermediate size in comparison to the LlamaMLP. Args: - config (AriaConfig): Configuration object for the Aria language model. + config (AriaTextConfig): Configuration object for the Aria language model. """ def __init__(self, config: AriaTextConfig): @@ -1384,7 +1016,7 @@ class AriaGroupedMLP(nn.Module): Grouped MLP module for Mixture of Experts. Args: - config (AriaConfig): Configuration object for the model. + config (AriaTextConfig): Configuration object for the model. """ def __init__(self, config: AriaTextConfig) -> None: @@ -1421,7 +1053,7 @@ class AriaTextMoELayer(nn.Module): # TODO: check naming convenstion for Instruc the outputs. Args: - config (AriaConfig): Configuration object for the MoE layer. + config (AriaTextConfig): Configuration object for the MoE layer. """ def __init__(self, config: AriaTextConfig): @@ -1489,7 +1121,7 @@ def __init__( device=None, scaling_factor=1.0, rope_type="default", - config: Optional[AriaConfig] = None, + config: Optional[AriaTextConfig] = None, ): super().__init__() # TODO (joao): remove the `if` below, only used for BC @@ -1601,7 +1233,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -_CONFIG_FOR_DOC = "AriaConfig" +_CONFIG_FOR_DOC = "AriaTextConfig" def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -2849,6 +2481,18 @@ class AriaCausalLMOutputWithPast(ModelOutput): image_hidden_states: Optional[torch.FloatTensor] = None +class Idefics3Wrapper(AriaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.vision_model = AutoModel.from_config( + config.vision_config, attn_implementation=config._attn_implementation + ) + self.post_init() + + def forward(self, pixel_values, **kwargs): + return self.vision_model(pixel_values, **kwargs) + + class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): """ Aria model for conditional generation tasks. @@ -2865,9 +2509,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): def __init__(self, config: AriaConfig): super().__init__(config) - self.vision_tower = AutoModel.from_config( - config.vision_config, attn_implementation=config._attn_implementation + self.vision_tower = Idefics3Wrapper( + config ) + print("PREFIX", self.vision_tower.base_model_prefix) + print(dir(self.vision_tower)) + # self.vision_tower.base_model_prefix = "vision_tower.vision_model" self.multi_modal_projector = AriaProjector( patch_to_query_dict=config.projector_patch_to_query_dict, embed_dim=config.vision_config.hidden_size, @@ -2877,6 +2524,7 @@ def __init__(self, config: AriaConfig): output_dim=config.text_config.hidden_size, ) self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModelForCausalLM.from_config( config.text_config, attn_implementation=config._attn_implementation ) diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 6aa1ba6afb9efe..c4aefaf907beb9 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -27,7 +27,8 @@ from ...utils import ( logging, ) -from ..auto import AutoModel, AutoModelForCausalLM, AutoTokenizer +from ..auto import CONFIG_MAPPING, AutoModel, AutoModelForCausalLM, AutoTokenizer +from ..idefics3.configuration_idefics3 import Idefics3VisionConfig from ..idefics3.modeling_idefics3 import Idefics3VisionTransformer from ..llama.configuration_llama import LlamaConfig from ..llama.modeling_llama import ( @@ -40,7 +41,6 @@ LlamaRMSNorm, ) from ..llava.modeling_llava import LlavaCausalLMOutputWithPast -from ..siglip.configuration_siglip import SiglipVisionConfig from ..siglip.modeling_siglip import SiglipVisionModel from .processing_utils import ( experts_gemm, @@ -52,19 +52,6 @@ logger = logging.get_logger(__name__) -class AriaVisionConfig(SiglipVisionConfig): - """Configuration class for AriaVisionModel.""" - - model_type = "aria_vision_model" - - def __init__( - self, - **kwargs, - ): - super().__init__(**kwargs) - self._supports_sdpa = False - - class IdentityOp(torch.nn.Module): """ An identity operation that returns the input unchanged. @@ -80,19 +67,6 @@ def forward(self, x, *args, **kwargs): return x -class AriaVisionTransformer(Idefics3VisionTransformer): - """ - Aria Vision Transformer model based on Idefics3VisionTransformer. - - This class extends the original Idefics3VisionTransformer by removing the post-layernorm operation. - """ - - _supports_sdpa = False - - def __init__(self, config: AriaVisionConfig): - super().__init__(config) - self.post_layernorm = IdentityOp() - class AriaRMSNorm(LlamaRMSNorm): pass @@ -111,16 +85,12 @@ class AriaVisionModel(SiglipVisionModel): This mask helps the model focus on the relevant parts of the image during processing. """ - config_class = AriaVisionConfig main_input_name = "pixel_values" _supports_sdpa = False - def __init__(self, config: AriaVisionConfig): + def __init__(self, config: Idefics3VisionConfig): super().__init__(config) - self.vision_model = AriaVisionTransformer(config) - - # Initialize weights and apply final processing - self.post_init() + self.vision_model = Idefics3VisionTransformer(config) def forward( self, @@ -156,6 +126,10 @@ def forward( image_attentions = self._create_image_attention_mask(patch_attention_mask) + last_hidden_state_pre_normalization = vision_output.hidden_states[-1] + + vision_output.last_hidden_state = last_hidden_state_pre_normalization + if not return_dict: return vision_output, image_attentions @@ -826,13 +800,17 @@ def __init__( 4900: 256, } self.projector_patch_to_query_dict = {int(k): int(v) for k, v in projector_patch_to_query_dict.items()} - if vision_config is None: - vision_config = AriaVisionConfig() if text_config is None: text_config = AriaTextConfig() - if isinstance(vision_config, dict) and "model_type" in vision_config: - vision_config = AriaVisionConfig(**vision_config) + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "idefics3" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["idefics3"]() self.vision_config = vision_config @@ -895,7 +873,7 @@ class AriaTopKRouter(nn.Module): It also applies auxiliary losses to encourage load balancing among experts. Args: - config (AriaConfig): Configuration object containing MoE-related parameters. + config (AriaTextConfig): Configuration object containing MoE-related parameters. """ def __init__(self, config: AriaTextConfig): @@ -932,7 +910,7 @@ class AriaMLP(LlamaMLP): This class reconfigures the intermediate size in comparison to the LlamaMLP. Args: - config (AriaConfig): Configuration object for the Aria language model. + config (AriaTextConfig): Configuration object for the Aria language model. """ def __init__(self, config: AriaTextConfig): @@ -998,7 +976,7 @@ class AriaGroupedMLP(nn.Module): Grouped MLP module for Mixture of Experts. Args: - config (AriaConfig): Configuration object for the model. + config (AriaTextConfig): Configuration object for the model. """ def __init__(self, config: AriaTextConfig) -> None: @@ -1035,7 +1013,7 @@ class AriaTextMoELayer(nn.Module): # TODO: check naming convenstion for Instruc the outputs. Args: - config (AriaConfig): Configuration object for the MoE layer. + config (AriaTextConfig): Configuration object for the MoE layer. """ def __init__(self, config: AriaTextConfig): diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 4beac67cc2e99a..b6fbb0477c832c 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -37,7 +37,6 @@ ("altclip", "AltCLIPConfig"), ("aria", "AriaConfig"), ("aria_text_model", "AriaTextConfig"), - ("aria_vision_model", "AriaVisionConfig"), ("audio-spectrogram-transformer", "ASTConfig"), ("autoformer", "AutoformerConfig"), ("bark", "BarkConfig"), @@ -138,6 +137,7 @@ ("idefics", "IdeficsConfig"), ("idefics2", "Idefics2Config"), ("idefics3", "Idefics3Config"), + ("idefics3_vision", "Idefics3VisionConfig"), ("imagegpt", "ImageGPTConfig"), ("informer", "InformerConfig"), ("instructblip", "InstructBlipConfig"), @@ -445,6 +445,7 @@ ("idefics", "IDEFICS"), ("idefics2", "Idefics2"), ("idefics3", "Idefics3"), + ("idefics3_vision", "Idefics3VisionTransformer"), ("imagegpt", "ImageGPT"), ("informer", "Informer"), ("instructblip", "InstructBLIP"), @@ -691,6 +692,7 @@ ("clip_text_model", "clip"), ("aria_text_model", "aria"), ("aria_vision_model", "aria"), + ("idefics3_vision", "idefics3"), ("siglip_vision_model", "siglip"), ("chinese_clip_vision_model", "chinese_clip"), ("rt_detr_resnet", "rt_detr"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 053bad1b7f3b52..02e3da8c630d25 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -135,6 +135,7 @@ ("idefics", "IdeficsModel"), ("idefics2", "Idefics2Model"), ("idefics3", "Idefics3Model"), + ("idefics3_vision", "Idefics3VisionTransformer"), ("imagegpt", "ImageGPTModel"), ("informer", "InformerModel"), ("jamba", "JambaModel"), diff --git a/src/transformers/models/idefics3/__init__.py b/src/transformers/models/idefics3/__init__.py index 35b1df5c678439..080ded94f368e7 100644 --- a/src/transformers/models/idefics3/__init__.py +++ b/src/transformers/models/idefics3/__init__.py @@ -16,7 +16,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available -_import_structure = {"configuration_idefics3": ["Idefics3Config"]} +_import_structure = {"configuration_idefics3": ["Idefics3Config", "Idefics3VisionConfig"]} try: @@ -38,11 +38,12 @@ "Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", + "Idefics3VisionTransformer", ] _import_structure["processing_idefics3"] = ["Idefics3Processor"] if TYPE_CHECKING: - from .configuration_idefics3 import Idefics3Config + from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig try: if not is_vision_available(): @@ -61,6 +62,7 @@ from .modeling_idefics3 import ( Idefics3ForConditionalGeneration, Idefics3Model, + Idefics3VisionTransformer, Idefics3PreTrainedModel, ) from .processing_idefics3 import Idefics3Processor diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index b449dfc6dd98d0..eba08288ad71b8 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -23,7 +23,6 @@ AriaConfig, AriaForConditionalGeneration, AriaTextConfig, - AriaVisionConfig, AutoProcessor, AutoTokenizer, is_torch_available,