Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLaVA torch.compile implementation #29891

Open
sheepymeh opened this issue Mar 27, 2024 · 11 comments
Open

LLaVA torch.compile implementation #29891

sheepymeh opened this issue Mar 27, 2024 · 11 comments
Labels
Compilation Issues related to torchdynamo and torchinductor Feature request Request for a new feature Good Difficult Issue

Comments

@sheepymeh
Copy link

Feature request

As per #28981, LLaVA is planned to receive torch.compile support. Seeing to the fact that LLaVA is composed of a vision tower and a LLM, both of which can be separately compiled with fullgraph=True (after support has been added, which is not the case for Mistral), it seems much easier to compile both parts separately as well.

Motivation

The _merge_input_ids_with_image_features function that connects the two parts is difficult to compile as PyTorch has yet to add support for many of the functions used that require dynamic input sizes, which are necessary here as the number of input image tokens is subject to change.

Your contribution

I'd love to try submitting a PR if possible but I'm not sure what the best way to do so is given the current circumstances.

@ArthurZucker ArthurZucker added Feature request Request for a new feature Compilation Issues related to torchdynamo and torchinductor labels Mar 27, 2024
@ArthurZucker
Copy link
Collaborator

Thanks! Mistral should be fairly easy to implement, just follow the updates done to Llama! 🤗 Generate will be refactored to support compile soon!

FYI @gante and @zucchini-nlp

@sheepymeh
Copy link
Author

I've checked that the following code runs:

import os
from functools import partial

import requests
import torch
from PIL import Image

from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor, StaticCache

os.environ["TOKENIZERS_PARALLELISM"] = "true"

with torch.inference_mode():
	processor = LlavaNextProcessor.from_pretrained("llava-mistral")
	model = LlavaNextForConditionalGeneration.from_pretrained("llava-mistral", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()

	static_cache = partial(StaticCache, dtype=torch.float16)
	model.language_model._setup_cache(static_cache, max_batch_size=1, max_cache_len=4096)
	model.language_model.compile(fullgraph=True)
	model.vision_tower.compile(fullgraph=True)

	url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
	image = Image.open(requests.get(url, stream=True).raw)
	prompt = "[INST] <image>"

	inputs = processor(prompt, image, return_tensors="pt").to(model.device)
	output = model(**inputs,)

print(output)

@gante
Copy link
Member

gante commented Mar 29, 2024

Hey @sheepymeh 👋

Given that the individual models are compileable (according to your script above), the next step would be to rewrite the logic in between the language model and the vision tower to be compilable as well such that model.forward can be compiled. It shouldn't require API changes, so we're welcoming contributions 💪

The process is very model-dependant, often requiring padding or clever manipulations to enable it. My suggestion would be to iteratively call the compiled forward pass, see where it crashes, rewrite, and repeat until it runs. Then, after the whole thing compiles, do a second pass to optimize performance (with the aid of a profiler) -- the compiled forward should be significantly faster than the uncompiled one, after 2-3 warmup forward passes.

Looking forward to the PR! If you get stuck, let us know :)

@sheepymeh
Copy link
Author

sheepymeh commented Mar 31, 2024

Thank you for your suggestions. I'm currently working on a PR and my current progress is on creating fixed-length tensors for everything. For example, final_embedding would be initialized to a shape of (batch, max_position_embeddings, embed_dim). However, this would create a lot of unnecessary and wasteful padding, which would be passed into the LLM. How could I mitigate this or is this an acceptable compromise? For reference, the max_position_embeddings for llava-1.6-mistral is 32768, which would be padded a lot of the time.

@sheepymeh
Copy link
Author

For reference, here is the (very unoptimized) version I'm working on:

Code
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
	num_images, num_image_patches, embed_dim = image_features.shape
	batch_size, sequence_length = input_ids.shape
	left_padding = torch.any(input_ids[:, -1] == torch.tensor(self.pad_token_id)).to(torch.int8) # CHANGE: multiply instead of using booleans

	# 1. Create a mask to know where special image tokens are
	special_image_token_mask = input_ids == self.config.image_token_index
	num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
	# Compute the maximum embed dimension
	max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
	max_pos_embed = self.language_model.config.max_position_embeddings # CHANGE: automatically use maximum position embedidngs
	text_tokens = input_ids != self.config.image_token_index # CHANGE: use a boolean mask instead of torch.where

	# 2. Compute the positions where text should be written
	# Calculate new positions for text tokens in merged image-text sequence.
	# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
	# `torch.cumsum` computes how each image token shifts subsequent text token positions.
	# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
	new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
	nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
	new_token_positions += nb_image_pad[:, None] * left_padding  # offset for left padding
	text_to_overwrite = new_token_positions.clone()
	last_token = new_token_positions.max(dim=-1).values + 1
	text_to_overwrite[~text_tokens] = -1 # set to -1 to place image tokens in "waste" element

	# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
	# set the corresponding tensors into their correct target device.
	target_device = inputs_embeds.device
	text_tokens, text_to_overwrite = (
		text_tokens.to(target_device),
		text_to_overwrite.to(target_device),
	)
	attention_mask = attention_mask.to(target_device)

	# 3. Create the full embedding
	# CHANGE: pad to max seq len by default and create a "waste" element at the end
	final_embedding = torch.zeros(
		batch_size, max_pos_embed + 1, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
	)
	final_attention_mask = torch.zeros(
		batch_size, max_pos_embed + 1, dtype=attention_mask.dtype, device=inputs_embeds.device
	)
	if labels is not None:
		final_labels = torch.full(
			(batch_size, max_pos_embed + 1), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
		)

	# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
	# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
	# I'm not sure how I could do this in a vectorized way
	for batch in range(batch_size):
		final_embedding[batch, text_to_overwrite[batch]] = inputs_embeds[batch]
		final_attention_mask[batch, text_to_overwrite[batch]] = attention_mask[batch]
		if labels is not None:
			final_labels[batch, text_to_overwrite[batch]] = labels[batch]

	pad_mask = torch.arange(max_pos_embed + 1, device=target_device).unsqueeze(dim=0).repeat(batch_size, 1) >= last_token.unsqueeze(-1)
	final_embedding[pad_mask] = self.pad_token_id
	final_attention_mask[pad_mask] = 0

	# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
	image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
	image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)

	# if image_to_overwrite.sum() != image_features.shape[:-1].numel():
	# 	raise ValueError(
	# 		f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
	# 		f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
	# 	)

	image_features = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
	print(image_to_overwrite)
	final_embedding[image_to_overwrite] = image_features
	final_attention_mask |= image_to_overwrite
	position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

	# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
	batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
	indices_to_mask = new_token_positions[batch_indices, pad_indices]

	final_embedding[batch_indices, indices_to_mask] = 0

	if labels is None:
		final_labels = None

	# CHANGE: incompatible with torch.compile; try to remove the additional tokens
	final_embedding, final_attention_mask, final_labels, position_ids = (
		final_embedding[:, :max_pos_embed],
		final_attention_mask[:, :max_pos_embed],
		final_labels[:, :max_pos_embed] if final_labels is not None else None,
		position_ids[:, :max_pos_embed],
	)

	return final_embedding, final_attention_mask, final_labels, position_ids

