-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Strange behavior with attn_implementation="eager" #35270
Comments
I can reproduce the error with my own images, and even after removing all the legacy code the behavior persists. Interestingly generating from only text works well with eager attention, and seems that the weird behavior comes from concatenating images. So I tried to load the vision model on eager attention while keeping the text backbone in sdpa with the following code. The generated text matched very well independently of whether the text model has sdpa or eager. model: LlavaForConditionalGeneration = LlavaForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="cuda:0",
attn_implementation={"text_config": "sdpa", "vision_config": "eager"},
) Also increasing the precision to |
Thank you for your reply. Could you please advise on how I should proceed to resolve this issue? Can I simply change the precision to bfloat16 to continue my operations? However, I noticed that LLaVA 1.5 was trained using float16. |
@pspdada depending on what you are trying to achieve, you are free to use any of the two options I outlined above, and then simply report the weird behavior you observed if it's some kind of research. I don't see any string preference for any of workarounds, but forcing "vision" model to eager will be more in-line with the release in original LLaVA-VL repo SDPA in CLIP-like models was added long time after llava models were released so I think the authors have been running on eager attention in CLIP all the time |
I found that the behavior I observed is somewhat different from what you described. I tried loading the model using the method you suggested: model: LlavaForConditionalGeneration = LlavaForConditionalGeneration.from_pretrained(
model_name,
cache_dir="/root/llm/utils/models/hub",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="cuda:0",
attn_implementation={"text_config": "sdpa", "vision_config": "eager"},
) This results in the following warning:
Additionally, this setup leads to errors (generating ids=0). When I tried using: {"text_config": "eager", "vision_config": "eager"} the warning does not appear, but the generation still results in errors (generating ids=0). Only by changing the dtype to |
@zucchini-nlp You mentioned that the "eager" mode is the one used during the training of LLaVA, but using this mode along with |
Oh sorry, I forgot that I used a different branch to make text-only generation and verify the generation quality. You're right, the above doesn't fix it and seems to be same error as in #34824. The precision change works though however I'd recommend to wait until the PR is merged because otherwise it will not match the original impl
btw, what type of error you mean? @ArthurZucker can we merge the VLMs clean up PR soon to fix the recurring errors? (#34502) |
I'm very sorry for not being clear enough, which has caused you confusion. By "Error," I refer to incorrect generation results. In some scenarios (image-text pairs), the generation process will complete but return a
From my observations, regardless of setting The issue is only resolved when changing the data type to |
Let's wait for the linked PR to be merged then, as currently llava does some past-kv manipulations due to incorrect indentation. The CUDA error you see is probably from sampling when the model for some reason outputs |
I did some testing of this right now (for the same images OP used) and I found that if the two images are run separately, as opposed to being in a single batch, the output seems to work correctly. |
Running code from pr-34502 did not solve it when testing on my local machine. I got the same output as @pspdada (it skips describing the first image and only describes the second). Attached is the code to reproduce it if desired, it is very similar to his, except for some adaptations so that I could more easily run locally without any extra setup. The dataset was obtained from this repository and the images correspond to his (a bathroom with a plant is id 339761 and a snowboader is id 431256), although I think this problem is not specific to these images. |
@zucchini-nlp I found that even when loading the model with a = [
{
"image_id": "339761.jpg",
"image_path": "/root/llm/utils/eval/Object_HalBench/images/339761.jpg",
"question": "Provide a thorough description of the given image.",
},
{
"image_id": "431256.jpg",
"image_path": "/root/llm/utils/eval/Object_HalBench/images/431256.jpg",
"question": "What is this photo about? Please answer in great detail.",
},
{
"image_id": "501400.jpg",
"image_path": "/root/llm/utils/eval/Object_HalBench/images/501400.jpg",
"question": "Provide a thorough description of the given picture.",
},
{
"image_id": "264619.jpg",
"image_path": "/root/llm/utils/eval/Object_HalBench/images/264619.jpg",
"question": "Explain the narrative or story that the image seems to convey, detailing each part that contributes to it.",
},
{
"image_id": "551791.jpg",
"image_path": "/root/llm/utils/eval/Object_HalBench/images/551791.jpg",
"question": "Please provide a detailed description of the image. Describe the visual elements, colors, shapes, textures, and any objects or people present along with the overall mood or atmosphere portrayed in the image.",
},
]
images = [Image.open(d["image_path"]) for d in a]
users = [d["question"] for d in a] The output is:
|
@zucchini-nlp It's strange that even without any additional configurations, using the suggested code from https://huggingface.co/docs/transformers/main/en/model_doc/llava#batched-inference directly also results in issues. import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
# Load the model in half-precision
model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
cache_dir="/root/llm/utils/models/hub",
torch_dtype=torch.float16,
device_map="auto",
)
processor = AutoProcessor.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
cache_dir="/root/llm/utils/models/hub",
padding_side="left",
)
a = [
{
"image_id": "339761.jpg",
"image_path": "/root/llm/utils/eval/Object_HalBench/images/339761.jpg",
"question": "Provide a thorough description of the given image.",
},
{
"image_id": "431256.jpg",
"image_path": "/root/llm/utils/eval/Object_HalBench/images/431256.jpg",
"question": "What is this photo about? Please answer in great detail.",
},
{
"image_id": "501400.jpg",
"image_path": "/root/llm/utils/eval/Object_HalBench/images/501400.jpg",
"question": "Provide a thorough description of the given picture.",
},
{
"image_id": "264619.jpg",
"image_path": "/root/llm/utils/eval/Object_HalBench/images/264619.jpg",
"question": "Explain the narrative or story that the image seems to convey, detailing each part that contributes to it.",
},
{
"image_id": "551791.jpg",
"image_path": "/root/llm/utils/eval/Object_HalBench/images/551791.jpg",
"question": "Please provide a detailed description of the image. Describe the visual elements, colors, shapes, textures, and any objects or people present along with the overall mood or atmosphere portrayed in the image.",
},
]
images = [Image.open(d["image_path"]) for d in a]
users = [d["question"] for d in a]
prompts: list[str] = []
for u in users:
conversation: list[dict[str]] = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": u},
],
},
]
prompt: str = processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True,
)
prompts.append(prompt)
inputs = processor(images=images, text=prompts, padding=True, return_tensors="pt").to(model.device, torch.float16)
generate_ids = model.generate(**inputs, max_new_tokens=50)
a = processor.batch_decode(generate_ids, skip_special_tokens=True)
print(a) The output is:
The bug persists in versions 4.46.3 and 4.47.0 of the transformers library. However, changing the model to another one, such as Qwen2-VL, yields correct output. |
BTW, do you know if this is a regression (was this broken recently or not?) |
I have been looking at this for a while and intend to look more in the coming days. My tests were all with the images from the original message (339761 and 431256 from
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) This segment takes input token positions where all entries are masked---making them essentially useless, as the operations are discarded---and removes the masking, "activating" them again. One such token position where all entries are masked is exactly the padding token. In case we run I think a possible solution would be to simply multiply the V relative to the |
As for the latest @pspdada sample code where the it breaks using the 5 examples, I just ran it (fully on a GPU setup that fits the whole batch): although it breaks on the main branch, it works correctly in the PR by @zucchini-nlp . The output I get there is:
Which seems to be correct---in fact, the two last descriptions match the ones posted above. The summary of my tests using both the two case example given above and the five example are given below. When it says main, I used the main branch from @zucchini-nlp fork, but, since the PR was not merged yet and my results were consistent with @pspdada, I think this should be fine.
|
@hsilva664 Thank you for conducting such a detailed and in-depth study on the question I raised. I apologize for not having enough time and energy to delve deeply into this issue myself. I would like to know if your proposed solution and PR-34502 can be used together to address the examples with two images and five images mentioned above? Additionally, have there been more examples tested to verify the correctness and stability of the model during batch inference in the eager mode? |
@pspdada They can be used together. I have created a PR as a candidate to be merged to the PR by @zucchini-nlp , feel free to test the behaviour on the modified code (it is linked above). I have not run extensive testing beyond the ones discussed here. |
System Info
transformers
version: 4.47.0Who can help?
@zucchini-nlp
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I am trying to analyze the attention pattern of the
LLAVA v1.5 7B
model, so I usedattn_implementation="eager"
when initing the model to obtain the attention weights. However, this has led to several issues. Firstly, the output IDs are incorrect, and secondly, errors may occur. I've noticed that this problem only appears with specific images and user prompts, while it does not occur in other cases, which is quite peculiar. Below is my code:Notice: the image I used is from Object_HalBench benchmark
The output is:
Some other warning:
The generated_ids: (I remove a large number of
<image>
token for readability)Notice that
<image>
: 32000,<pad>
: 32001The output after
batch_decode
is:It's strange that there is token id 0 generated.
Only set
output_attentions=False
andreturn_dict_in_generate=False
without removingattn_implementation="eager",
won't make any change.Notice that removing
attn_implementation="eager"
, and not returning dict overcome this question, the output then become correct:Beside this, some error may occur with
attn_implementation="eager"
, in some other case (different Image input)Expected behavior
fix it
The text was updated successfully, but these errors were encountered: