From 241d79026f1030124dbb957de936b3d617b621f2 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Wed, 30 Oct 2024 14:17:20 +0100 Subject: [PATCH] fix pixtral processor (#34486) * fix pixtral processor * test out full length batches + remove undue ValueError * fix up processing * fix tests * fix * last fixup * style * [run-slow] pixtral * [run-slow] pixtral * fix config key * skip torchscript tests * [run-slow] pixtral * add missing key * [run-slow] pixtral * fix docs * [run-slow] pixtral * fix wrong url for integration test * [run-slow] pixtral * pixtralVisionModel does not have a lm head * [run-slow] pixtral --- .../models/pixtral/configuration_pixtral.py | 4 ++ .../models/pixtral/modeling_pixtral.py | 2 +- .../models/pixtral/processing_pixtral.py | 15 +++---- tests/models/pixtral/test_modeling_pixtral.py | 41 +------------------ .../models/pixtral/test_processor_pixtral.py | 21 +++++++++- 5 files changed, 35 insertions(+), 48 deletions(-) diff --git a/src/transformers/models/pixtral/configuration_pixtral.py b/src/transformers/models/pixtral/configuration_pixtral.py index 32325a929411ba..14db51b947e664 100644 --- a/src/transformers/models/pixtral/configuration_pixtral.py +++ b/src/transformers/models/pixtral/configuration_pixtral.py @@ -52,6 +52,8 @@ class PixtralVisionConfig(PretrainedConfig): Dropout probability for the attention layers. rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. Example: @@ -82,6 +84,7 @@ def __init__( hidden_act="gelu", attention_dropout=0.0, rope_theta=10000.0, + initializer_range=0.02, **kwargs, ): super().__init__(**kwargs) @@ -97,3 +100,4 @@ def __init__( self.hidden_act = hidden_act self.rope_theta = rope_theta self.head_dim = hidden_size // num_attention_heads + self.initializer_range = initializer_range diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 06b9701a75661a..b65fbd634ba789 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -407,7 +407,7 @@ def _init_weights(self, module): std = ( self.config.initializer_range if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range + else self.config.initializer_range ) if isinstance(module, (nn.Linear, nn.Conv2d)): diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 70d28fb7b79c93..5913e8688d00be 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -206,14 +206,15 @@ def __call__( if is_image_or_image_url(images): images = [[images]] elif isinstance(images, list) and is_image_or_image_url(images[0]): - images = [images] - elif ( - not isinstance(images, list) - and not isinstance(images[0], list) - and not is_image_or_image_url(images[0][0]) - ): + if isinstance(text, list): + images = [[im] for im in images] + else: + images = [images] + elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]): + pass + else: raise ValueError( - "Invalid input images. Please provide a single image or a list of images or a list of list of images." + "Invalid input images. Please provide a single image, a list of images, or a list of lists of images." ) images = [[load_image(im) for im in sample] for sample in images] image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"]) diff --git a/tests/models/pixtral/test_modeling_pixtral.py b/tests/models/pixtral/test_modeling_pixtral.py index 9a128f6ad28823..0c36cb5a4e0554 100644 --- a/tests/models/pixtral/test_modeling_pixtral.py +++ b/tests/models/pixtral/test_modeling_pixtral.py @@ -14,22 +14,16 @@ # limitations under the License. """Testing suite for the PyTorch Pixtral model.""" -import gc import unittest -import requests - from transformers import ( - AutoProcessor, PixtralVisionConfig, PixtralVisionModel, is_torch_available, is_vision_available, ) from transformers.testing_utils import ( - require_bitsandbytes, require_torch, - slow, torch_device, ) @@ -43,7 +37,7 @@ is_torch_greater_or_equal_than_2_0 = False if is_vision_available(): - from PIL import Image + pass class PixtralVisionModelTester: @@ -148,6 +142,7 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (PixtralVisionModel,) if is_torch_available() else () test_pruning = False test_head_masking = False + test_torchscript = False def setUp(self): self.model_tester = PixtralVisionModelTester(self) @@ -258,35 +253,3 @@ def test_disk_offload_safetensors(self): @unittest.skip(reason="Not supported yet") def test_determinism(self): pass - - -@require_torch -class PixtralVisionModelIntegrationTest(unittest.TestCase): - def setUp(self): - self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b") - - def tearDown(self): - gc.collect() - torch.cuda.empty_cache() - - @slow - @require_bitsandbytes - def test_small_model_integration_test(self): - # Let' s make sure we test the preprocessing to replace what is used - model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True) - - prompt = "[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]" - image_file = "https://pixtral-vl.github.io/static/images/view.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = self.processor(prompt, raw_image, return_tensors="pt") - - EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip - self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) - - output = model.generate(**inputs, max_new_tokens=20) - EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip - - self.assertEqual( - self.processor.decode(output[0], skip_special_tokens=True), - EXPECTED_DECODED_TEXT, - ) diff --git a/tests/models/pixtral/test_processor_pixtral.py b/tests/models/pixtral/test_processor_pixtral.py index 8cdbf93c6476b8..c3496dff3cdf81 100644 --- a/tests/models/pixtral/test_processor_pixtral.py +++ b/tests/models/pixtral/test_processor_pixtral.py @@ -171,7 +171,7 @@ def test_processor_with_multiple_images_single_list(self): input_ids[0].tolist(), # Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"] [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] - ) + ) # fmt: on # Test passing in a url @@ -246,6 +246,25 @@ def test_processor_with_multiple_images_multiple_lists(self): ) # fmt: on + def test_processor_returns_full_length_batches(self): + # to avoid https://github.com/huggingface/transformers/issues/34204 + processor = self.processor_class.from_pretrained(self.tmpdirname) + prompt_string = [ + "USER: [IMG]\nWhat's the content of the image? ASSISTANT:", + ] * 5 + processor.tokenizer.pad_token = "" + image_inputs = [self.image_0] * 5 + + # Make small for checking image token expansion + processor.image_processor.size = {"longest_edge": 30} + processor.image_processor.patch_size = {"height": 2, "width": 2} + + # Test passing in an image + inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True) + self.assertIn("input_ids", inputs_image) + self.assertTrue(len(inputs_image["input_ids"]) == 5) + self.assertTrue(len(inputs_image["pixel_values"]) == 5) + # Override as PixtralProcessor needs nested images to work properly with batched inputs @require_vision def prepare_image_inputs(self, batch_size: Optional[int] = None):