Skip to content

Commit

Permalink
fix pixtral processor (#34486)
Browse files Browse the repository at this point in the history
* fix pixtral processor

* test out full length batches + remove undue ValueError

* fix up processing

* fix tests

* fix

* last fixup

* style

* [run-slow] pixtral

* [run-slow] pixtral

* fix config key

* skip torchscript tests

* [run-slow] pixtral

* add missing key

* [run-slow] pixtral

* fix docs

* [run-slow] pixtral

* fix wrong url for integration test

* [run-slow] pixtral

* pixtralVisionModel does not have a lm head

* [run-slow] pixtral
  • Loading branch information
molbap authored Oct 30, 2024
1 parent 8a734ea commit 241d790
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 48 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/pixtral/configuration_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class PixtralVisionConfig(PretrainedConfig):
Dropout probability for the attention layers.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
Expand Down Expand Up @@ -82,6 +84,7 @@ def __init__(
hidden_act="gelu",
attention_dropout=0.0,
rope_theta=10000.0,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -97,3 +100,4 @@ def __init__(
self.hidden_act = hidden_act
self.rope_theta = rope_theta
self.head_dim = hidden_size // num_attention_heads
self.initializer_range = initializer_range
2 changes: 1 addition & 1 deletion src/transformers/models/pixtral/modeling_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
else self.config.initializer_range
)

if isinstance(module, (nn.Linear, nn.Conv2d)):
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/pixtral/processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,15 @@ def __call__(
if is_image_or_image_url(images):
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
images = [images]
elif (
not isinstance(images, list)
and not isinstance(images[0], list)
and not is_image_or_image_url(images[0][0])
):
if isinstance(text, list):
images = [[im] for im in images]
else:
images = [images]
elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
pass
else:
raise ValueError(
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
"Invalid input images. Please provide a single image, a list of images, or a list of lists of images."
)
images = [[load_image(im) for im in sample] for sample in images]
image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
Expand Down
41 changes: 2 additions & 39 deletions tests/models/pixtral/test_modeling_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,16 @@
# limitations under the License.
"""Testing suite for the PyTorch Pixtral model."""

import gc
import unittest

import requests

from transformers import (
AutoProcessor,
PixtralVisionConfig,
PixtralVisionModel,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
slow,
torch_device,
)

Expand All @@ -43,7 +37,7 @@
is_torch_greater_or_equal_than_2_0 = False

if is_vision_available():
from PIL import Image
pass


class PixtralVisionModelTester:
Expand Down Expand Up @@ -148,6 +142,7 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
test_torchscript = False

def setUp(self):
self.model_tester = PixtralVisionModelTester(self)
Expand Down Expand Up @@ -258,35 +253,3 @@ def test_disk_offload_safetensors(self):
@unittest.skip(reason="Not supported yet")
def test_determinism(self):
pass


@require_torch
class PixtralVisionModelIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

@slow
@require_bitsandbytes
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)

prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
image_file = "https://pixtral-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(prompt, raw_image, return_tensors="pt")

EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))

output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip

self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
21 changes: 20 additions & 1 deletion tests/models/pixtral/test_processor_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_processor_with_multiple_images_single_list(self):
input_ids[0].tolist(),
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
)
# fmt: on

# Test passing in a url
Expand Down Expand Up @@ -246,6 +246,25 @@ def test_processor_with_multiple_images_multiple_lists(self):
)
# fmt: on

def test_processor_returns_full_length_batches(self):
# to avoid https://github.com/huggingface/transformers/issues/34204
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = [
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
] * 5
processor.tokenizer.pad_token = "</s>"
image_inputs = [self.image_0] * 5

# Make small for checking image token expansion
processor.image_processor.size = {"longest_edge": 30}
processor.image_processor.patch_size = {"height": 2, "width": 2}

# Test passing in an image
inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 5)
self.assertTrue(len(inputs_image["pixel_values"]) == 5)

# Override as PixtralProcessor needs nested images to work properly with batched inputs
@require_vision
def prepare_image_inputs(self, batch_size: Optional[int] = None):
Expand Down

0 comments on commit 241d790

Please sign in to comment.