From bd9f4d79517a3ad2f9da999d090dc3bbfc506dc4 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 15 May 2024 16:42:29 +0500 Subject: [PATCH] Add Video Llava (#29733) * add model draft * update docstring * add tests * support image and video as input * update for better handling of mixed input and clean-up a bit * bug when mixed inputs & add tests * Update README.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Merge remote-tracking branch 'upstream/main' into video_llava * link to abstract of paper in README * fix test * fix-copies * make tests happy * skip docstest for now * do not run doctest for now * Update src/transformers/models/video_llava/processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/image_processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/image_processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/image_processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/image_processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/video_llava/test_modeling_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/image_processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * address review comments * failing tests * Fix vocab_size in common tests for VLMs * codestyle * Update src/transformers/models/video_llava/configuration_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/configuration_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/modeling_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/modeling_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/video_llava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/video_llava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/image_processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/video_llava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/processing_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/video_llava/test_modeling_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/video_llava/test_modeling_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/video_llava/test_modeling_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * PR suggestions * fix-copies * Update src/transformers/models/video_llava/configuration_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/video_llava/configuration_video_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add full example in docs * clean-up with new model-id * [run-slow] video_llava * update docstring * [run-slow] video_llava * remove all achive maps * fix some tests * test was supposed to be skipped for llava :) --------- Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- README_fr.md | 1 - README_te.md | 1 - docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/video_llava.md | 129 ++++ docs/source/en/perf_infer_gpu_one.md | 1 + src/transformers/__init__.py | 16 + src/transformers/image_utils.py | 3 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + .../models/video_llava/__init__.py | 71 ++ .../video_llava/configuration_video_llava.py | 126 ++++ .../convert_video_llava_weights_to_hf.py | 159 ++++ .../image_processing_video_llava.py | 422 +++++++++++ .../video_llava/modeling_video_llava.py | 686 ++++++++++++++++++ .../video_llava/processing_video_llava.py | 143 ++++ src/transformers/utils/dummy_pt_objects.py | 28 + tests/models/video_llava/__init__.py | 0 .../test_image_processing_video_llava.py | 305 ++++++++ .../video_llava/test_modeling_video_llava.py | 536 ++++++++++++++ tests/test_modeling_common.py | 2 +- 25 files changed, 2637 insertions(+), 3 deletions(-) create mode 100644 docs/source/en/model_doc/video_llava.md create mode 100644 src/transformers/models/video_llava/__init__.py create mode 100644 src/transformers/models/video_llava/configuration_video_llava.py create mode 100644 src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py create mode 100644 src/transformers/models/video_llava/image_processing_video_llava.py create mode 100644 src/transformers/models/video_llava/modeling_video_llava.py create mode 100644 src/transformers/models/video_llava/processing_video_llava.py create mode 100644 tests/models/video_llava/__init__.py create mode 100644 tests/models/video_llava/test_image_processing_video_llava.py create mode 100644 tests/models/video_llava/test_modeling_video_llava.py diff --git a/README_fr.md b/README_fr.md index d58bb0bbca385d..0fffb6d936076d 100644 --- a/README_fr.md +++ b/README_fr.md @@ -288,7 +288,6 @@ Suivez les pages d'installation de Flax, PyTorch ou TensorFlow pour voir comment Nombre actuel de points de contrôle : ![](https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen) - 🤗 Transformers fournit actuellement les architectures suivantes: consultez [ici](https://huggingface.co/docs/transformers/model_summary) pour un résumé global de chacune d'entre elles. Pour vérifier si chaque modèle a une implémentation en Flax, PyTorch ou TensorFlow, ou s'il a un tokenizer associé pris en charge par la bibliothèque 🤗 Tokenizers, consultez [ce tableau](https://huggingface.co/docs/transformers/index#supported-frameworks). diff --git a/README_te.md b/README_te.md index 19cbe320624186..f23476efda5f2f 100644 --- a/README_te.md +++ b/README_te.md @@ -293,7 +293,6 @@ Flax, PyTorch లేదా TensorFlow యొక్క ఇన్‌స్టా 🤗 ట్రాన్స్‌ఫార్మర్లు ప్రస్తుతం కింది ఆర్కిటెక్చర్‌లను అందజేస్తున్నాయి: వాటిలో ప్రతి ఒక్కటి ఉన్నత స్థాయి సారాంశం కోసం [ఇక్కడ](https://huggingface.co/docs/transformers/model_summary) చూడండి. - ఈ అమలులు అనేక డేటాసెట్‌లలో పరీక్షించబడ్డాయి (ఉదాహరణ స్క్రిప్ట్‌లను చూడండి) మరియు అసలైన అమలుల పనితీరుతో సరిపోలాలి. మీరు [డాక్యుమెంటేషన్](https://github.com/huggingface/transformers/tree/main/examples) యొక్క ఉదాహరణల విభాగంలో పనితీరుపై మరిన్ని వివరాలను కనుగొనవచ్చు. ## ఇంకా నేర్చుకో diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 325ecb0c4d2c80..ae671cbf13e343 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -806,6 +806,8 @@ title: TVP - local: model_doc/udop title: UDOP + - local: model_doc/video_llava + title: VideoLlava - local: model_doc/vilt title: ViLT - local: model_doc/vipllava diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 16018c32e57ec0..9a8c2ebbe37e00 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -304,6 +304,7 @@ Flax), PyTorch, and/or TensorFlow. | [UnivNet](model_doc/univnet) | ✅ | ❌ | ❌ | | [UPerNet](model_doc/upernet) | ✅ | ❌ | ❌ | | [VAN](model_doc/van) | ✅ | ❌ | ❌ | +| [VideoLlava](model_doc/video_llava) | ✅ | ❌ | ❌ | | [VideoMAE](model_doc/videomae) | ✅ | ❌ | ❌ | | [ViLT](model_doc/vilt) | ✅ | ❌ | ❌ | | [VipLlava](model_doc/vipllava) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/video_llava.md b/docs/source/en/model_doc/video_llava.md new file mode 100644 index 00000000000000..0dad4df06f0e87 --- /dev/null +++ b/docs/source/en/model_doc/video_llava.md @@ -0,0 +1,129 @@ + + +# Video-LLaVA + +## Overview + +Video-LLaVa is an open-source multimodal LLM trained by fine-tuning LlamA/Vicuna on multimodal instruction-following data generated by Llava1.5 and VideChat. It is an auto-regressive language model, based on the transformer architecture. Video-LLaVa unifies visual representations to the language feature space, and enables an LLM to perform visual reasoning capabilities on both images and videos simultaneously. + + +The Video-LLaVA model was proposed in [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/abs/2311.10122) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. + +The abstract from the paper is the following: + +*The Large Vision-Language Model (LVLM) has enhanced the performance of various downstream tasks in +visual-language understanding. Most existing approaches +encode images and videos into separate feature spaces, +which are then fed as inputs to large language models. +However, due to the lack of unified tokenization for images and videos, namely misalignment before projection, it +becomes challenging for a Large Language Model (LLM) +to learn multi-modal interactions from several poor projection layers. In this work, we unify visual representation into the language feature space to advance the foundational LLM towards a unified LVLM. As a result, we establish a simple but robust LVLM baseline, Video-LLaVA, +which learns from a mixed dataset of images and videos, +mutually enhancing each other. Video-LLaVA achieves superior performances on a broad range of 9 image benchmarks across 5 image question-answering datasets and 4 +image benchmark toolkits. Additionally, our Video-LLaVA +also outperforms Video-ChatGPT by 5.8%, 9.9%, 18.6%, +and 10.1% on MSRVTT, MSVD, TGIF, and ActivityNet, respectively. Notably, extensive experiments demonstrate that +Video-LLaVA mutually benefits images and videos within +a unified visual representation, outperforming models designed specifically for images or videos. We aim for this +work to provide modest insights into the multi-modal inputs +for the LLM* + +Tips: + +- We advise users to use padding_side="left" when computing batched generation as it leads to more accurate results. Simply make sure to call processor.tokenizer.padding_side = "left" before generating. + +- Note the model has not been explicitly trained to process multiple images/videos in the same prompt, although this is technically possible, you may experience inaccurate results. + +- For better results, we recommend users prompt the model with the correct prompt format: + + +```python +import av +import torch +import numpy as np +import requests +from PIL import Image +from transformers import VideoLlavaForConditionalGeneration, VideoLlavaProcessor + +def read_video_pyav(container, indices): + ''' + Decode the video with PyAV decoder. + Args: + container (`av.container.input.InputContainer`): PyAV container. + indices (`List[int]`): List of frame indices to decode. + Returns: + result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ''' + frames = [] + container.seek(0) + start_index = indices[0] + end_index = indices[-1] + for i, frame in enumerate(container.decode(video=0)): + if i > end_index: + break + if i >= start_index and i in indices: + frames.append(frame) + return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + +model = VideoLlavaForConditionalGeneration.from_pretrained("RaushanTurganbay/video-llava-7b-hf", device_map="auto") +processor = VideoLlavaProcessor.from_pretrained("RaushanTurganbay/video-llava-7b-hf") + +video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset") + +container = av.open(video_path) +total_frames = container.streams.video[0].frames +indices = np.arange(0, total_frames, total_frames / 8).astype(int) +video = read_video_pyav(container, indices) + +prompt = "USER: