Skip to content

Commit

Permalink
add slow tests for batched inference
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp committed Apr 25, 2024
1 parent 1e07818 commit bff40e0
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions tests/models/llava_next/test_modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_torch_available,
is_vision_available,
)
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device

from ...generation.test_utils import GenerationTesterMixin
Expand Down Expand Up @@ -462,3 +463,55 @@ def test_small_model_integration_test_batch(self):

EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
@require_bitsandbytes
def test_small_model_integration_test_batch_different_resolutions(self):
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
cats_image = Image.open(requests.get(url, stream=True).raw)
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)

inputs = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
).to(torch_device)
pixel_values = inputs["pixel_values"]

# verify pixel values are padded correctly with 0 when one image has more num_patches than the other
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=model.config.image_grid_pinpoints,
patch_size=model.config.vision_config.image_size,
)
for imsize in inputs["image_sizes"]
]
for pix_val, num_patch in zip(pixel_values, image_num_patches):
self.assertTrue(torch.all(pix_val[num_patch:] == 0))

# check loss when labels are passed
inputs["labels"] = inputs["input_ids"].clone()
with torch.no_grad():
output = model(**inputs)

expected_slice = torch.tensor(
[[-4.7695, -4.5664, -0.2786], [-10.6250, -10.8906, -2.5234], [-6.7383, -7.2422, -0.6699]],
dtype=torch.float32,
device=torch_device,
)
assert torch.allclose(output.logits[0, :3, :3], expected_slice, atol=1e-3)
assert torch.allclose(output.loss, torch.tensor(6.5171, device=torch_device))

# verify generation
output = model.generate(**inputs, max_new_tokens=50)
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST]\nThe image shows a forested area with a misty or foggy atmosphere. In the foreground, there is a grassy field with a few small plants or flowers. In the background, there are trees and what appears to be a deer' # fmt: skip

self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

0 comments on commit bff40e0

Please sign in to comment.