diff --git a/src/transformers/models/idefics3/processing_idefics3.py b/src/transformers/models/idefics3/processing_idefics3.py index 872f5206f20175..7ca5829e2063d8 100644 --- a/src/transformers/models/idefics3/processing_idefics3.py +++ b/src/transformers/models/idefics3/processing_idefics3.py @@ -283,45 +283,53 @@ def __call__( image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) inputs.update(image_inputs) - if text is not None: - if n_images_in_images != n_images_in_text: - raise ValueError( - f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same." - ) - - image_rows = inputs.pop("rows", [[0] * len(text)]) - image_cols = inputs.pop("cols", [[0] * len(text)]) - - fake_image_token = self.fake_image_token.content - image_token = self.image_token.content - global_img_token = self.global_image_tag - - prompt_strings = [] - for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): - # Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len` - image_prompt_strings = [] - for n_rows, n_cols in zip(sample_rows, sample_cols): - image_prompt_string = get_image_prompt_string( - n_rows, - n_cols, - image_seq_len, - image_token=image_token, - fake_token_around_image=fake_image_token, - global_img_token=global_img_token, + if text is not None: + if n_images_in_images != n_images_in_text: + raise ValueError( + f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same." ) - image_prompt_strings.append(image_prompt_string) - split_sample = sample.split(image_token) - if len(split_sample) == 0: - raise ValueError("The image token should be present in the text.") + image_rows = inputs.pop("rows", [[0] * len(text)]) + image_cols = inputs.pop("cols", [[0] * len(text)]) + + fake_image_token = self.fake_image_token.content + image_token = self.image_token.content + global_img_token = self.global_image_tag + + prompt_strings = [] + for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): + # Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len` + image_prompt_strings = [] + for n_rows, n_cols in zip(sample_rows, sample_cols): + image_prompt_string = get_image_prompt_string( + n_rows, + n_cols, + image_seq_len, + image_token=image_token, + fake_token_around_image=fake_image_token, + global_img_token=global_img_token, + ) + image_prompt_strings.append(image_prompt_string) - # Place in the image prompt strings where the image tokens are - sample = split_sample[0] - for i, image_prompt_string in enumerate(image_prompt_strings): - sample += image_prompt_string + split_sample[i + 1] - prompt_strings.append(sample) + split_sample = sample.split(image_token) + if len(split_sample) == 0: + raise ValueError("The image token should be present in the text.") - text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"]) + # Place in the image prompt strings where the image tokens are + sample = split_sample[0] + for i, image_prompt_string in enumerate(image_prompt_strings): + sample += image_prompt_string + split_sample[i + 1] + prompt_strings.append(sample) + + text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"]) + inputs.update(text_inputs) + + elif text is not None: + if any(n_images_in_text): + raise ValueError( + f"Found {sum(n_images_in_text)} {self.image_token.content} tokens in the text but no images were passed." + ) + text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) inputs.update(text_inputs) return inputs diff --git a/tests/models/idefics3/test_processor_idefics3.py b/tests/models/idefics3/test_processor_idefics3.py index 52d2f1539a4867..36c5d294844939 100644 --- a/tests/models/idefics3/test_processor_idefics3.py +++ b/tests/models/idefics3/test_processor_idefics3.py @@ -505,3 +505,74 @@ def test_unstructured_kwargs(self): self.assertEqual(inputs["pixel_values"].shape[3], 32) self.assertEqual(len(inputs["input_ids"][0]), 120) + + @require_torch + @require_vision + def test_text_only_inference(self): + """Test that the processor works correctly with text-only input.""" + processor = self.get_processor() + + text = "This is a simple text without images." + inputs = processor(text=text) + + tokenized_sentence = processor.tokenizer(text, add_special_tokens=False) + expected_input_ids = [[self.bos_token_id] + tokenized_sentence["input_ids"]] + + self.assertEqual(inputs["input_ids"], expected_input_ids) + self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])]) + self.assertTrue("pixel_values" not in inputs) + self.assertTrue("pixel_attention_mask" not in inputs) + + # Test batch of texts without image tokens + texts = ["First text.", "Second piece of text."] + batch_inputs = processor(text=texts, padding=True) + + tokenized_1 = processor.tokenizer(texts[0], add_special_tokens=False) + tokenized_2 = processor.tokenizer(texts[1], add_special_tokens=False) + + expected_1 = [self.bos_token_id] + tokenized_1["input_ids"] + expected_2 = [self.bos_token_id] + tokenized_2["input_ids"] + + # Pad the shorter sequence + pad_len = len(expected_2) - len(expected_1) + if pad_len > 0: + padded_expected_1 = [self.padding_token_id] * pad_len + expected_1 + expected_attention_1 = [0] * pad_len + [1] * len(expected_1) + self.assertEqual(batch_inputs["input_ids"], [padded_expected_1, expected_2]) + self.assertEqual(batch_inputs["attention_mask"], [expected_attention_1, [1] * len(expected_2)]) + else: + pad_len = -pad_len + padded_expected_2 = [self.padding_token_id] * pad_len + expected_2 + expected_attention_2 = [0] * pad_len + [1] * len(expected_2) + self.assertEqual(batch_inputs["input_ids"], [expected_1, padded_expected_2]) + self.assertEqual(batch_inputs["attention_mask"], [[1] * len(expected_1), expected_attention_2]) + + @require_torch + @require_vision + def test_missing_images_error(self): + """Test that appropriate error is raised when images are referenced but not provided.""" + processor = self.get_processor() + + # Test single text with image token but no image + text = "Let me show you this image: What do you think?" + with self.assertRaises(ValueError) as context: + processor(text=text) + self.assertTrue("tokens in the text but no images were passed" in str(context.exception)) + + # Test batch with image tokens but no images + texts = [ + "First text with token.", + "Second text with token.", + ] + with self.assertRaises(ValueError) as context: + processor(text=texts) + self.assertTrue("tokens in the text but no images were passed" in str(context.exception)) + + # Test with None as Images + with self.assertRaises(ValueError) as context: + processor(text=text, images=None) + self.assertTrue("tokens in the text but no images were passed" in str(context.exception)) + + with self.assertRaises(ValueError) as context: + processor(text=texts, images=None) + self.assertTrue("tokens in the text but no images were passed" in str(context.exception))