Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llava] Fix llava index errors #28032

Merged
merged 12 commits into from
Dec 22, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Dec 14, 2023

What does this PR do?

Fixes errors on the Hub such as https://huggingface.co/llava-hf/llava-1.5-7b-hf/discussions/6 and https://huggingface.co/llava-hf/bakLlava-v1-hf/discussions/4

I did not managed to repro as the issue seems to happen on some specific custom images for some reason, however @gullalc managed to find a fix https://huggingface.co/llava-hf/llava-1.5-7b-hf/discussions/6#657a2aa96cd623f45c3c499f which do not affect generation as I can confirm by the slow tests.

The fix is simply to mask out the indices that are out of range of the extended_attention_mask - added also the same fix on VipLlava architecture

cc @amyeroberts

Fixes #28197, Fixes #27901

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this fix!

This seems like a hack to cover over matrix creation and indexing logic above. It would be better to prevent this from happening at all. Even if this fix doesn't change our slow generations, I'd rather we were able to repro the issue first to make sure the behaviour is what we want. Can the users who report the issue share an image we can use to trigger the problems?

# Ensuring indices are within bounds - and avoid CUDA index errors
# See https://huggingface.co/llava-hf/llava-1.5-7b-hf/discussions/6 for more details
valid_indices = non_attended_tokens < extended_attention_mask.shape[1]
new_batch_index = batch_index[valid_indices]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see why this applies for new_non_attended_tokens = non_attended_tokens[valid_indices] but not for new_batch_index as extended_attention_mask.shape[0] can be different and so have a different set of valid indices

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would cause a shape mismatch error if we only change the non_attended_tokens.
batch_index and non_attended_tokens come from this operation:
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)

It points to which sample in batch and what token index. Pointing to row and column in extended_attention_mask.

@younesbelkada
Copy link
Contributor Author

Hi @amyeroberts , I agree it is quite hacky, let me take some time to further investigate and provide a proper fix

@gullalc
Copy link
Contributor

gullalc commented Dec 14, 2023

@younesbelkada I saw no other issue except here onwards batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)

first_layer_past_key_value.shape[1] is bigger than extended_attention_mask.shape[1], so it should be expected that an index could be present in non_attended_tokens that is larger than extended_attention_mask.shape[1].

That is why filtering out made sense to me at-least. To avoid this hack, I see that either first_layer_past_key_value or extended_attention_mask is the root of the problem.

Copy link
Contributor Author

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @gullalc
I tried many combinations to reproduce the issue - batched, batcged with multiple images, batched with multiple images and long context - and still not able to repro ..
Can you give us more insights on how to repro your issue? Do you use one image per prompt? Are the prompts you use long? Can you somehow reproduce it with an image that you can find on the internet?

@adilzhan-ismailov-depop
Copy link

adilzhan-ismailov-depop commented Dec 18, 2023

hey - facing a similar issue: it seems to appear when both the inputs and generated outputs are long enough, hence different behaviour for different images. One way to replicate is:

import requests
from PIL import Image

import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

model_id = "llava-hf/llava-1.5-7b-hf"

k = 200
user_prompt = "Describe the image:?\n" * k
prompt = f"USER: <image>\n{user_prompt}ASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"

model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
).to(0)

processor = AutoProcessor.from_pretrained(model_id)

raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)

print(k, inputs['input_ids'].size())
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
print(k, output.size())
print(processor.decode(output[0][-100:], skip_special_tokens=True))

Running on A10G on current main

!CUDA_LAUNCH_BLOCKING=1 python test_llava.py

