Skip to content

Commit

Permalink
Chat template: return vectorized output in processors (#34275)
Browse files Browse the repository at this point in the history
* update chat template

* style

* fix tests

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <[email protected]>

* typehints + docs

* fix tests

* remove unnecessary warnings

* forgot code style :(

* allow users to pass backend and num frames

* Update docs/source/en/chat_templating.md

Co-authored-by: Pavel Iakubovskii <[email protected]>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <[email protected]>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <[email protected]>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <[email protected]>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <[email protected]>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <[email protected]>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <[email protected]>

* Update src/transformers/processing_utils.py

Co-authored-by: Pavel Iakubovskii <[email protected]>

* typo fix

* style

* address comments

* align with "pipeline" template

* update docs

* update docs

* unpack for all kwargs?

* wrong conflict resolution while rebasing

* tmp

* update docs

* Update docs/source/en/chat_templating.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/chat_templating.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/chat_templating.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/chat_templating.md

Co-authored-by: Steven Liu <[email protected]>

---------

Co-authored-by: Pavel Iakubovskii <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2025
1 parent 5f087d1 commit e0646f3
Show file tree
Hide file tree
Showing 12 changed files with 880 additions and 46 deletions.
132 changes: 132 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os
import time

import cv2
import av
import numpy as np
from numba import jit, cuda
from decord import VideoReader, cpu, gpu

import torch
from torchvision import io


video_dir = "/raid/raushan/temp_dir/"
NUM_FRAMES = 32


# @jit(nopython=True, target_backend='cuda') # <-- If you have a cuda GPU
def process_video_cv2(video: cv2.VideoCapture, indices: np.array, length: int):
index = 0
frames = []
while video.isOpened():
success, frame = video.read()
if index in indices:
# Channel 0:B 1:G 2:R
height, width, channel = frame.shape
frames.append(frame[0:height, 0:width, 0:channel])
if success:
index += 1
if index >= length:
break

video.release()
return frames


def read_video_opencv(video_path, num_frames=NUM_FRAMES):
'''
Decode the video with open-cv decoder.
Args:
video_path (str): Path to the video file.
num_frames (int): Number of frames to sample uniformly. Defaults to NUM_FRAMES
Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
'''
video = cv2.VideoCapture(video_path)
fps = int(video.get(cv2.CAP_PROP_FPS))
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
frames = process_video_cv2(video, indices, total_num_frames)
return np.stack(frames)



def read_video_decord(video_path, num_frames=NUM_FRAMES):
'''
Decode the video with Decord decoder.
Args:
video_path (str): Path to the video file.
num_frames (int): Number of frames to sample uniformly. Defaults to NUM_FRAMES
Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
'''
vr = VideoReader(uri=video_path, ctx=cpu(0)) # you need to install from source to use gpu ctx
indices = np.arange(0, len(vr), len(vr) / num_frames).astype(int)
frames = vr.get_batch(indices).asnumpy()
return frames


def read_video_pyav(video_path, num_frames=NUM_FRAMES):
'''
Decode the video with PyAV decoder.
Args:
video_path (str): Path to the video file.
num_frames (int): Number of frames to sample uniformly. Defaults to NUM_FRAMES
Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
'''
container = av.open(video_path)

# sample uniformly "num_frames" frames from the video
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / num_frames).astype(int)

frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])



def read_video_torchvision(video_path, num_frames=NUM_FRAMES):
video, _, info = io.read_video(
video_path,
start_pts=0.0,
end_pts=None,
pts_unit="sec",
output_format="TCHW",
)

idx = torch.linspace(0, video.size(0) - 1, num_frames, dtype=torch.int64)
return video[idx]


decoders = {"decord": read_video_decord, "opencv": read_video_opencv, "av": read_video_pyav, "torchvision": read_video_torchvision}
for name, fn in decoders.items():
start = time.perf_counter()
for video_file in os.listdir(video_dir):
path = f"{video_dir}/{video_file}"
output = fn(path)

end = time.perf_counter()
print(f"Time taken for {name}: {(end-start):.04f} sec")


# Time taken for decord: 475.2979 sec
# Time taken for opencv: 614.6062 sec
# Time taken for av: 1067.0860 sec
# Time taken for torchvision: 1924.0433 sec

44 changes: 42 additions & 2 deletions docs/source/en/chat_templating.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ of text (as is the case with a standard language model), the model instead conti
of one or more **messages**, each of which includes a **role**, like "user" or "assistant", as well as message text.

