Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix pixtral processor #34486

Merged
merged 21 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/models/pixtral/configuration_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class PixtralVisionConfig(PretrainedConfig):
Dropout probability for the attention layers.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

Example:

Expand Down Expand Up @@ -82,6 +84,7 @@ def __init__(
hidden_act="gelu",
attention_dropout=0.0,
rope_theta=10000.0,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -97,3 +100,4 @@ def __init__(
self.hidden_act = hidden_act
self.rope_theta = rope_theta
self.head_dim = hidden_size // num_attention_heads
self.initializer_range = initializer_range
2 changes: 1 addition & 1 deletion src/transformers/models/pixtral/modeling_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
else self.config.initializer_range
)

if isinstance(module, (nn.Linear, nn.Conv2d)):
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/pixtral/processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,15 @@ def __call__(
if is_image_or_image_url(images):
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
images = [images]
elif (
not isinstance(images, list)
and not isinstance(images[0], list)
and not is_image_or_image_url(images[0][0])
):
if isinstance(text, list):
images = [[im] for im in images]
else:
images = [images]
elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
pass
else:
raise ValueError(
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
"Invalid input images. Please provide a single image, a list of images, or a list of lists of images."
)
images = [[load_image(im) for im in sample] for sample in images]
image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
Expand Down
41 changes: 2 additions & 39 deletions tests/models/pixtral/test_modeling_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,16 @@
# limitations under the License.
"""Testing suite for the PyTorch Pixtral model."""

import gc
import unittest

import requests

from transformers import (
AutoProcessor,
PixtralVisionConfig,
PixtralVisionModel,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
slow,
torch_device,
)

Expand All @@ -43,7 +37,7 @@
is_torch_greater_or_equal_than_2_0 = False

if is_vision_available():
from PIL import Image
pass


class PixtralVisionModelTester:
Expand Down Expand Up @@ -148,6 +142,7 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
test_torchscript = False

def setUp(self):
self.model_tester = PixtralVisionModelTester(self)
Expand Down Expand Up @@ -258,35 +253,3 @@ def test_disk_offload_safetensors(self):
@unittest.skip(reason="Not supported yet")
def test_determinism(self):
pass


@require_torch
class PixtralVisionModelIntegrationTest(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did this one have to go away ? 😓

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

many reasons, it was looking up an image url that did not exist (pixtral-vl instead of llava-vl) to predict something that pixtral model does not predict, also it was testing VisionModel that does not support generate

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep it needs to be in the Llava tests! Can you move it around 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def test_pixtral(self):
model_id = "hf-internal-testing/pixtral-12b"
model = LlavaForConditionalGeneration.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
IMG_URLS = [
Image.open(requests.get("https://picsum.photos/id/237/400/300", stream=True).raw),
Image.open(requests.get("https://picsum.photos/id/231/200/300", stream=True).raw),
Image.open(requests.get("https://picsum.photos/id/27/500/500", stream=True).raw),
Image.open(requests.get("https://picsum.photos/id/17/150/600", stream=True).raw),
]
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"
# image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
generate_ids = model.generate(**inputs, max_new_tokens=500)
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# fmt: off
EXPECTED_GENERATION = """
Describe the images.
Sure, let's break down each image description:
1. **Image 1:**
- **Description:** A black dog with a glossy coat is sitting on a wooden floor. The dog has a focused expression and is looking directly at the camera.
- **Details:** The wooden floor has a rustic appearance with visible wood grain patterns. The dog's eyes are a striking color, possibly brown or amber, which contrasts with its black fur.
2. **Image 2:**
- **Description:** A scenic view of a mountainous landscape with a winding road cutting through it. The road is surrounded by lush green vegetation and leads to a distant valley.
- **Details:** The mountains are rugged with steep slopes, and the sky is clear, indicating good weather. The winding road adds a sense of depth and perspective to the image.
3. **Image 3:**
- **Description:** A beach scene with waves crashing against the shore. There are several people in the water and on the beach, enjoying the waves and the sunset.
- **Details:** The waves are powerful, creating a dynamic and lively atmosphere. The sky is painted with hues of orange and pink from the setting sun, adding a warm glow to the scene.
4. **Image 4:**
- **Description:** A garden path leading to a large tree with a bench underneath it. The path is bordered by well-maintained grass and flowers.
- **Details:** The path is made of small stones or gravel, and the tree provides a shaded area with the bench invitingly placed beneath it. The surrounding area is lush and green, suggesting a well-kept garden.
Each image captures a different scene, from a close-up of a dog to expansive natural landscapes, showcasing various elements of nature and human interaction with it.
"""
# fmt: on
# check that both inputs are handled correctly and generate the same output
self.assertListEqual(ouptut, EXPECTED_GENERATION)

looks like pixtral is already tested under llava tests properly - since it's a llava model it's enough to test it there, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah then good job and sorry!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries- it's hard to see when tests are measuring a model through another one, but it's convenient on our side. I can move all pixtral related tests to pixtral testing so we don't forget it, wdyt? It's a bit more loc but it's easier to track down

def setUp(self):
self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

@slow
@require_bitsandbytes
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)

prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
image_file = "https://pixtral-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(prompt, raw_image, return_tensors="pt")

EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))

output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip

self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
21 changes: 20 additions & 1 deletion tests/models/pixtral/test_processor_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_processor_with_multiple_images_single_list(self):
input_ids[0].tolist(),
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
)
# fmt: on

# Test passing in a url
Expand Down Expand Up @@ -246,6 +246,25 @@ def test_processor_with_multiple_images_multiple_lists(self):
)
# fmt: on

def test_processor_returns_full_length_batches(self):
# to avoid https://github.com/huggingface/transformers/issues/34204
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = [
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
] * 5
processor.tokenizer.pad_token = "</s>"
image_inputs = [self.image_0] * 5

# Make small for checking image token expansion
processor.image_processor.size = {"longest_edge": 30}
processor.image_processor.patch_size = {"height": 2, "width": 2}

# Test passing in an image
inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 5)
self.assertTrue(len(inputs_image["pixel_values"]) == 5)

# Override as PixtralProcessor needs nested images to work properly with batched inputs
@require_vision
def prepare_image_inputs(self, batch_size: Optional[int] = None):
Expand Down
Loading