From 184a215c8e1c2f9a4cae394740b8d9c82d07decf Mon Sep 17 00:00:00 2001 From: Miquel Farre Date: Fri, 20 Dec 2024 13:28:30 +0000 Subject: [PATCH] bugfix processing empty images --- .../models/idefics3/processing_idefics3.py | 75 ++++++++++--------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/idefics3/processing_idefics3.py b/src/transformers/models/idefics3/processing_idefics3.py index 872f5206f20175..1b0e6e22a5e79f 100644 --- a/src/transformers/models/idefics3/processing_idefics3.py +++ b/src/transformers/models/idefics3/processing_idefics3.py @@ -283,46 +283,49 @@ def __call__( image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) inputs.update(image_inputs) - if text is not None: - if n_images_in_images != n_images_in_text: - raise ValueError( - f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same." - ) - - image_rows = inputs.pop("rows", [[0] * len(text)]) - image_cols = inputs.pop("cols", [[0] * len(text)]) - - fake_image_token = self.fake_image_token.content - image_token = self.image_token.content - global_img_token = self.global_image_tag - - prompt_strings = [] - for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): - # Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len` - image_prompt_strings = [] - for n_rows, n_cols in zip(sample_rows, sample_cols): - image_prompt_string = get_image_prompt_string( - n_rows, - n_cols, - image_seq_len, - image_token=image_token, - fake_token_around_image=fake_image_token, - global_img_token=global_img_token, + if text is not None: + if n_images_in_images != n_images_in_text: + raise ValueError( + f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same." ) - image_prompt_strings.append(image_prompt_string) - split_sample = sample.split(image_token) - if len(split_sample) == 0: - raise ValueError("The image token should be present in the text.") + image_rows = inputs.pop("rows", [[0] * len(text)]) + image_cols = inputs.pop("cols", [[0] * len(text)]) + + fake_image_token = self.fake_image_token.content + image_token = self.image_token.content + global_img_token = self.global_image_tag + + prompt_strings = [] + for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): + # Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len` + image_prompt_strings = [] + for n_rows, n_cols in zip(sample_rows, sample_cols): + image_prompt_string = get_image_prompt_string( + n_rows, + n_cols, + image_seq_len, + image_token=image_token, + fake_token_around_image=fake_image_token, + global_img_token=global_img_token, + ) + image_prompt_strings.append(image_prompt_string) + + split_sample = sample.split(image_token) + if len(split_sample) == 0: + raise ValueError("The image token should be present in the text.") + + # Place in the image prompt strings where the image tokens are + sample = split_sample[0] + for i, image_prompt_string in enumerate(image_prompt_strings): + sample += image_prompt_string + split_sample[i + 1] + prompt_strings.append(sample) - # Place in the image prompt strings where the image tokens are - sample = split_sample[0] - for i, image_prompt_string in enumerate(image_prompt_strings): - sample += image_prompt_string + split_sample[i + 1] - prompt_strings.append(sample) + text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"]) + else: + text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) - text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"]) - inputs.update(text_inputs) + inputs.update(text_inputs) return inputs