Skip to content

Commit

Permalink
[Llava] Fix llava index errors (huggingface#28032)
Browse files Browse the repository at this point in the history
* fix llava index errors

* forward contrib credits from original implementation and fix

* better fix

* final fixes and fix all tests

* fix

* fix nit

* fix tests

* add regression tests

---------

Co-authored-by: gullalc <[email protected]>
  • Loading branch information
2 people authored and Saibo-creator committed Jan 4, 2024
1 parent 04234f6 commit 30493fa
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 12 deletions.
7 changes: 5 additions & 2 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,11 @@ def forward(
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0]
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]

# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_seqlen = first_layer_past_key_value.shape[-1] + 1

Expand Down
9 changes: 6 additions & 3 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,13 @@ def forward(
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0]
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)
first_layer_past_key_value = past_key_values[0][0][:, 0, :, :]

# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0)

# Get the target length
target_seqlen = first_layer_past_key_value.shape[-1] + 1
target_seqlen = first_layer_past_key_value.shape[-2] + 1

extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
Expand Down
63 changes: 56 additions & 7 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,20 +247,20 @@ def test_small_model_integration_test_llama_batched(self):
model_id = "llava-hf/llava-1.5-7b-hf"

model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id, pad_token="<pad>")
processor = AutoProcessor.from_pretrained(model_id)

prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT:",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)

inputs = processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True)
inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True)

output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: the water is calm and clear\n\nThe image shows a wooden pier on a lake, with a', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip

self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

Expand All @@ -272,14 +272,63 @@ def test_small_model_integration_test_batch(self):
# The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!.
prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT:",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)

inputs = self.processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True)
inputs = self.processor(prompts, images=[image1, image2], return_tensors="pt", padding=True)

output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ['\nUSER: What are the things I should be cautious about when I visit this place? What should I bring with me\nASSISTANT: When visiting this place, bring a camera, Park rules may apply, Per person, Sunrise over', '\nUSER: What is this?\nASSISTANT: Two cats lying on a bed!\nUSER: And this? a dock on a lake with two cats on it (Photo credit: Jocelyn R.'] # fmt: skip
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
@require_bitsandbytes
def test_small_model_integration_test_llama_batched_regression(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "llava-hf/llava-1.5-7b-hf"

# Multi-image & multi-prompt (e.g. 3 images and 2 prompts now fails with SDPA, this tests if "eager" works as before)
model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf", load_in_4bit=True, attn_implementation="eager"
)
processor = AutoProcessor.from_pretrained(model_id, pad_token="<pad>")

prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)

inputs = processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True)

output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip

self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
@require_bitsandbytes
def test_llava_index_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore
# Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for
# more details
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)

processor = AutoProcessor.from_pretrained(model_id)

# Simulate a super long prompt
user_prompt = "Describe the image:?\n" * 200
prompt = f"USER: <image>\n{user_prompt}ASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"

raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)

# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)

0 comments on commit 30493fa

Please sign in to comment.