From d599d5dffa887758cab1d43187b8df6dcb777433 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 23 Dec 2024 15:32:07 +0000 Subject: [PATCH] Put the legacy processing code back --- .../models/llava/configuration_llava.py | 4 ++ .../models/llava/modeling_llava.py | 56 ++++++++++++++++++- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llava/configuration_llava.py b/src/transformers/models/llava/configuration_llava.py index b054c0abda1ded..8b3be6fcedbc63 100644 --- a/src/transformers/models/llava/configuration_llava.py +++ b/src/transformers/models/llava/configuration_llava.py @@ -48,6 +48,8 @@ class LlavaConfig(PretrainedConfig): Can be one of `"default"` or `"full"`. vision_feature_layer (`int`, *optional*, defaults to -2): The index of the layer to select the vision feature. + image_seq_length (`int`, *optional*, defaults to 576): + Sequence length of one image embedding. multimodal_projector_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the multimodal projector. @@ -84,12 +86,14 @@ def __init__( projector_hidden_act="gelu", vision_feature_select_strategy="default", vision_feature_layer=-2, + image_seq_length=576, multimodal_projector_bias=True, **kwargs, ): self.ignore_index = ignore_index self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act + self.image_seq_length = image_seq_length if vision_feature_select_strategy not in ["default", "full"]: raise ValueError( diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index f307f871e46d38..05dd7e49f1ea25 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -464,9 +464,17 @@ def forward( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) + legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing + # not very reliable, but we don't expect one to actually pass 500+ images for one prompt + # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True + legacy_processing = ( + (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length + ) or (input_ids.shape[-1] == 1 and pixel_values is not None) + image_features = None if pixel_values is not None: image_features = self.get_image_features( @@ -475,7 +483,53 @@ def forward( vision_feature_select_strategy=vision_feature_select_strategy, ) - if image_features is not None: + if legacy_processing: + logger.warning_once( + "Expanding inputs for image tokens in LLaVa should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." + ) + # prefill stage vs decoding stage (legacy behavior copied) + if input_ids.shape[1] != 1: + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) + else: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + + # TODO: @raushan retain only the new behavior after v4.47 + elif image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] * image_features.shape[1]