Much like tokenization, different models expect very different input formats for chat. This is the reason we added
**chat templates** as a feature. Chat templates are part of the tokenizer. They specify how to convert conversations,
**chat templates** as a feature. Chat templates are part of the tokenizer for text-only LLMs or processor for multimodal LLMs. They specify how to convert conversations,
represented as lists of messages, into a single tokenizable string in the format that the model expects.

Let's make this concrete with a quick example using the `mistralai/Mistral-7B-Instruct-v0.1` model:
Expand Down Expand Up @@ -66,10 +66,12 @@ for you, allowing you to write universal code that works for any model.
## How do I use chat templates?

As you can see in the example above, chat templates are easy to use. Simply build a list of messages, with `role`
and `content` keys, and then pass it to the [`~PreTrainedTokenizer.apply_chat_template`] method. Once you do that,
and `content` keys, and then pass it to the [`~PreTrainedTokenizer.apply_chat_template`] or [`~ProcessorMixin.apply_chat_template`] method
depending on what type of model you are using. Once you do that,
you'll get output that's ready to go! When using chat templates as input for model generation, it's also a good idea
to use `add_generation_prompt=True` to add a [generation prompt](#what-are-generation-prompts).

## Usage with text-only LLMs
Here's an example of preparing input for `model.generate()`, using `Zephyr` again:

```python
Expand Down Expand Up @@ -116,6 +118,44 @@ How many helicopters can a human eat in one sitting?</s>
Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all.
```

## Usage with multimodal LLMs

For multimodal LLMs such as [LLaVA](https://huggingface.co/llava-hf) the prompts can be formatted in a similar way. The only difference is you need to pass input images/videos as well along with the text. Each `"content"`
has to be a list containing either a text or an image/video.

Here's an example of preparing input for using `LLaVA` model:

```python
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration

model_id = "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
model = LlavaOnevisionForConditionalGeneration.from_pretrained(model_id) # You may want to use bfloat16 and/or move to GPU here
processor = AutoProcessor.from_pretrained(model_id)

messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a friendly chatbot who always responds in the style of a pirate"}],
},
{
"role": "user",
"content": [
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
{"type": "text", "text": "What are these?"},
],
},
]

processed_chat = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt")
print(processor.batch_decode(processed_chat["input_ids"][:, :30]))
```
This yields a string in LLaVAs expected input format with many `<image>` tokens at the end.
The `<image>` tokens are placeholders and each one will be replaced by image embeddings when the mode is run in the forward call. The `processed_chat` can be further passed into [`~GenerationMixin.generate`] to generate text.
```text
'<|im_start|>system
You are a friendly chatbot who always responds in the style of a pirate<|im_end|><|im_start|>user <image><image><image><image><image><image><image><image>'
```

Arr, 'twas easy after all!

## Is there an automated pipeline for chat?
Expand Down
77 changes: 77 additions & 0 deletions read_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np
import cv2
import requests
from yt_dlp import YoutubeDL
from contextlib import redirect_stdout
from pathlib import Path
import io
import imageio.v3 as iio


url = "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4"
vid = cv2.VideoCapture(url)
# ret, frame = vid.read()

while(True):
# Capture frame-by-frame
ret, frame = vid.read()
#print cap.isOpened(), ret
if frame is not None:
pass
# print(frame.shape)
else:
break

print(vid.isOpened(), frame is not None)

buffer = io.BytesIO(requests.get(url).content)
video = buffer.getvalue()
frames = iio.imread(video, index=None)
print(frames.shape)





youtube_id = "https://www.youtube.com/watch?v=BaW_jenozKc"

ctx = {
"outtmpl": "-",
'logtostderr': True
}

buffer = io.BytesIO()
with redirect_stdout(buffer), YoutubeDL(ctx) as foo:
foo.download([youtube_id])
# Path(f"vi.mp4").write_bytes(buffer.getvalue())

video = buffer.getvalue()
print(type(video))
frames = iio.imread(video, index=None)
print(frames.shape)


import decord
file_obj = io.BytesIO(video)
container = decord.VideoReader(file_obj)
print(container[2].shape)

# print(np.frombuffer(video, dtype=np.uint8).shape)
# img_array = np.asarray(bytearray(video), dtype=np.uint8)
# im = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)



import av

file_obj = io.BytesIO(video)
container = av.open(file_obj)
container.seek(0)
frames = []
for i, frame in enumerate(container.decode(video=0)):
if i > 10:
break
if i >= 0:
frames.append(frame)
out = np.stack([x.to_ndarray(format="rgb24") for x in frames])
print(out.shape)
107 changes: 107 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import av
import torch
import decord
from decord import VideoReader, cpu

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration, SiglipImageProcessor

model_id = "/raid/raushan/llava-next-video-qwen-7b"

model = LlavaNextVideoForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
).to(0)

processor = LlavaNextVideoProcessor.from_pretrained(model_id, torch_dtype=torch.bfloat16)
img_proc = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")

image = Image.open("/raid/raushan/image.png")


def load_video(video_path, max_frames_num,fps=1,force_sample=False):

vr = VideoReader(video_path)
total_frame_num = len(vr)
video_time = total_frame_num / vr.get_avg_fps()
fps = round(vr.get_avg_fps()/fps)
frame_idx = [i for i in range(0, len(vr), fps)]
frame_time = [i/fps for i in frame_idx]
if len(frame_idx) > max_frames_num or force_sample:
sample_fps = max_frames_num
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frame_time = [i/vr.get_avg_fps() for i in frame_idx]
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
spare_frames = vr.get_batch(frame_idx).asnumpy()
print(spare_frames.shape)
return spare_frames,frame_time,video_time


def read_video_pyav(container, indices):
'''
Decode the video with PyAV decoder.
Args:
container (`av.container.input.InputContainer`): PyAV container.
indices (`List[int]`): List of frame indices to decode.
Returns:
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
'''
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])


