diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index be178a1d0750af..050d042de5417c 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -28,6 +28,7 @@ is_torch_available, is_vision_available, ) +from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -462,3 +463,55 @@ def test_small_model_integration_test_batch(self): EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_batch_different_resolutions(self): + model = LlavaNextForConditionalGeneration.from_pretrained( + "llava-hf/llava-v1.6-mistral-7b-hf", + load_in_4bit=True, + ) + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e" + cats_image = Image.open(requests.get(url, stream=True).raw) + lowres_img = Image.open(requests.get(lowres_url, stream=True).raw) + + inputs = self.processor( + [self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True + ).to(torch_device) + pixel_values = inputs["pixel_values"] + + # verify pixel values are padded correctly with 0 when one image has more num_patches than the other + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=model.config.image_grid_pinpoints, + patch_size=model.config.vision_config.image_size, + ) + for imsize in inputs["image_sizes"] + ] + for pix_val, num_patch in zip(pixel_values, image_num_patches): + self.assertTrue(torch.all(pix_val[num_patch:] == 0)) + + # check loss when labels are passed + inputs["labels"] = inputs["input_ids"].clone() + with torch.no_grad(): + output = model(**inputs) + + expected_slice = torch.tensor( + [[-4.7695, -4.5664, -0.2786], [-10.6250, -10.8906, -2.5234], [-6.7383, -7.2422, -0.6699]], + dtype=torch.float32, + device=torch_device, + ) + assert torch.allclose(output.logits[0, :3, :3], expected_slice, atol=1e-3) + assert torch.allclose(output.loss, torch.tensor(6.5171, device=torch_device)) + + # verify generation + output = model.generate(**inputs, max_new_tokens=50) + EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST]\nThe image shows a forested area with a misty or foggy atmosphere. In the foreground, there is a grassy field with a few small plants or flowers. In the background, there are trees and what appears to be a deer' # fmt: skip + + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + )