Skip to content

Commit

Permalink
bugfix Idefics3 processor - handle gracefully cases with text and no …
Browse files Browse the repository at this point in the history
…images (#35363)

* bugfix processing empty images

* fix

* fix

* Update src/transformers/models/idefics3/processing_idefics3.py

Co-authored-by: Yoni Gozlan <[email protected]>

* adding tests

* fix

* fix

* fix

---------

Co-authored-by: Yoni Gozlan <[email protected]>
  • Loading branch information
mfarre and yonigozlan authored Dec 23, 2024
1 parent 64c05ee commit a1780b7
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 35 deletions.
78 changes: 43 additions & 35 deletions src/transformers/models/idefics3/processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,45 +283,53 @@ def __call__(
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
inputs.update(image_inputs)

if text is not None:
if n_images_in_images != n_images_in_text:
raise ValueError(
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
)

image_rows = inputs.pop("rows", [[0] * len(text)])
image_cols = inputs.pop("cols", [[0] * len(text)])

fake_image_token = self.fake_image_token.content
image_token = self.image_token.content
global_img_token = self.global_image_tag

prompt_strings = []
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
image_prompt_string = get_image_prompt_string(
n_rows,
n_cols,
image_seq_len,
image_token=image_token,
fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
if text is not None:
if n_images_in_images != n_images_in_text:
raise ValueError(
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
)
image_prompt_strings.append(image_prompt_string)

split_sample = sample.split(image_token)
if len(split_sample) == 0:
raise ValueError("The image token should be present in the text.")
image_rows = inputs.pop("rows", [[0] * len(text)])
image_cols = inputs.pop("cols", [[0] * len(text)])

fake_image_token = self.fake_image_token.content
image_token = self.image_token.content
global_img_token = self.global_image_tag

prompt_strings = []
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
image_prompt_string = get_image_prompt_string(
n_rows,
n_cols,
image_seq_len,
image_token=image_token,
fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
)
image_prompt_strings.append(image_prompt_string)

# Place in the image prompt strings where the image tokens are
sample = split_sample[0]
for i, image_prompt_string in enumerate(image_prompt_strings):
sample += image_prompt_string + split_sample[i + 1]
prompt_strings.append(sample)
split_sample = sample.split(image_token)
if len(split_sample) == 0:
raise ValueError("The image token should be present in the text.")

text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"])
# Place in the image prompt strings where the image tokens are
sample = split_sample[0]
for i, image_prompt_string in enumerate(image_prompt_strings):
sample += image_prompt_string + split_sample[i + 1]
prompt_strings.append(sample)

text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"])
inputs.update(text_inputs)

elif text is not None:
if any(n_images_in_text):
raise ValueError(
f"Found {sum(n_images_in_text)} {self.image_token.content} tokens in the text but no images were passed."
)
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
inputs.update(text_inputs)

return inputs
Expand Down
71 changes: 71 additions & 0 deletions tests/models/idefics3/test_processor_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,74 @@ def test_unstructured_kwargs(self):

self.assertEqual(inputs["pixel_values"].shape[3], 32)
self.assertEqual(len(inputs["input_ids"][0]), 120)

@require_torch
@require_vision
def test_text_only_inference(self):
"""Test that the processor works correctly with text-only input."""
processor = self.get_processor()

text = "This is a simple text without images."
inputs = processor(text=text)

tokenized_sentence = processor.tokenizer(text, add_special_tokens=False)
expected_input_ids = [[self.bos_token_id] + tokenized_sentence["input_ids"]]

self.assertEqual(inputs["input_ids"], expected_input_ids)
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])])
self.assertTrue("pixel_values" not in inputs)
self.assertTrue("pixel_attention_mask" not in inputs)

# Test batch of texts without image tokens
texts = ["First text.", "Second piece of text."]
batch_inputs = processor(text=texts, padding=True)

tokenized_1 = processor.tokenizer(texts[0], add_special_tokens=False)
tokenized_2 = processor.tokenizer(texts[1], add_special_tokens=False)

expected_1 = [self.bos_token_id] + tokenized_1["input_ids"]
expected_2 = [self.bos_token_id] + tokenized_2["input_ids"]

# Pad the shorter sequence
pad_len = len(expected_2) - len(expected_1)
if pad_len > 0:
padded_expected_1 = [self.padding_token_id] * pad_len + expected_1
expected_attention_1 = [0] * pad_len + [1] * len(expected_1)
self.assertEqual(batch_inputs["input_ids"], [padded_expected_1, expected_2])
self.assertEqual(batch_inputs["attention_mask"], [expected_attention_1, [1] * len(expected_2)])
else:
pad_len = -pad_len
padded_expected_2 = [self.padding_token_id] * pad_len + expected_2
expected_attention_2 = [0] * pad_len + [1] * len(expected_2)
self.assertEqual(batch_inputs["input_ids"], [expected_1, padded_expected_2])
self.assertEqual(batch_inputs["attention_mask"], [[1] * len(expected_1), expected_attention_2])

@require_torch
@require_vision
def test_missing_images_error(self):
"""Test that appropriate error is raised when images are referenced but not provided."""
processor = self.get_processor()

# Test single text with image token but no image
text = "Let me show you this image: <image> What do you think?"
with self.assertRaises(ValueError) as context:
processor(text=text)
self.assertTrue("tokens in the text but no images were passed" in str(context.exception))

# Test batch with image tokens but no images
texts = [
"First text with <image> token.",
"Second text <image> with token.",
]
with self.assertRaises(ValueError) as context:
processor(text=texts)
self.assertTrue("tokens in the text but no images were passed" in str(context.exception))

# Test with None as Images
with self.assertRaises(ValueError) as context:
processor(text=text, images=None)
self.assertTrue("tokens in the text but no images were passed" in str(context.exception))

with self.assertRaises(ValueError) as context:
processor(text=texts, images=None)
self.assertTrue("tokens in the text but no images were passed" in str(context.exception))

0 comments on commit a1780b7

Please sign in to comment.