From 940a6bd343cfd2ff4f4425b4cbc548d1e1d316da Mon Sep 17 00:00:00 2001 From: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Date: Thu, 24 Oct 2024 20:00:13 -0400 Subject: [PATCH] Use non nested images and batched text Idefics2/3 (#34222) * add support for non nested images and add tests * add tests error scenario * fix style * added single and no image to error tests --- .../idefics2/image_processing_idefics2.py | 1 + .../models/idefics2/processing_idefics2.py | 17 +++- .../idefics3/image_processing_idefics3.py | 3 + .../models/idefics3/processing_idefics3.py | 38 ++++++--- .../pixtral/image_processing_pixtral.py | 1 + .../idefics2/test_processor_idefics2.py | 77 +++++++++++++++--- .../idefics3/test_processor_idefics3.py | 79 ++++++++++++++++--- 7 files changed, 183 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/idefics2/image_processing_idefics2.py b/src/transformers/models/idefics2/image_processing_idefics2.py index ac9df68871eee2..ce0032f80c5ece 100644 --- a/src/transformers/models/idefics2/image_processing_idefics2.py +++ b/src/transformers/models/idefics2/image_processing_idefics2.py @@ -99,6 +99,7 @@ def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]: isinstance(images, (list, tuple)) and len(images) > 0 and isinstance(images[0], (list, tuple)) + and len(images[0]) > 0 and is_valid_image(images[0][0]) ): pass diff --git a/src/transformers/models/idefics2/processing_idefics2.py b/src/transformers/models/idefics2/processing_idefics2.py index 68566d182678c2..9a041257c36b5b 100644 --- a/src/transformers/models/idefics2/processing_idefics2.py +++ b/src/transformers/models/idefics2/processing_idefics2.py @@ -16,6 +16,7 @@ Processor class for IDEFICS2. """ +from itertools import accumulate from typing import TYPE_CHECKING, List, Optional, Union from ...feature_extraction_utils import BatchFeature @@ -218,7 +219,21 @@ def __call__( if is_image_or_image_url(images): images = [[images]] elif isinstance(images, list) and is_image_or_image_url(images[0]): - images = [images] + if text is not None: + if sum(n_images_in_text) != len(images): + raise ValueError( + f"The total number of {image_token} tokens in the prompts should be the same as the number of images passed." + f" Found {sum(n_images_in_text)} {image_token} tokens and {len(images)} images." + ) + # Reorganize the images to match the prompts + cumsum_images_in_text = [0] + list(accumulate(n_images_in_text)) + images = [ + images[cumsum_images_in_text[i] : cumsum_images_in_text[i + 1]] + for i in range(len(n_images_in_text)) + ] + else: + images = [images] + elif ( not isinstance(images, list) and not isinstance(images[0], list) diff --git a/src/transformers/models/idefics3/image_processing_idefics3.py b/src/transformers/models/idefics3/image_processing_idefics3.py index 495ac04595fbc6..05a1a396dc72d3 100644 --- a/src/transformers/models/idefics3/image_processing_idefics3.py +++ b/src/transformers/models/idefics3/image_processing_idefics3.py @@ -151,9 +151,11 @@ def get_resize_output_image_size( def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]: """ Convert a single image or a list of images to a list of numpy arrays. + Args: images (`ImageInput`): A single image or a list of images. + Returns: A list of numpy arrays. """ @@ -168,6 +170,7 @@ def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]: isinstance(images, (list, tuple)) and len(images) > 0 and isinstance(images[0], (list, tuple)) + and len(images[0]) > 0 and is_valid_image(images[0][0]) ): pass diff --git a/src/transformers/models/idefics3/processing_idefics3.py b/src/transformers/models/idefics3/processing_idefics3.py index ceafa26a8b1187..872f5206f20175 100644 --- a/src/transformers/models/idefics3/processing_idefics3.py +++ b/src/transformers/models/idefics3/processing_idefics3.py @@ -17,6 +17,7 @@ """ import re +from itertools import accumulate from typing import TYPE_CHECKING, Dict, List, Optional, Union from ...feature_extraction_utils import BatchFeature @@ -241,11 +242,31 @@ def __call__( n_images_in_images = [] inputs = BatchFeature() + if text is not None: + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + n_images_in_text = [sample.count(self.image_token.content) for sample in text] + if images is not None: if is_image_or_image_url(images): images = [[images]] elif isinstance(images, list) and is_image_or_image_url(images[0]): - images = [images] + if text is not None: + if sum(n_images_in_text) != len(images): + raise ValueError( + f"The total number of {self.image_token.content} tokens in the prompts should be the same as the number of images passed." + f" Found {sum(n_images_in_text)} {self.image_token.content} tokens and {len(images)} images." + ) + # Reorganize the images to match the prompts + cumsum_images_in_text = [0] + list(accumulate(n_images_in_text)) + images = [ + images[cumsum_images_in_text[i] : cumsum_images_in_text[i + 1]] + for i in range(len(n_images_in_text)) + ] + else: + images = [images] elif ( not isinstance(images, list) and not isinstance(images[0], list) @@ -263,10 +284,10 @@ def __call__( inputs.update(image_inputs) if text is not None: - if isinstance(text, str): - text = [text] - elif not isinstance(text, list) and not isinstance(text[0], str): - raise ValueError("Invalid input text. Please provide a string, or a list of strings") + if n_images_in_images != n_images_in_text: + raise ValueError( + f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same." + ) image_rows = inputs.pop("rows", [[0] * len(text)]) image_cols = inputs.pop("cols", [[0] * len(text)]) @@ -277,8 +298,6 @@ def __call__( prompt_strings = [] for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): - n_images_in_text.append(sample.count(image_token)) - # Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len` image_prompt_strings = [] for n_rows, n_cols in zip(sample_rows, sample_cols): @@ -305,11 +324,6 @@ def __call__( text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"]) inputs.update(text_inputs) - if n_images_in_images != n_images_in_text: - raise ValueError( - f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same." - ) - return inputs def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py index a75704fc3dbac8..b4ec0e50c9ccc3 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral.py +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -120,6 +120,7 @@ def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]: isinstance(images, (list, tuple)) and len(images) > 0 and isinstance(images[0], (list, tuple)) + and len(images[0]) > 0 and is_valid_image(images[0][0]) ): pass diff --git a/tests/models/idefics2/test_processor_idefics2.py b/tests/models/idefics2/test_processor_idefics2.py index bf713c6fb8cfbb..d89004679aef0f 100644 --- a/tests/models/idefics2/test_processor_idefics2.py +++ b/tests/models/idefics2/test_processor_idefics2.py @@ -226,6 +226,73 @@ def test_add_special_tokens_processor(self): self.assertEqual(inputs["input_ids"], expected_input_ids) # fmt: on + def test_non_nested_images_with_batched_text(self): + processor = self.get_processor() + processor.image_processor.do_image_splitting = False + + image_str = "" + text_str_1 = "In this image, we see" + text_str_2 = "bla, bla" + + text = [ + image_str + text_str_1, + text_str_2 + image_str + image_str, + ] + images = [self.image1, self.image2, self.image3] + + inputs = processor(text=text, images=images, padding=True) + + self.assertEqual(inputs["pixel_values"].shape, (2, 2, 3, 767, 980)) + self.assertEqual(inputs["pixel_attention_mask"].shape, (2, 2, 767, 980)) + + def test_process_interleaved_images_prompts_image_error(self): + processor = self.get_processor() + + text = [ + "This is a test sentence.", + "In this other sentence we try some good things", + ] + images = [[self.image1], [self.image2]] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [[self.image1], []] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + + text = [ + "This is a test sentence.", + "In this other sentence we try some good things", + ] + images = [[self.image1], [self.image2, self.image3]] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [[], [self.image2]] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [self.image1, self.image2, self.image3] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [self.image1] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + + text = [ + "This is a test sentence.", + "In this other sentence we try some good things", + ] + images = [[self.image1], []] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [[], [self.image2]] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [self.image1, self.image2] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [self.image1] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + def test_apply_chat_template(self): # Message contains content which a mix of lists with images and image urls and string messages = [ @@ -275,13 +342,3 @@ def prepare_text_inputs(self, batch_size: Optional[int] = None): return ["lower newer ", " upper older longer string"] + [" lower newer"] * ( batch_size - 2 ) - - # Override as PixtralProcessor needs nested images to work properly with batched inputs - @require_vision - def prepare_image_inputs(self, batch_size: Optional[int] = None): - """This function prepares a list of PIL images for testing""" - if batch_size is None: - return super().prepare_image_inputs() - if batch_size < 1: - raise ValueError("batch_size must be greater than 0") - return [[super().prepare_image_inputs()]] * batch_size diff --git a/tests/models/idefics3/test_processor_idefics3.py b/tests/models/idefics3/test_processor_idefics3.py index a53109b02b6951..52d2f1539a4867 100644 --- a/tests/models/idefics3/test_processor_idefics3.py +++ b/tests/models/idefics3/test_processor_idefics3.py @@ -250,6 +250,74 @@ def test_add_special_tokens_processor(self): self.assertEqual(inputs["input_ids"], expected_input_ids) # fmt: on + def test_non_nested_images_with_batched_text(self): + processor = self.get_processor() + processor.image_processor.do_image_splitting = False + + image_str = "" + text_str_1 = "In this image, we see" + text_str_2 = "In this image, we see" + + text = [ + image_str + text_str_1, + image_str + image_str + text_str_2, + ] + images = [self.image1, self.image2, self.image3] + + inputs = processor(text=text, images=images, padding=True) + + self.assertEqual(np.array(inputs["pixel_values"]).shape, (2, 2, 3, 364, 364)) + self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (2, 2, 364, 364)) + + # Copied from tests.models.idefics2.test_processor_idefics2.Idefics2ProcessorTest.test_process_interleaved_images_prompts_image_error + def test_process_interleaved_images_prompts_image_error(self): + processor = self.get_processor() + + text = [ + "This is a test sentence.", + "In this other sentence we try some good things", + ] + images = [[self.image1], [self.image2]] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [[self.image1], []] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + + text = [ + "This is a test sentence.", + "In this other sentence we try some good things", + ] + images = [[self.image1], [self.image2, self.image3]] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [[], [self.image2]] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [self.image1, self.image2, self.image3] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [self.image1] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + + text = [ + "This is a test sentence.", + "In this other sentence we try some good things", + ] + images = [[self.image1], []] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [[], [self.image2]] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [self.image1, self.image2] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + images = [self.image1] + with self.assertRaises(ValueError): + processor(text=text, images=images, padding=True) + def test_apply_chat_template(self): # Message contains content which a mix of lists with images and image urls and string messages = [ @@ -299,16 +367,7 @@ def prepare_text_inputs(self, batch_size: Optional[int] = None): batch_size - 2 ) - # Override as Idefics3Processor needs nested images to work properly with batched inputs - @require_vision - def prepare_image_inputs(self, batch_size: Optional[int] = None): - """This function prepares a list of PIL images for testing""" - if batch_size is None: - return super().prepare_image_inputs() - if batch_size < 1: - raise ValueError("batch_size must be greater than 0") - return [[super().prepare_image_inputs()]] * batch_size - + # Override tests as inputs_ids padded dimension is the second one but not the last one @require_vision @require_torch def test_kwargs_overrides_default_tokenizer_kwargs(self):