# define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image", "video")
# <|im_start|>system
# You are a helpful assistant.<|im_end|>
# <|im_start|>user
# <image>Time farmes are this moments and we ahev 64 frames
# Please describe this video in detail.<|im_end|>
# <|im_start|>assistant

conversation = [
{

"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."},
],
},
{

"role": "user",
"content": [
{"type": "text", "text": "The video lasts for 19.97 seconds, and 64 frames are uniformly sampled from it. These frames are located at 0.00s,0.30s,0.60s,0.93s,1.23s,1.57s,1.87s,2.20s,2.50s,2.83s,3.13s,3.47s,3.77s,4.10s,4.40s,4.73s,5.03s,5.37s,5.67s,6.00s,6.30s,6.63s,6.93s,7.27s,7.57s,7.90s,8.20s,8.53s,8.83s,9.17s,9.47s,9.80s,10.10s,10.43s,10.73s,11.07s,11.37s,11.70s,12.00s,12.33s,12.63s,12.97s,13.27s,13.60s,13.90s,14.23s,14.53s,14.87s,15.17s,15.50s,15.80s,16.13s,16.43s,16.77s,17.07s,17.40s,17.70s,18.03s,18.33s,18.67s,18.97s,19.30s,19.60s,19.93s.Please answer the following questions related to this video.\nPlease describe this video in detail."},
{"type": "video"},
],
},
]

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<video>The video lasts for 19.97 seconds, and 64 frames are uniformly sampled from it. These frames are located at 0.00s,0.30s,0.60s,0.93s,1.23s,1.57s,1.87s,2.20s,2.50s,2.83s,3.13s,3.47s,3.77s,4.10s,4.40s,4.73s,5.03s,5.37s,5.67s,6.00s,6.30s,6.63s,6.93s,7.27s,7.57s,7.90s,8.20s,8.53s,8.83s,9.17s,9.47s,9.80s,10.10s,10.43s,10.73s,11.07s,11.37s,11.70s,12.00s,12.33s,12.63s,12.97s,13.27s,13.60s,13.90s,14.23s,14.53s,14.87s,15.17s,15.50s,15.80s,16.13s,16.43s,16.77s,17.07s,17.40s,17.70s,18.03s,18.33s,18.67s,18.97s,19.30s,19.60s,19.93s.Please answer the following questions related to this video.\nPlease describe this video in detail.<|im_end|>\n<|im_start|>assistant"

video_path = "/raid/raushan/karate.mp4" # hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
container = av.open(video_path)

# sample uniformly 8 frames from the video, can sample more for longer videos
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 64).astype(int)
clip = read_video_pyav(container, indices)

clip, frame_time,video_time = load_video(video_path, max_frames_num=64, force_sample=True)
inputs_video = processor(text=prompt, videos=clip, return_tensors="pt").to(device=model.device, dtype=torch.bfloat16)

output = model.generate(**inputs_video, max_new_tokens=100, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True))
Loading

0 comments on commit e0646f3

Please sign in to comment.