2023-12-18 14:20:27.237501: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-18 14:20:27.237560: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-18 14:20:27.237608: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-18 14:20:27.244438: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Loading checkpoint shards: 100%|██████████████████| 3/3 [00:01<00:00,  1.59it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
200 torch.Size([1, 1412])
/databricks/python/lib/python3.10/site-packages/torch/nn/modules/conv.py:459: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)
  return F.conv2d(input, weight, bias, self.stride,


--- added some prints
Extended attention mask: (1, 575); 
Attention mask: (1, 1413); 
first_layer_past_key_value (1, 1987); 
Target seqlen: 1988; 
Batch index: tensor([0], device='cuda:0'); 
Non attended tokens: tensor([1881], device='cuda:0')
---


../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [0,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
Traceback (most recent call last):
  File "", line 25, in <module>
    output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
  File "/databricks/python/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-e7088835-8718-43c0-b531-9c937824ca9c/lib/python3.10/site-packages/transformers/generation/utils.py", line 1731, in generate
    return self.greedy_search(
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-e7088835-8718-43c0-b531-9c937824ca9c/lib/python3.10/site-packages/transformers/generation/utils.py", line 2592, in greedy_search
    outputs = self(
  File "/databricks/python/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-e7088835-8718-43c0-b531-9c937824ca9c/lib/python3.10/site-packages/transformers/models/llava/modeling_llava.py", line 439, in forward
    extended_attention_mask[batch_index, non_attended_tokens] = 0
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

So if I understand correctly we should be masking the resulting attention_mask instead of extended_attention_mask?

Also, how do these zeros appear in first_layer_past_key_value in the first place?

@younesbelkada
Copy link
Contributor Author

Thanks for the reproducer, I'll try to run some experiments on my end

Also, how do these zeros appear in first_layer_past_key_value in the first place?

Because the extended hidden states are initialized with all zeros , hence on the first layer they should stay un-touched so the first past kv cache should remain all zeros in the places where you have padd tokens

@gullalc
Copy link
Contributor

gullalc commented Dec 18, 2023

@adilzhan-ismailov-depop Thanks for sharing the example. I would also agree that it has something to do with length of the generated output for a certain image. I am not sure if it has anything to do with input length, as the three prompts I tried, the different between prompt length was not much.

@younesbelkada To answer your questions:
"Do you use one image per prompt? Are the prompts you use long? Can you somehow reproduce it with an image that you can find on the internet?"

Yes, I used one image per prompt. Three different prompts were used in different runs. Smallest prompt was 7 words and the biggest one was 17 words. I can try to find more example images on the internet for which the same error is thrown, if needed.

@younesbelkada
Copy link
Contributor Author

Hi @gullalc

I can try to find more example images on the internet for which the same error is thrown, if needed.

Yes that would be really great, thanks !

@aismlv
Copy link
Contributor

aismlv commented Dec 20, 2023

Because the extended hidden states are initialized with all zeros , hence on the first layer they should stay un-touched so the first past kv cache should remain all zeros in the places where you have padd tokens

Thanks - but I think in the example with batch size of one we shouldn't have any pad tokens? We can reproduce the error by adding a padding token to any input manually though:

import requests
from PIL import Image

import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

model_id = "llava-hf/llava-1.5-7b-hf"

prompt = f"USER: <image>\nDescribe the image:?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"

model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    device_map='auto'
)

processor = AutoProcessor.from_pretrained(model_id)

device = 'cuda'

raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(device)

# add padding token manually
pad_token_id = processor.tokenizer.pad_token_id #  32001
inputs['input_ids'] = torch.hstack([torch.ones((1, 1), dtype=torch.int64, device=device) * pad_token_id, inputs['input_ids']])
inputs['attention_mask'] = torch.hstack([torch.zeros((1, 1), dtype=torch.int64, device=device), inputs['input_ids']])

output = model.generate(**inputs, max_new_tokens=200, do_sample=False)

This fails because since we have an image, padding token is not at the first position, so we fail the first time we create extended_attention_mask and try to index it

So why does this happen without padding tokens, and likelihood is higher with longer inputs? This is likely to do with half-precision. If you run this experiment you can see that this happens for float16 much more frequently than for float32:

import torch
import altair as alt
import pandas as pd
from tqdm.auto import tqdm

# Function to run the experiment
def run_experiment(dtype, num_runs=1000):
    lengths = []
    for _ in tqdm(range(num_runs), desc=f"Running with dtype={dtype}"):
        random_tensor = torch.rand((2000, 1), dtype=dtype) # num. of elements in the original example
        batch_index, non_attended_tokens = torch.where(random_tensor == 0)
        lengths.append(len(batch_index))
    return lengths

# Running the experiments
lengths_float16 = run_experiment(torch.float16)
lengths_float32 = run_experiment(torch.float32)

# Creating a DataFrame for visualization
df = pd.DataFrame({
    "Length": lengths_float16 + lengths_float32,
    "Dtype": ["float16"] * len(lengths_float16) + ["float32"] * len(lengths_float32)
})

# Plotting the results
chart = alt.Chart(df).mark_bar().encode(
    x=alt.X('Length:Q', title="Num. of zero entries in 2000-element array"),
    y='count(Length):Q',
    column='Dtype:N'
).properties(
    width=220,
    height=200
)

chart

image

In practical terms I think it's ok, but maybe there is a more elegant way to identify non-attended tokens. The logic that handles the attention mask is still an issue though in case we have real padding tokens in the batch

@floschne
Copy link

floschne commented Dec 20, 2023

Hi :)
Thanks for investigating this! Just to let you know I'm facing the same issue when using images from the German split of the XM3600 dataset with a batch size > 1.

Here is some log extract:

  File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 134, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/homeXXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 391, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 416, in test_step
    return self.lightning_module.test_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/XXX/gitrepos/lmmm/lmmm/model/mixins/evaluation.py", line 201, in test_step
    return self._in_text_image_out_text(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/XXX/gitrepos/lmmm/lmmm/model/mixins/evaluation.py", line 101, in _in_text_image_out_text
    _, pred_text = self.generate(
                   ^^^^^^^^^^^^^^
  File "/home/XXX/gitrepos/lmmm/lmmm/model/lit_llava.py", line 73, in generate
    generated_ids: torch.Tensor = self.model.generate(
                                  ^^^^^^^^^^^^^^^^^^^^
  File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/transformers/generation/utils.py", line 1718, in generate
    return self.greedy_search(
           ^^^^^^^^^^^^^^^^^^^
  File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/transformers/generation/utils.py", line 2579, in greedy_search
    outputs = self(
              ^^^^^
  File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/transformers/models/llava/modeling_llava.py", line 428, in forward
    extended_attention_mask[batch_index, non_attended_tokens] = 0
    ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

The fix introduced in this PR fixes the issue though

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor Author

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts this PR is ready for a review 🙏
I left few explanations to help you reviewing some diffs, let me know if I should break down this into 2 PRs to fix the SDPA issues separately

@@ -431,8 +431,11 @@ def forward(
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0]
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more robust check is to check the entire row that has a dimension of 128 (head_dim) instead of looking only at a single logit value, which can randomly hapoen in some cases in fp16

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool!


prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT:",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts here there is a regression issue with SDPA support of Llava; when users perform batched generation with different number of images in prompts, the model output gibberish. I suspect this is related to the fact SDPA might not support arbitrary attention masks.

Therefore I adapted the test here to make sure we do perform batched generation with SDPA with the same number of images per prompt and I added a regression test that test the previous behaviour by loading the model with attn_implementation="eager".

I also adapted expected values with the ones we should get with T4s, which are used in our CIs. The original values were obtained on an A100, which led to failures currently.

I am happy to break down this into multiple PRs, but I thought we could make all the fixes in a single PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 🔥 - I think it's fine to keep both here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok perfect then!


@slow
@require_bitsandbytes
def test_small_model_integration_test_llama_batched_regression(self):
Copy link
Contributor Author

@younesbelkada younesbelkada Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the regression test I mentioned above, as you can see, here we do perform 3 images on 2 prompts and users need to pass attn_implementation="eager" to retrieve previous behaviour


@slow
@require_bitsandbytes
def test_llava_index_error_bug(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a new test that is based on one of the scripts shared by a contributor on the PR. This test fails on main and passes with this PR

@younesbelkada younesbelkada marked this pull request as ready for review December 22, 2023 15:57
@younesbelkada
Copy link
Contributor Author

Note I will try to fix the SDPA regression (for users that perform multi-image & multi-prompt such as #28184) in a separate PR , meanwhile users can always use the model with attn_implementation="eager" to revert the previous behaviour.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada Thank you for digging into this tricky issue, finding a robust solution and adding these tests - great work 🔥

If slow Llava model tests are all passing happy to merge!


prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT:",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 🔥 - I think it's fine to keep both here

@younesbelkada
Copy link
Contributor Author

Thanks a lot for the review! Tests are passing on my VM which is a 2xT4 with the same pytorch & bnb version as the docker image we use! Merging ! 🚀
Thanks to all contributors for the insightful discussion and the fix!

@younesbelkada younesbelkada merged commit 29e7a1e into huggingface:main Dec 22, 2023
18 checks passed
@younesbelkada younesbelkada deleted the llava-fix-index branch December 22, 2023 16:47
@NicholasCao
Copy link

Have the same problem and this is very helpful. Thanks!

@younesbelkada
Copy link
Contributor Author

For anyone that wants to use this fix before the next release:

pip install -U git+https://github.com/huggingface/transformers.git

@NicholasCao
Copy link

NicholasCao commented Dec 23, 2023

I tried this pr but it still gives me an error.
I think the core issue is

first_layer_past_key_value.size(-1) > extended_attention_mask.size(-1)
# May induce index errors

, which hasn't been addressed.
But I really don't quite understand what this piece of code is doing, the fix I think should look like this, but I don't know if that's correct:

extended_attention_mask = torch.ones(
    (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
    dtype=attention_mask.dtype,
    device=attention_mask.device,
)

# Zero-out the places where we don't need to attend
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
attention_mask[batch_index, non_attended_tokens] = 0
# attention_mask.size() = first_layer_past_key_value.size() + 1

or

extended_attention_mask = torch.ones(
    (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
    dtype=attention_mask.dtype,
    device=attention_mask.device,
)

valid_indices = non_attended_tokens >= attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens - attention_mask.size(-1)] = 0

might be better

@younesbelkada
Copy link
Contributor Author

Thanks @NicholasCao !
Can you file a separate issue for this and tag me? If you can also provide a reproducer it would be great. This PR that got merged fixes the issue explained here: #28032 (comment) which hopefully should cover most of the issues related with llava and index errors

@younesbelkada
Copy link
Contributor Author

@NicholasCao I managed to repro your issue that seems to happen in the case one passes a custom past key value, which is the case for AWQ. It should be fixed in #28239

@NicholasCao
Copy link

I'm not using awq, I'm having this problem when I'm batch inference images, it's harder to reproduce to find the specific image

@younesbelkada
Copy link
Contributor Author

@NicholasCao #28239 should solve it, let me know if the PR fixes your issue

@NicholasCao
Copy link

thx, it works

@younesbelkada
Copy link
Contributor Author

Thanks @NicholasCao !

Saibo-creator pushed a commit to epfl-dlab/transformers-GCD-PR that referenced this pull request Jan 4, 2024
* fix llava index errors

* forward contrib credits from original implementation and fix

* better fix

* final fixes and fix all tests

* fix

* fix nit

* fix tests

* add regression tests

---------

Co-authored-by: gullalc <[email protected]>
staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
* fix llava index errors

* forward contrib credits from original implementation and fix

* better fix

* final fixes and fix all tests

* fix

* fix nit

* fix tests

* add regression tests

---------

Co-authored-by: gullalc <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LLaVA: index error when computing extended_attention_mask
8 participants