-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LLaVA torch.compile
implementation
#29891
Comments
Thanks! Mistral should be fairly easy to implement, just follow the updates done to Llama! 🤗 Generate will be refactored to support compile soon! FYI @gante and @zucchini-nlp |
I've checked that the following code runs: import os
from functools import partial
import requests
import torch
from PIL import Image
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor, StaticCache
os.environ["TOKENIZERS_PARALLELISM"] = "true"
with torch.inference_mode():
processor = LlavaNextProcessor.from_pretrained("llava-mistral")
model = LlavaNextForConditionalGeneration.from_pretrained("llava-mistral", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
static_cache = partial(StaticCache, dtype=torch.float16)
model.language_model._setup_cache(static_cache, max_batch_size=1, max_cache_len=4096)
model.language_model.compile(fullgraph=True)
model.vision_tower.compile(fullgraph=True)
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "[INST] <image>"
inputs = processor(prompt, image, return_tensors="pt").to(model.device)
output = model(**inputs,)
print(output) |
Hey @sheepymeh 👋 Given that the individual models are compileable (according to your script above), the next step would be to rewrite the logic in between the language model and the vision tower to be compilable as well such that The process is very model-dependant, often requiring padding or clever manipulations to enable it. My suggestion would be to iteratively call the compiled forward pass, see where it crashes, rewrite, and repeat until it runs. Then, after the whole thing compiles, do a second pass to optimize performance (with the aid of a profiler) -- the compiled forward should be significantly faster than the uncompiled one, after 2-3 warmup forward passes. Looking forward to the PR! If you get stuck, let us know :) |
Thank you for your suggestions. I'm currently working on a PR and my current progress is on creating fixed-length tensors for everything. For example, |
For reference, here is the (very unoptimized) version I'm working on: Codedef _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = torch.any(input_ids[:, -1] == torch.tensor(self.pad_token_id)).to(torch.int8) # CHANGE: multiply instead of using booleans
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
max_pos_embed = self.language_model.config.max_position_embeddings # CHANGE: automatically use maximum position embedidngs
text_tokens = input_ids != self.config.image_token_index # CHANGE: use a boolean mask instead of torch.where
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
new_token_positions += nb_image_pad[:, None] * left_padding # offset for left padding
text_to_overwrite = new_token_positions.clone()
last_token = new_token_positions.max(dim=-1).values + 1
text_to_overwrite[~text_tokens] = -1 # set to -1 to place image tokens in "waste" element
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
text_tokens, text_to_overwrite = (
text_tokens.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 3. Create the full embedding
# CHANGE: pad to max seq len by default and create a "waste" element at the end
final_embedding = torch.zeros(
batch_size, max_pos_embed + 1, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_pos_embed + 1, dtype=attention_mask.dtype, device=inputs_embeds.device
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_pos_embed + 1), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
# I'm not sure how I could do this in a vectorized way
for batch in range(batch_size):
final_embedding[batch, text_to_overwrite[batch]] = inputs_embeds[batch]
final_attention_mask[batch, text_to_overwrite[batch]] = attention_mask[batch]
if labels is not None:
final_labels[batch, text_to_overwrite[batch]] = labels[batch]
pad_mask = torch.arange(max_pos_embed + 1, device=target_device).unsqueeze(dim=0).repeat(batch_size, 1) >= last_token.unsqueeze(-1)
final_embedding[pad_mask] = self.pad_token_id
final_attention_mask[pad_mask] = 0
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
# if image_to_overwrite.sum() != image_features.shape[:-1].numel():
# raise ValueError(
# f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
# f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
# )
image_features = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
print(image_to_overwrite)
final_embedding[image_to_overwrite] = image_features
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
# CHANGE: incompatible with torch.compile; try to remove the additional tokens
final_embedding, final_attention_mask, final_labels, position_ids = (
final_embedding[:, :max_pos_embed],
final_attention_mask[:, :max_pos_embed],
final_labels[:, :max_pos_embed] if final_labels is not None else None,
position_ids[:, :max_pos_embed],
)
return final_embedding, final_attention_mask, final_labels, position_ids |
When you generate, you mostly want the One thing that could help with this is actually to pre-process the strings, in order to make sure that the input of the model is the only thing that changes in terms of shapes. |
Thank you for the help so far. I've managed to implement @ArthurZucker's suggestion successfully in the preprocessor. However, the def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor requires
Is it possible to somehow pass these values in as a list of ints? Trying that causes this error when using the
|
Not super-sure, but Llava does not require this, only Llava next. |
Yes, it's only Llava-next. Btw, not sure how helpul is this info but in the linked PR we had to convert "image sizes" to a list inside "modeling.py", because there were mismatches in the way resolutions are calculated in the image processor vs in the modeling side. |
Thanks for the inputs. Unfortunately it's not possible to convert tensors to lists in a compiled PyTorch model, thus we would need to implement the vectorized |
Feature request
As per #28981, LLaVA is planned to receive
torch.compile
support. Seeing to the fact that LLaVA is composed of a vision tower and a LLM, both of which can be separately compiled withfullgraph=True
(after support has been added, which is not the case for Mistral), it seems much easier to compile both parts separately as well.Motivation
The
_merge_input_ids_with_image_features
function that connects the two parts is difficult to compile as PyTorch has yet to add support for many of the functions used that require dynamic input sizes, which are necessary here as the number of input image tokens is subject to change.Your contribution
I'd love to try submitting a PR if possible but I'm not sure what the best way to do so is given the current circumstances.
The text was updated successfully, but these errors were encountered: