diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 5913e8688d00be..9fb1ff485361db 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -204,12 +204,21 @@ def __call__( if images is not None: if is_image_or_image_url(images): - images = [[images]] - elif isinstance(images, list) and is_image_or_image_url(images[0]): - if isinstance(text, list): - images = [[im] for im in images] + if isinstance(text, str) or isinstance(text, list) and len(text) == 1: + # If there's a single sample, the image must belong to it + images = [[images]] else: + raise ValueError( + "You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample." + ) + elif isinstance(images, list) and is_image_or_image_url(images[0]): + if isinstance(text, str) or isinstance(text, list) and len(text) == 1: + # If there's a single sample, all images must belong to it images = [images] + else: + raise ValueError( + "You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample." + ) elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]): pass else: diff --git a/tests/models/pixtral/test_processor_pixtral.py b/tests/models/pixtral/test_processor_pixtral.py index c3496dff3cdf81..d224c531241fa7 100644 --- a/tests/models/pixtral/test_processor_pixtral.py +++ b/tests/models/pixtral/test_processor_pixtral.py @@ -253,7 +253,7 @@ def test_processor_returns_full_length_batches(self): "USER: [IMG]\nWhat's the content of the image? ASSISTANT:", ] * 5 processor.tokenizer.pad_token = "" - image_inputs = [self.image_0] * 5 + image_inputs = [[self.image_0]] * 5 # Make small for checking image token expansion processor.image_processor.size = {"longest_edge": 30}