@ArthurZucker
Copy link
Collaborator

Thank you for your suggestions. I'm currently working on a PR and my current progress is on creating fixed-length tensors for everything. For example, final_embedding would be initialized to a shape of (batch, max_position_embeddings, embed_dim). However, this would create a lot of unnecessary and wasteful padding, which would be passed into the LLM. How could I mitigate this or is this an acceptable compromise? For reference, the max_position_embeddings for llava-1.6-mistral is 32768, which would be padded a lot of the time.

When you generate, you mostly want the decoding part, when the input ids is only 1 token, to be fast. This should not require much changes as the final embeddings would be of the same size as the input itds.

One thing that could help with this is actually to pre-process the strings, in order to make sure that the input of the model is the only thing that changes in terms of shapes.
So instead of creating the final embedding, you replace "Hey <image> is this nice?" with Hey <image><image>.......<image> is this nice?" in the processor.
That way the embedding created is of the correct shape.

@sheepymeh
Copy link
Author

This would break compatability with previous code and the original LLaVA codebase completely. Is it advisable to maintain two code paths in the processor (one with a single token and one with multiple per image)?

@sheepymeh
Copy link
Author

Thank you for the help so far. I've managed to implement @ArthurZucker's suggestion successfully in the preprocessor. However, the unpad_image function:

def unpad_image(tensor, original_size):
    """
    Unpads a PyTorch tensor of a padded and resized image.

    Args:
        tensor (`torch.Tensor`):
            The image tensor, assumed to be of shape (num_channels, height, width).
        original_size (`tuple`):
            The original size of the image (height, width).

    Returns:
        `torch.Tensor`: The unpadded image tensor.
    """
    original_height, original_width = original_size
    current_height, current_width = tensor.shape[1:]

    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
        padding = (current_height - new_height) // 2
        unpadded_tensor = tensor[:, padding : current_height - padding, :]
    else:
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
        padding = (current_width - new_width) // 2
        unpadded_tensor = tensor[:, :, padding : current_width - padding]

    return unpadded_tensor

requires original_size to be passed into the model through the preprocessor, and it cannot be in the form of a PyTorch tensor as slicing by tensor values is unsupported:

torch._dynamo.exc.Unsupported: Dynamic slicing on data-dependent value is not supported

from user code:
   File "/home/sheepymeh/transformers/src/transformers/models/llava_next/modeling_llava_next.py", line 427, in unpad_image
    unpadded_tensor = tensor[:, padding : current_height - padding, :]

Is it possible to somehow pass these values in as a list of ints? Trying that causes this error when using the .to() function:

Traceback (most recent call last):
  File "/home/sheepymeh/transformers/test.py", line 14, in <module>
    inputs = inputs.to("cuda:0")
  File "/home/sheepymeh/transformers/src/transformers/feature_extraction_utils.py", line 229, in to
    if torch.is_floating_point(v):
TypeError: is_floating_point(): argument 'input' (position 1) must be Tensor, not list

@ArthurZucker
Copy link
Collaborator

Not super-sure, but Llava does not require this, only Llava next.

@zucchini-nlp
Copy link
Member

Yes, it's only Llava-next. Btw, not sure how helpul is this info but in the linked PR we had to convert "image sizes" to a list inside "modeling.py", because there were mismatches in the way resolutions are calculated in the image processor vs in the modeling side.

@sheepymeh
Copy link
Author

Thanks for the inputs. Unfortunately it's not possible to convert tensors to lists in a compiled PyTorch model, thus we would need to implement the vectorized image_size first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Compilation Issues related to torchdynamo and torchinductor Feature request Request for a new feature Good Difficult Issue
Projects
None yet
Development

No branches or pull requests

4 participants