From 8a08b6bdd0033fb22f53b5a9c79d13103ecb5813 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 30 Oct 2024 09:57:33 +0100 Subject: [PATCH 01/15] fix tests --- .../llava_next_video/modular_llava_next_video.py | 13 ++++++++++++- .../models/video_llava/modeling_video_llava.py | 13 ++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index e9974e920493ff..c1ed7571941b9e 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -623,6 +623,17 @@ def prepare_inputs_for_generation( ): # Overwritten -- extra custom processing + if input_ids is not None: + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + legacy_processing = (img_token_not_enough and pixel_values is not None) or ( + video_token_not_enough and pixel_values_videos is not None + ) + model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -635,7 +646,7 @@ def prepare_inputs_for_generation( # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model - if cache_position[0] == 0: + if legacy_processing or cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values model_inputs["pixel_values_videos"] = pixel_values_videos model_inputs["image_sizes"] = image_sizes diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 30f82e45056c77..a9bd8b745a6f68 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -720,6 +720,17 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + if input_ids is not None: + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + legacy_processing = (img_token_not_enough and pixel_values_images is not None) or ( + video_token_not_enough and pixel_values_videos is not None + ) + model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -730,7 +741,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if cache_position[0] == 0: + if legacy_processing or cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values_images"] = pixel_values_images From a022e60df2655f37b9d7b5ead4a66d9df0ed360e Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 30 Oct 2024 09:59:25 +0100 Subject: [PATCH 02/15] [run-slow] llava_next_video From b0e1c7c3764509b2d4d3cd53e5f4d4ef0950a7fb Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 30 Oct 2024 10:26:24 +0100 Subject: [PATCH 03/15] fix copies --- .../llava_next_video/modeling_llava_next_video.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 44b372535d70bd..96f4373afd9ec6 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -1110,6 +1110,17 @@ def prepare_inputs_for_generation( ): # Overwritten -- extra custom processing + if input_ids is not None: + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + legacy_processing = (img_token_not_enough and pixel_values is not None) or ( + video_token_not_enough and pixel_values_videos is not None + ) + model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -1122,7 +1133,7 @@ def prepare_inputs_for_generation( # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model - if cache_position[0] == 0: + if legacy_processing or cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values model_inputs["pixel_values_videos"] = pixel_values_videos model_inputs["image_sizes"] = image_sizes From 2bfd72217dcc03ab7ba99974ab38822cd5cee795 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 30 Oct 2024 10:26:26 +0100 Subject: [PATCH 04/15] [run-slow] llava_next_video From 931b03a393955f709b18e10b0153d253882d2b50 Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 22 Nov 2024 14:40:42 +0100 Subject: [PATCH 05/15] remove legacy in all models --- .../models/llava/modeling_llava.py | 56 ------- .../models/llava/processing_llava.py | 30 ++-- .../models/llava_next/modeling_llava_next.py | 61 -------- .../llava_next/processing_llava_next.py | 34 ++-- .../modular_llava_next_video.py | 145 ++++-------------- .../processing_llava_next_video.py | 61 ++++---- .../video_llava/modeling_video_llava.py | 145 ++++-------------- .../video_llava/processing_video_llava.py | 15 +- .../models/vipllava/modeling_vipllava.py | 54 ------- tests/models/llava/test_modeling_llava.py | 33 +--- .../models/vipllava/test_modeling_vipllava.py | 32 ++-- 11 files changed, 128 insertions(+), 538 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index a0079f1787a2e9..626a7146e537fb 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -461,18 +461,9 @@ def forward( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True - legacy_processing = ( - (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - ) or (input_ids.shape[-1] == 1 and pixel_values is not None) - - image_features = None if pixel_values is not None: image_features = self.get_image_features( pixel_values=pixel_values, @@ -480,53 +471,6 @@ def forward( vision_feature_select_strategy=vision_feature_select_strategy, ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - # prefill stage vs decoding stage (legacy behavior copied) - if input_ids.shape[1] != 1: - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - elif image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() n_image_features = image_features.shape[1] if n_image_tokens != n_image_features: diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 8a9597892c6021..d1b0beb2d0f698 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -143,25 +143,17 @@ def __call__( # try to expand inputs in processing if we have the necessary parts prompt_strings = text if image_inputs.get("pixel_values") is not None: - if self.patch_size is not None and self.vision_feature_select_strategy is not None: - # Replace the image token with the expanded image token sequence - pixel_values = image_inputs["pixel_values"] - height, width = get_image_size(to_numpy_array(pixel_values[0])) - num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1 - if self.vision_feature_select_strategy == "default": - num_image_tokens -= 1 - - prompt_strings = [] - for sample in text: - sample = sample.replace(self.image_token, self.image_token * num_image_tokens) - prompt_strings.append(sample) - else: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) + # Replace the image token with the expanded image token sequence + pixel_values = image_inputs["pixel_values"] + height, width = get_image_size(to_numpy_array(pixel_values[0])) + num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1 + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + prompt_strings = [] + for sample in text: + sample = sample.replace(self.image_token, self.image_token * num_image_tokens) + prompt_strings.append(sample) text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_inputs, **image_inputs}) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 5a49337b2b5d96..cd9763f02b5135 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -835,18 +835,9 @@ def forward( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True - legacy_processing = ( - (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - ) or (input_ids.shape[-1] == 1 and pixel_values is not None) - - image_features = None if pixel_values is not None and pixel_values.size(0) > 0: image_features = self.get_image_features( pixel_values, @@ -863,58 +854,6 @@ def forward( image_newline=self.image_newline, ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - if input_ids.shape[1] != 1: - inputs_embeds = inputs_embeds.to(image_features.dtype) - inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features( - image_features, - feature_lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - elif image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] if n_image_tokens != n_image_features: diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index ce11be6d6309a8..ac4396ff10ab9b 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -138,27 +138,19 @@ def __call__( prompt_strings = text if image_inputs: - if self.patch_size is None or self.vision_feature_select_strategy is None: - logger.warning_once( - "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - else: - image_sizes = iter(image_inputs["image_sizes"]) - height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) - prompt_strings = [] - for sample in text: - while self.image_token in sample: - image_size = next(image_sizes) - orig_height, orig_width = image_size - num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) - if self.vision_feature_select_strategy == "default": - num_image_tokens -= 1 - sample = sample.replace(self.image_token, "" * num_image_tokens, 1) - prompt_strings.append(sample) - prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings] + image_sizes = iter(image_inputs["image_sizes"]) + height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + sample = sample.replace(self.image_token, "" * num_image_tokens, 1) + prompt_strings.append(sample) + prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings] text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index c1ed7571941b9e..691fa92766fc92 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -427,25 +427,9 @@ def forward( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None) - legacy_processing = inputs_not_expanded or pixels_present - - image_features = feature_lens = None if pixel_values is not None and pixel_values.size(0) > 0: image_features = self.get_image_features( pixel_values, @@ -460,7 +444,21 @@ def forward( image_newline=self.image_newline, ) - video_features = video_feature_lens = None + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: video_features = self.get_video_features( pixel_values_videos, @@ -472,94 +470,20 @@ def forward( video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - if input_ids.shape[1] != 1: - iterator = ( - (image_features, feature_lens, self.config.image_token_index), - (video_features, video_feature_lens, self.config.video_token_index), - ) - for features, lens, special_token in iterator: - if features is not None: - ( - inputs_embeds, - attention_mask, - position_ids, - labels, - input_ids, - ) = self._merge_input_ids_with_image_features( - features, - lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - image_token_index=special_token, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - else: - if image_features is not None: - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - if video_features is not None: - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() - n_video_features = video_features.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - special_image_mask = ( - (input_ids == self.config.video_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + special_image_mask = ( + (input_ids == self.config.video_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, @@ -623,17 +547,6 @@ def prepare_inputs_for_generation( ): # Overwritten -- extra custom processing - if input_ids is not None: - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - legacy_processing = (img_token_not_enough and pixel_values is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -646,7 +559,7 @@ def prepare_inputs_for_generation( # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model - if legacy_processing or cache_position[0] == 0: + if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values model_inputs["pixel_values_videos"] = pixel_values_videos model_inputs["image_sizes"] = image_sizes diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index e0e4534e42b565..3bedf6e8f51724 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -161,42 +161,33 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - if self.patch_size is None or self.vision_feature_select_strategy is None: - logger.warning_once( - "Expanding inputs for image/video tokens in LLaVa-NeXT-Video should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - else: - # images expand taking into account num_of_patches in each image - if image_inputs: - image_sizes = iter(image_inputs["image_sizes"]) - height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) - prompt_strings = [] - for sample in text: - while self.image_token in sample: - image_size = next(image_sizes) - orig_height, orig_width = image_size - num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) - if self.vision_feature_select_strategy == "default": - num_image_tokens -= 1 - sample = sample.replace(self.image_token, "" * num_image_tokens, 1) - prompt_strings.append(sample) - text = [sample.replace("", self.image_token) for sample in prompt_strings] + if image_inputs: + image_sizes = iter(image_inputs["image_sizes"]) + height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + sample = sample.replace(self.image_token, "" * num_image_tokens, 1) + prompt_strings.append(sample) + text = [sample.replace("", self.image_token) for sample in prompt_strings] - # videos are easier, simply get frames and multiply - if videos_inputs: - one_video = to_numpy_array(videos_inputs.get("pixel_values_videos")[0]) - height, width = get_image_size(one_video[0]) - num_frames = one_video.shape[0] # frame dim is always after batch dim - num_image_tokens = (height // self.patch_size) * (width // self.patch_size) - num_video_tokens = num_image_tokens // 4 * num_frames # divide by 4 needed for avg pooling layer - prompt_strings = [] - for sample in text: - sample = sample.replace(self.video_token, self.video_token * num_video_tokens) - prompt_strings.append(sample) - text = prompt_strings + # videos are easier, simply get frames and multiply + if videos_inputs: + one_video = to_numpy_array(videos_inputs.get("pixel_values_videos")[0]) + height, width = get_image_size(one_video[0]) + num_frames = one_video.shape[0] # frame dim is always after batch dim + num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + num_video_tokens = num_image_tokens // 4 * num_frames # divide by 4 needed for avg pooling layer + prompt_strings = [] + for sample in text: + sample = sample.replace(self.video_token, self.video_token * num_video_tokens) + prompt_strings.append(sample) + text = prompt_strings text_inputs = self.tokenizer( text, diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index a9bd8b745a6f68..0537736087b169 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -537,127 +537,49 @@ def forward( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) - legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - inputs_not_expanded = (img_token_not_enough and pixel_values_images is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - pixels_present = input_ids.shape[-1] == 1 and ( - pixel_values_images is not None or pixel_values_videos is not None - ) - legacy_processing = inputs_not_expanded or pixels_present - - image_features = None if pixel_values_images is not None: image_features = self.get_image_features( pixel_values_images, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() + n_image_features = image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - video_features = None - num_frames = 0 if pixel_values_videos is not None: video_features, num_frames = self.get_video_features( pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer ) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in Video-LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - if input_ids.shape[1] != 1: - for features, frames in ((image_features, 1), (video_features, num_frames)): - if features is not None: - ( - inputs_embeds, - attention_mask, - labels, - position_ids, - input_ids, - ) = self._merge_input_ids_with_visual_features( - features, - inputs_embeds, - input_ids, - attention_mask, - labels, - num_frames=frames, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - - # TODO: @raushan retain only the new behavior after v4.47 - else: - if pixel_values_images is not None: - n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() - n_image_features = image_features.shape[1] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item() + n_video_features = video_features.shape[1] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - if pixel_values_videos is not None: - n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item() - n_video_features = video_features.shape[1] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - special_image_mask = ( - (input_ids == self.config.video_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + special_image_mask = ( + (input_ids == self.config.video_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, @@ -720,17 +642,6 @@ def prepare_inputs_for_generation( ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - if input_ids is not None: - img_token_not_enough = (input_ids == self.config.image_token_index).sum( - 1 - ).max() < self.config.image_seq_length - video_token_not_enough = (input_ids == self.config.video_token_index).sum( - 1 - ).max() < self.config.video_seq_length - legacy_processing = (img_token_not_enough and pixel_values_images is not None) or ( - video_token_not_enough and pixel_values_videos is not None - ) - model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -741,7 +652,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if legacy_processing or cache_position[0] == 0: + if cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values_images"] = pixel_values_images diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index bd6f91270965bb..057324b28bd5b9 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -62,8 +62,8 @@ def __init__( self, image_processor=None, tokenizer=None, - patch_size=None, - vision_feature_select_strategy=None, + patch_size=14, + vision_feature_select_strategy="default", image_token="", # set the default and let users change if they have peculiar special tokens in rare cases video_token="