-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Llava
] Fix llava index errors
#28032
[Llava
] Fix llava index errors
#28032
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this fix!
This seems like a hack to cover over matrix creation and indexing logic above. It would be better to prevent this from happening at all. Even if this fix doesn't change our slow generations, I'd rather we were able to repro the issue first to make sure the behaviour is what we want. Can the users who report the issue share an image we can use to trigger the problems?
# Ensuring indices are within bounds - and avoid CUDA index errors | ||
# See https://huggingface.co/llava-hf/llava-1.5-7b-hf/discussions/6 for more details | ||
valid_indices = non_attended_tokens < extended_attention_mask.shape[1] | ||
new_batch_index = batch_index[valid_indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see why this applies for new_non_attended_tokens = non_attended_tokens[valid_indices]
but not for new_batch_index
as extended_attention_mask.shape[0]
can be different and so have a different set of valid indices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would cause a shape mismatch error if we only change the non_attended_tokens
.
batch_index and non_attended_tokens come from this operation:
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)
It points to which sample in batch and what token index. Pointing to row and column in extended_attention_mask
.
Hi @amyeroberts , I agree it is quite hacky, let me take some time to further investigate and provide a proper fix |
@younesbelkada I saw no other issue except here onwards
That is why filtering out made sense to me at-least. To avoid this hack, I see that either |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @gullalc
I tried many combinations to reproduce the issue - batched, batcged with multiple images, batched with multiple images and long context - and still not able to repro ..
Can you give us more insights on how to repro your issue? Do you use one image per prompt? Are the prompts you use long? Can you somehow reproduce it with an image that you can find on the internet?
hey - facing a similar issue: it seems to appear when both the inputs and generated outputs are long enough, hence different behaviour for different images. One way to replicate is:
Running on A10G on current main
So if I understand correctly we should be masking the resulting Also, how do these zeros appear in |
Thanks for the reproducer, I'll try to run some experiments on my end
Because the extended hidden states are initialized with all zeros , hence on the first layer they should stay un-touched so the first past kv cache should remain all zeros in the places where you have padd tokens |
@adilzhan-ismailov-depop Thanks for sharing the example. I would also agree that it has something to do with length of the generated output for a certain image. I am not sure if it has anything to do with input length, as the three prompts I tried, the different between prompt length was not much. @younesbelkada To answer your questions: Yes, I used one image per prompt. Three different prompts were used in different runs. Smallest prompt was 7 words and the biggest one was 17 words. I can try to find more example images on the internet for which the same error is thrown, if needed. |
Hi @gullalc
Yes that would be really great, thanks ! |
Thanks - but I think in the example with batch size of one we shouldn't have any pad tokens? We can reproduce the error by adding a padding token to any input manually though:
This fails because since we have an image, padding token is not at the first position, so we fail the first time we create extended_attention_mask and try to index it So why does this happen without padding tokens, and likelihood is higher with longer inputs? This is likely to do with half-precision. If you run this experiment you can see that this happens for float16 much more frequently than for float32:
In practical terms I think it's ok, but maybe there is a more elegant way to identify non-attended tokens. The logic that handles the attention mask is still an issue though in case we have real padding tokens in the batch |
Hi :) Here is some log extract: File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 134, in run
self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
File "/homeXXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 391, in _evaluation_step
output = call._call_strategy_hook(trainer, hook_name, *step_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 416, in test_step
return self.lightning_module.test_step(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/XXX/gitrepos/lmmm/lmmm/model/mixins/evaluation.py", line 201, in test_step
return self._in_text_image_out_text(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/XXX/gitrepos/lmmm/lmmm/model/mixins/evaluation.py", line 101, in _in_text_image_out_text
_, pred_text = self.generate(
^^^^^^^^^^^^^^
File "/home/XXX/gitrepos/lmmm/lmmm/model/lit_llava.py", line 73, in generate
generated_ids: torch.Tensor = self.model.generate(
^^^^^^^^^^^^^^^^^^^^
File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/transformers/generation/utils.py", line 1718, in generate
return self.greedy_search(
^^^^^^^^^^^^^^^^^^^
File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/transformers/generation/utils.py", line 2579, in greedy_search
outputs = self(
^^^^^
File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/XXX/miniforge3/envs/lmmm/lib/python3.11/site-packages/transformers/models/llava/modeling_llava.py", line 428, in forward
extended_attention_mask[batch_index, non_attended_tokens] = 0
~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. The fix introduced in this PR fixes the issue though |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts this PR is ready for a review 🙏
I left few explanations to help you reviewing some diffs, let me know if I should break down this into 2 PRs to fix the SDPA issues separately
@@ -431,8 +431,11 @@ def forward( | |||
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: | |||
# Retrieve the first layer to inspect the logits and mask out the hidden states | |||
# that are set to 0 | |||
first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0] | |||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0) | |||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more robust check is to check the entire row that has a dimension of 128 (head_dim) instead of looking only at a single logit value, which can randomly hapoen in some cases in fp16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool!
|
||
prompts = [ | ||
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", | ||
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:", | ||
"USER: <image>\nWhat is this?\nASSISTANT:", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts here there is a regression issue with SDPA support of Llava; when users perform batched generation with different number of images in prompts, the model output gibberish. I suspect this is related to the fact SDPA might not support arbitrary attention masks.
Therefore I adapted the test here to make sure we do perform batched generation with SDPA with the same number of images per prompt and I added a regression test that test the previous behaviour by loading the model with attn_implementation="eager"
.
I also adapted expected values with the ones we should get with T4s, which are used in our CIs. The original values were obtained on an A100, which led to failures currently.
I am happy to break down this into multiple PRs, but I thought we could make all the fixes in a single PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 🔥 - I think it's fine to keep both here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok perfect then!
|
||
@slow | ||
@require_bitsandbytes | ||
def test_small_model_integration_test_llama_batched_regression(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the regression test I mentioned above, as you can see, here we do perform 3 images on 2 prompts and users need to pass attn_implementation="eager"
to retrieve previous behaviour
|
||
@slow | ||
@require_bitsandbytes | ||
def test_llava_index_error_bug(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a new test that is based on one of the scripts shared by a contributor on the PR. This test fails on main and passes with this PR
Note I will try to fix the SDPA regression (for users that perform multi-image & multi-prompt such as #28184) in a separate PR , meanwhile users can always use the model with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@younesbelkada Thank you for digging into this tricky issue, finding a robust solution and adding these tests - great work 🔥
If slow Llava model tests are all passing happy to merge!
|
||
prompts = [ | ||
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", | ||
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:", | ||
"USER: <image>\nWhat is this?\nASSISTANT:", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 🔥 - I think it's fine to keep both here
Thanks a lot for the review! Tests are passing on my VM which is a 2xT4 with the same pytorch & bnb version as the docker image we use! Merging ! 🚀 |
Have the same problem and this is very helpful. Thanks! |
For anyone that wants to use this fix before the next release: pip install -U git+https://github.com/huggingface/transformers.git |
I tried this pr but it still gives me an error. first_layer_past_key_value.size(-1) > extended_attention_mask.size(-1)
# May induce index errors , which hasn't been addressed. extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Zero-out the places where we don't need to attend
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
attention_mask[batch_index, non_attended_tokens] = 0
# attention_mask.size() = first_layer_past_key_value.size() + 1 or extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
valid_indices = non_attended_tokens >= attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens - attention_mask.size(-1)] = 0 might be better |
Thanks @NicholasCao ! |
@NicholasCao I managed to repro your issue that seems to happen in the case one passes a custom past key value, which is the case for AWQ. It should be fixed in #28239 |
I'm not using awq, I'm having this problem when I'm batch inference images, it's harder to reproduce to find the specific image |
@NicholasCao #28239 should solve it, let me know if the PR fixes your issue |
thx, it works |
Thanks @NicholasCao ! |
* fix llava index errors * forward contrib credits from original implementation and fix * better fix * final fixes and fix all tests * fix * fix nit * fix tests * add regression tests --------- Co-authored-by: gullalc <[email protected]>
* fix llava index errors * forward contrib credits from original implementation and fix * better fix * final fixes and fix all tests * fix * fix nit * fix tests * add regression tests --------- Co-authored-by: gullalc <[email protected]>
What does this PR do?
Fixes errors on the Hub such as https://huggingface.co/llava-hf/llava-1.5-7b-hf/discussions/6 and https://huggingface.co/llava-hf/bakLlava-v1-hf/discussions/4
I did not managed to repro as the issue seems to happen on some specific custom images for some reason, however @gullalc managed to find a fix https://huggingface.co/llava-hf/llava-1.5-7b-hf/discussions/6#657a2aa96cd623f45c3c499f which do not affect generation as I can confirm by the slow tests.
The fix is simply to mask out the indices that are out of range of the
extended_attention_mask
- added also the same fix on VipLlava architecturecc @amyeroberts
Fixes #28197, Fixes #27901