From 3947ad8635e0dfd6d828563decd689c9fbeceea0 Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 14 Jun 2024 10:36:09 +0200 Subject: [PATCH 01/22] squash into single commit --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/llava-next-video.md | 259 ++++ docs/source/en/perf_infer_gpu_one.md | 1 + src/transformers/__init__.py | 20 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + .../models/llava_next/modeling_llava_next.py | 8 +- .../models/llava_next_video/__init__.py | 70 ++ .../configuration_llava_next_video.py | 153 +++ .../convert_llava_next_video_weights_to_hf.py | 276 +++++ .../llava_next_video/diff_llava_next_video.py | 559 +++++++++ .../image_processing_llava_next_video.py | 421 +++++++ .../modeling_llava_next_video.py | 1077 +++++++++++++++++ .../processing_llava_next_video.py | 221 ++++ src/transformers/utils/dummy_pt_objects.py | 14 + .../utils/dummy_vision_objects.py | 7 + tests/models/llava/test_modeling_llava.py | 8 + .../llava_next/test_modeling_llava_next.py | 8 + tests/models/llava_next_video/__init__.py | 0 .../test_image_processing_llava_next_video.py | 217 ++++ .../test_modeling_llava_next_video.py | 469 +++++++ .../video_llava/test_modeling_video_llava.py | 8 + .../models/vipllava/test_modeling_vipllava.py | 8 + utils/diff_model_converter.py | 12 +- 29 files changed, 3823 insertions(+), 4 deletions(-) create mode 100644 docs/source/en/model_doc/llava-next-video.md create mode 100644 src/transformers/models/llava_next_video/__init__.py create mode 100644 src/transformers/models/llava_next_video/configuration_llava_next_video.py create mode 100644 src/transformers/models/llava_next_video/convert_llava_next_video_weights_to_hf.py create mode 100644 src/transformers/models/llava_next_video/diff_llava_next_video.py create mode 100644 src/transformers/models/llava_next_video/image_processing_llava_next_video.py create mode 100644 src/transformers/models/llava_next_video/modeling_llava_next_video.py create mode 100644 src/transformers/models/llava_next_video/processing_llava_next_video.py create mode 100644 tests/models/llava_next_video/__init__.py create mode 100644 tests/models/llava_next_video/test_image_processing_llava_next_video.py create mode 100644 tests/models/llava_next_video/test_modeling_llava_next_video.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index be3001dc761a90..f81e712cf5b4fa 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -790,6 +790,8 @@ title: Llava - local: model_doc/llava_next title: LLaVA-NeXT + - local: model_doc/llava-next-video + title: LLaVa-NeXT-Video - local: model_doc/lxmert title: LXMERT - local: model_doc/matcha diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 72237d13839569..31d8b770ed8403 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -181,6 +181,7 @@ Flax), PyTorch, and/or TensorFlow. | [Llama3](model_doc/llama3) | ✅ | ❌ | ✅ | | [LLaVa](model_doc/llava) | ✅ | ❌ | ❌ | | [LLaVA-NeXT](model_doc/llava_next) | ✅ | ❌ | ❌ | +| [LLaVa-NeXT-Video](model_doc/llava-next-video) | ✅ | ❌ | ❌ | | [Longformer](model_doc/longformer) | ✅ | ✅ | ❌ | | [LongT5](model_doc/longt5) | ✅ | ❌ | ✅ | | [LUKE](model_doc/luke) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/llava-next-video.md b/docs/source/en/model_doc/llava-next-video.md new file mode 100644 index 00000000000000..88e41efc29c87c --- /dev/null +++ b/docs/source/en/model_doc/llava-next-video.md @@ -0,0 +1,259 @@ + + +# LLaVa-NeXT-Video + +## Overview + +The LLaVa-NeXT-Video model was proposed in [LLaVA-NeXT: A Strong Zero-shot Video Understanding Model +](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/) by Yuanhan Zhang, Bo Li, Haotian Liu, Yong Jae Lee, Liangke Gui, Di Fu, Jiashi Feng, Ziwei Liu, Chunyuan Li. LLaVa-NeXT-Video improves upon [LLaVa-NeXT](llava_next) by fine-tuning on a mix if video and image dataset thus increasing the model's performance on videos. + +[LLaVA-NeXT](llava_next) surprisingly has strong performance in understanding video content in zero-shot fashion with the AnyRes technique that it uses. The AnyRes technique naturally represents a high-resolution image into multiple images. This technique is naturally generalizable to represent videos because videos can be considered as a set of frames (similar to a set of images in LLaVa-NeXT). The current version of LLaVA-NeXT makes use of AnyRes and trains with supervised fine-tuning (SFT) on top of LLaVA-Next on video data to achieves better video understanding capabilities.The model is a current SOTA among open-source models on [VideoMME bench](https://arxiv.org/abs/2405.21075). + + +The introduction from the blog is the following: + +On January 30, 2024, we released LLaVA-NeXT, an open-source Large Multimodal Model (LMM) that has been trained exclusively on text-image data. With the proposed AnyRes technique, it boosts capabilities in reasoning, OCR, and world knowledge, demonstrating remarkable performance across a spectrum of image-based multimodal understanding tasks, and even exceeding Gemini-Pro on several image benchmarks, e.g. MMMU and MathVista. + +**In today’s exploration, we delve into the performance of LLaVA-NeXT within the realm of video understanding tasks. We reveal that LLaVA-NeXT surprisingly has strong performance in understanding video content. The current version of LLaVA-NeXT for videos has several improvements: + +- Zero-shot video representation capabilities with AnyRes: The AnyRes technique naturally represents a high-resolution image into multiple images that a pre-trained VIT is able to digest, and forms them into a concantenated sequence. This technique is naturally generalizable to represent videos (consisting of multiple frames), allowing the image-only-trained LLaVA-Next model to perform surprisingly well on video tasks. Notably, this is the first time that LMMs show strong zero-shot modality transfer ability. +- Inference with length generalization improves on longer videos. The linear scaling technique enables length generalization, allowing LLaVA-NeXT to effectively handle long-video beyond the limitation of the "max_token_length" of the LLM. +- Strong video understanding ability. (1) LLaVA-Next-Image, which combines the above two techniques, yields superior zero-shot performance than open-source LMMs tuned on videos. (2) LLaVA-Next-Video, further supervised fine-tuning (SFT) LLaVA-Next-Image on video data, achieves better video understanding capabilities compared to LLaVA-Next-Image. (3) LLaVA-Next-Video-DPO, which aligns the model response with AI feedback using direct preference optimization (DPO), showing significant performance boost. +- Efficient deployment and inference with SGLang. It allows 5x faster inference on video tasks, allowing more scalable serving such as million-level video re-captioning. See instructions in our repo.** + + +This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay). +The original code can be found [here](https://github.com/LLaVA-VL/LLaVA-NeXT/tree/inference). + +## Usage tips + +- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating. + +- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use tokenizer's `apply_chat_template` to format your prompts correctly. Below is an example of how to do that. + +We will use [LLaVA-NeXT-Video-7B-hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-hf) and a conversation history of videos and images. Each content field has to be a list of dicts, as follows: + +```python +from transformers import LlavaNextVideoProcessor + +processor = LlavaNextVideoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf") + +conversation = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s shown in this image?"}, + {"type": "image"}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "This image shows a red stop sign."},] + }, + { + + "role": "user", + "content": [ + {"type": "text", "text": "Why is this video funny?"}, + {"type": "video"}, + ], + }, +] + +text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + +# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your visuals +print(text_prompt) +``` + +## Usage example + +### Single Media Mode + +The model can accept both images and videos as input. Here's an example code for inference in half-precision (`torch.float16`): + +```python +import av +import torch +import numpy as np +from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor + +def read_video_pyav(container, indices): + ''' + Decode the video with PyAV decoder. + Args: + container (`av.container.input.InputContainer`): PyAV container. + indices (`List[int]`): List of frame indices to decode. + Returns: + result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ''' + frames = [] + container.seek(0) + start_index = indices[0] + end_index = indices[-1] + for i, frame in enumerate(container.decode(video=0)): + if i > end_index: + break + if i >= start_index and i in indices: + frames.append(frame) + return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + +# Load the model in half-precision +model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf", torch_dtype=torch.float16, device_map="auto") +processor = LlavaNextVideoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf") + +# Load the video as an np.array, sampling uniformly 8 frames (can sample more for longer videos) +video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset") +container = av.open(video_path) +total_frames = container.streams.video[0].frames +indices = np.arange(0, total_frames, total_frames / 8).astype(int) +video = read_video_pyav(container, indices) + +conversation = [ + { + + "role": "user", + "content": [ + {"type": "text", "text": "Why is this video funny?"}, + {"type": "video"}, + ], + }, +] + +prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) +inputs = processor(text=prompt, videos=video, return_tensors="pt") + +out = model.generate(**inputs, max_new_tokens=60) +processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True) +``` + + +### Mixed Media Mode + +The model can also generate from an interleaved image-video inputs. However note, that it was not trained in interleaved image-video setting which might affect the performance. Below is an example usage for mixed media input, add the following lines to the above code snippet: + +```python +from PIL import Image +import requests + +# Generate from image and video mixed inputs +# Load and image and write a new prompt +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image = Image.open(requests.get(url, stream=True).raw) +conversation = [ + { + + "role": "user", + "content": [ + {"type": "text", "text": "How many cats are there in the image?"}, + {"type": "image"}, + ], + }, + { + + "role": "assistant", + "content": [{"type": "text", "text": "There are two cats"}], + }, + { + + "role": "user", + "content": [ + {"type": "text", "text": "Why is this video funny?"}, + {"type": "video"}, + ], + }, +] +prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) +inputs = processor(text=prompt, images=image, videos=clip, padding=True, return_tensors="pt") + +# Generate +generate_ids = model.generate(**inputs, max_length=50) +processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) + +``` + +## Model optimization + +### Quantization using Bitsandbytes for memory efficiency + +The model can be loaded in lower bits, significantly reducing memory burden while maintaining the performance of the original model. This allows for efficient deployment on resource-constrained cases. + +First make sure to install bitsandbytes by running `pip install bitsandbytes` and to have access to a CUDA compatible GPU device. Load the quantized model by simply adding [`BitsAndBytesConfig`](../main_classes/quantization#transformers.BitsAndBytesConfig) as shown below: + + +```python +from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor + +# specify how to quantize the model +quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, +) + +model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf", quantization_config=quantization_config, device_map="auto") +``` + + +### Flash-Attention 2 to speed-up generation + +Additionally, we can greatly speed-up model inference by using [Flash Attention](../perf_train_gpu_one.md#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model. + +First, make sure to install the latest version of Flash Attention 2: + +```bash +pip install -U flash-attn --no-build-isolation +``` + +Also, you should have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention-2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`. + +To load and run a model using Flash Attention-2, simply add `attn_implementation="flash_attention_2"` when loading the model as follows: + +```python +from transformers import LlavaNextVideoForConditionalGeneration + +model = LlavaNextVideoForConditionalGeneration.from_pretrained( + "llava-hf/LLaVA-NeXT-Video-7B-hf", + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", +).to(0) +``` + + + +## LlavaNextVideoConfig + +[[autodoc]] LlavaNextVideoConfig + +## LlavaNextVideoProcessor + +[[autodoc]] LlavaNextVideoProcessor + +## LlavaNextVideoImageProcessor + +[[autodoc]] LlavaNextVideoImageProcessor + +## LlavaNextVideoForConditionalGeneration + +[[autodoc]] LlavaNextVideoForConditionalGeneration + - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index a81fc493813816..c5145da3a62d76 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -55,6 +55,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [Llava](https://huggingface.co/docs/transformers/model_doc/llava) * [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next) +* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video) * [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava) * [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava) * [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4976a4a1b90e7e..ad61e39422df54 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -510,6 +510,10 @@ "LlavaNextConfig", "LlavaNextProcessor", ], + "models.llava_next_video": [ + "LlavaNextVideoConfig", + "LlavaNextVideoProcessor", + ], "models.longformer": [ "LongformerConfig", "LongformerTokenizer", @@ -1140,6 +1144,7 @@ _import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"]) _import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"]) _import_structure["models.llava_next"].append("LlavaNextImageProcessor") + _import_structure["models.llava_next_video"].append("LlavaNextVideoImageProcessor") _import_structure["models.mask2former"].append("Mask2FormerImageProcessor") _import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"]) _import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"]) @@ -2415,6 +2420,12 @@ "LlavaNextPreTrainedModel", ] ) + _import_structure["models.llava_next_video"].extend( + [ + "LlavaNextVideoForConditionalGeneration", + "LlavaNextVideoPreTrainedModel", + ] + ) _import_structure["models.longformer"].extend( [ "LongformerForMaskedLM", @@ -5105,6 +5116,10 @@ LlavaNextConfig, LlavaNextProcessor, ) + from .models.llava_next_video import ( + LlavaNextVideoConfig, + LlavaNextVideoProcessor, + ) from .models.longformer import ( LongformerConfig, LongformerTokenizer, @@ -5767,6 +5782,7 @@ ) from .models.levit import LevitFeatureExtractor, LevitImageProcessor from .models.llava_next import LlavaNextImageProcessor + from .models.llava_next_video import LlavaNextVideoImageProcessor from .models.mask2former import Mask2FormerImageProcessor from .models.maskformer import ( MaskFormerFeatureExtractor, @@ -6830,6 +6846,10 @@ LlavaNextForConditionalGeneration, LlavaNextPreTrainedModel, ) + from .models.llava_next_video import ( + LlavaNextVideoForConditionalGeneration, + LlavaNextVideoPreTrainedModel, + ) from .models.longformer import ( LongformerForMaskedLM, LongformerForMultipleChoice, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 24b602f18c8f38..b86f088ed5506a 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -124,6 +124,7 @@ llama, llava, llava_next, + llava_next_video, longformer, longt5, luke, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 40e282166ef99e..7799029cbaee35 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -140,6 +140,7 @@ ("lilt", "LiltConfig"), ("llama", "LlamaConfig"), ("llava", "LlavaConfig"), + ("llava-next-video", "LlavaNextVideoConfig"), ("llava_next", "LlavaNextConfig"), ("longformer", "LongformerConfig"), ("longt5", "LongT5Config"), @@ -417,6 +418,7 @@ ("llama2", "Llama2"), ("llama3", "Llama3"), ("llava", "LLaVa"), + ("llava-next-video", "LLaVa-NeXT-Video"), ("llava_next", "LLaVA-NeXT"), ("longformer", "Longformer"), ("longt5", "LongT5"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index b316a1a55ddeed..6dac0b2e9a8dea 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -94,6 +94,7 @@ ("layoutlmv3", ("LayoutLMv3ImageProcessor",)), ("levit", ("LevitImageProcessor",)), ("llava", ("CLIPImageProcessor",)), + ("llava-next-video", ("LlavaNextVideoImageProcessor",)), ("llava_next", ("LlavaNextImageProcessor",)), ("mask2former", ("Mask2FormerImageProcessor",)), ("maskformer", ("MaskFormerImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index adfcc7af9fbc88..aa7c737fb45836 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -298,6 +298,7 @@ ("idefics2", "Idefics2ForConditionalGeneration"), ("layoutlm", "LayoutLMForMaskedLM"), ("llava", "LlavaForConditionalGeneration"), + ("llava-next-video", "LlavaNextVideoForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"), ("longformer", "LongformerForMaskedLM"), ("luke", "LukeForMaskedLM"), @@ -698,6 +699,7 @@ ("instructblip", "InstructBlipForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"), + ("llava-next-video", "LlavaNextVideoForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 4a8295cc830419..dc2d4ec11c11d0 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -68,6 +68,7 @@ ("layoutlmv2", "LayoutLMv2Processor"), ("layoutlmv3", "LayoutLMv3Processor"), ("llava", "LlavaProcessor"), + ("llava-next-video", "LlavaNextVideoProcessor"), ("llava_next", "LlavaNextProcessor"), ("markuplm", "MarkupLMProcessor"), ("mctct", "MCTCTProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index e99bc89205cbdf..20283fffbcf646 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -241,6 +241,7 @@ ), ), ("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("llava-next-video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)), ( diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index c052af3b3c8a19..0640960274cb92 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -577,6 +577,9 @@ def _merge_input_ids_with_image_features( final_attention_mask = torch.zeros( batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device ) + final_input_ids = torch.full( + (batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=input_ids.device + ) final_labels = None if labels is not None: final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long) @@ -594,6 +597,7 @@ def _merge_input_ids_with_image_features( # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices] if labels is not None: final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] @@ -626,7 +630,7 @@ def _merge_input_ids_with_image_features( final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - return final_embedding, final_attention_mask, position_ids, final_labels + return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids def pack_image_features(self, image_features, image_sizes, image_newline=None): """ @@ -796,7 +800,7 @@ def forward( ) inputs_embeds = inputs_embeds.to(image_features.dtype) - inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features( + inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features( image_features, feature_lens, inputs_embeds, diff --git a/src/transformers/models/llava_next_video/__init__.py b/src/transformers/models/llava_next_video/__init__.py new file mode 100644 index 00000000000000..d079643e73e99d --- /dev/null +++ b/src/transformers/models/llava_next_video/__init__.py @@ -0,0 +1,70 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_llava_next_video": ["LlavaNextVideoConfig"], + "processing_llava_next_video": ["LlavaNextVideoProcessor"], +} + + +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_llava_next_video"] = ["LlavaNextVideoImageProcessor"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_llava_next_video"] = [ + "LlavaNextVideoForConditionalGeneration", + "LlavaNextVideoPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_llava_next_video import LlavaNextVideoConfig + from .processing_llava_next_video import LlavaNextVideoProcessor + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_llava_next_video import LlavaNextVideoImageProcessor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_llava_next_video import ( + LlavaNextVideoForConditionalGeneration, + LlavaNextVideoPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/llava_next_video/configuration_llava_next_video.py b/src/transformers/models/llava_next_video/configuration_llava_next_video.py new file mode 100644 index 00000000000000..59bf460e84a631 --- /dev/null +++ b/src/transformers/models/llava_next_video/configuration_llava_next_video.py @@ -0,0 +1,153 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the diff. If any change should be done, please apply the change to the +# diff.py file directly. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from transformers import PretrainedConfig + +from ..auto import CONFIG_MAPPING + + +class LlavaNextVideoConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlavaNextVideoForConditionalGeneration`]. It is used to instantiate an + Llava-NeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the [llava-hf/LLaVA-NeXT-Video-7B-hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-hf) + model. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + image_token_index (`int`, *optional*, defaults to 32001): + The image token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. + If `"full"`, the full vision features are used. + vision_feature_layer (`int`, *optional*, defaults to -2): + The index of the layer to select the vision feature. + image_grid_pinpoints (`List`, *optional*, defaults to `[[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]`): + A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list + of the form `(height, width)`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + video_token_index (`int`, *optional*, defaults to 32000): + The video token index to encode the image prompt. + spatial_pool_mode (`str`, *optional*, defaults to `"average"`): + Pooling mode to use for videos. Can be "average", "max" or "conv". + spatial_pool_stride (`int`, *optional*, defaults to 2): + Stride used in the pooling layer for videos. + + Example: + + ```python + >>> from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> configuration = LlavaNextVideoConfig(vision_config, text_config) + + >>> model = LlavaNextVideoForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llava_next_video" + is_composition = True + + def __init__( + self, + vision_config=None, + text_config=None, + ignore_index=-100, + image_token_index=32001, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_grid_pinpoints=None, + tie_word_embeddings=False, + video_token_index=32000, + spatial_pool_mode="average", + spatial_pool_stride=2, + **kwargs, + ): + self.video_token_index = video_token_index + self.spatial_pool_mode = spatial_pool_mode + self.spatial_pool_stride = spatial_pool_stride + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + image_grid_pinpoints = ( + image_grid_pinpoints + if image_grid_pinpoints is not None + else [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] + ) + self.image_grid_pinpoints = image_grid_pinpoints + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/src/transformers/models/llava_next_video/convert_llava_next_video_weights_to_hf.py b/src/transformers/models/llava_next_video/convert_llava_next_video_weights_to_hf.py new file mode 100644 index 00000000000000..aae44eee97a032 --- /dev/null +++ b/src/transformers/models/llava_next_video/convert_llava_next_video_weights_to_hf.py @@ -0,0 +1,276 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert LLaVa-NeXT-Video checkpoints from the original repository. + +URL: https://github.com/LLaVA-VL/LLaVA-NeXT/tree/inference +""" + +import argparse +import glob +import json +from pathlib import Path + +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download, snapshot_download +from safetensors import safe_open + +from transformers import ( + AddedToken, + AutoConfig, + AutoTokenizer, + LlavaNextImageProcessor, + LlavaNextVideoConfig, + LlavaNextVideoForConditionalGeneration, + LlavaNextVideoImageProcessor, + LlavaNextVideoProcessor, +) + + +KEYS_TO_MODIFY_MAPPING = { + "model.vision_tower.": "", + ".vision_resampler": "", # all lmms-lab models do avg pooling, so no vision_resampler + "model.mm_projector": "multi_modal_projector", + "model": "model.model", + "vision_model.model": "vision_model", + "lm_head": "language_model.lm_head", + "model.model": "language_model.model", + "multi_modal_projector.0": "multi_modal_projector.linear_1", + "multi_modal_projector.2": "multi_modal_projector.linear_2", + "language_model.model.image_newline": "image_newline", +} + +# {{SYSTEM_PROMPT}} USER: \n{{PROMPT}} ASSISTANT:" assistant end with " " +chat_vicuna = ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ message['content'][0]['text'] }}" + "{% else %}" + "{{ message['role'].upper() + ': '}}" + "{% endif %}" + "{# Render all images first #}" + "{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}" + "{{ '\n' }}" + "{% endfor %}" + "{# Render all text next #}" + "{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}" + "{{ content['text'] + ' '}}" + "{% endfor %}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ 'ASSISTANT:' }}" + "{% endif %}" +) + +# "[INST] \nWhat is shown in this image? [/INST]" assistant end with " " +chat_mistral = ( + "{% for message in messages %}" + "{% if message['role'] == 'user' %}" + "{{ '[INST] ' }}" + "{# Render all images first #}" + "{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}" + "{{ '\n' }}" + "{% endfor %}" + "{# Render all text next #}" + "{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}" + "{{ content['text'] }}" + "{% endfor %}" + "{{' [/INST]' }}" + "{% elif message['role'] == 'assistant' %}" + r"{{ ' ' + message['content'][0]['text'] + '<\s> '}}" + "{% else %}" + "{{ raise_exception('Only user and assistant roles are supported!') }}" + "{% endif %}" + "{% endfor %}" +) + +# "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n" +chat_yi = ( + "{% for message in messages %}" + "{{'<|im_start|>' + message['role'] + '\n'}}" + "{# Render all images first #}" + "{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}" + "{{ '\n' }}" + "{% endfor %}" + "{# Render all text next #}" + "{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}" + "{{ content['text'] }}" + "{% endfor %}" + "{{'<|im_end|>' + '\n'}}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '<|im_start|>assistant\n' }}" + "{% endif %}" +) + +model2template = { + "lmms-lab/LLaVA-NeXT-Video-7B-32K": chat_mistral, + "lmms-lab/LLaVA-NeXT-Video-7B": chat_vicuna, + "lmms-lab/LLaVA-NeXT-Video-7B-DPO": chat_vicuna, + "lmms-lab/LLaVA-NeXT-Video-34B": chat_yi, + "lmms-lab/LLaVA-NeXT-Video-34B-DPO": chat_yi, +} + + +def load_original_state_dict(model_id): + directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"]) + + original_state_dict = {} + for path in glob.glob(f"{directory_path}/*"): + if path.endswith(".safetensors"): + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + original_state_dict[key] = f.get_tensor(key) + + return original_state_dict + + +def convert_state_dict_to_hf(state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if key.endswith(".inv_freq"): + continue + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + new_state_dict[key] = value.to(torch.bfloat16) + return new_state_dict + + +def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): + # load original config + filepath = hf_hub_download(repo_id=model_id, filename="config.json", repo_type="model") + with open(filepath) as f: + data = json.load(f) + print(data) + + if model_id == "lmms-lab/LLaVA-NeXT-Video-7B-32K": + text_model_id = "mistralai/Mistral-7B-Instruct-v0.2" + video_token_index = 32000 + image_token_index = 32001 + overwrite_text_config = {} + elif model_id in ["lmms-lab/LLaVA-NeXT-Video-7B", "lmms-lab/LLaVA-NeXT-Video-7B-DPO"]: + text_model_id = "lmsys/vicuna-7b-v1.5" + video_token_index = 32000 + image_token_index = 32001 + overwrite_text_config = {"factor": 2.0, "type": "linear"} + elif model_id in ["lmms-lab/LLaVA-NeXT-Video-34B", "lmms-lab/LLaVA-NeXT-Video-34B-DPO"]: + text_model_id = "NousResearch/Nous-Hermes-2-Yi-34B" + video_token_index = 64000 + image_token_index = 64001 + overwrite_text_config = {} + else: + raise ValueError("Incorrect checkpoint referenced. Text model-id not identified!") + + vision_model_id = data["mm_vision_tower"] + + torch.set_default_dtype(torch.bfloat16) + text_config = AutoConfig.from_pretrained(text_model_id) + text_config = text_config.to_dict() + text_config.update(overwrite_text_config) + + tokenizer = AutoTokenizer.from_pretrained(text_model_id, use_fast=True, padding_side="left") + tokenizer.add_tokens(AddedToken("