diff --git a/docs/source/en/main_classes/pipelines.md b/docs/source/en/main_classes/pipelines.md index d5d132aaaba566..59e474fcc49f75 100644 --- a/docs/source/en/main_classes/pipelines.md +++ b/docs/source/en/main_classes/pipelines.md @@ -478,6 +478,12 @@ Pipelines available for multimodal tasks include the following. - __call__ - all +### ImageTextToTextPipeline + +[[autodoc]] ImageTextToTextPipeline + - __call__ + - all + ### MaskGenerationPipeline [[autodoc]] MaskGenerationPipeline diff --git a/docs/source/ja/main_classes/pipelines.md b/docs/source/ja/main_classes/pipelines.md index bfb9922057d318..3980becebbde36 100644 --- a/docs/source/ja/main_classes/pipelines.md +++ b/docs/source/ja/main_classes/pipelines.md @@ -481,6 +481,12 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline) - __call__ - all +### ImageTextToTextPipeline + +[[autodoc]] ImageTextToTextPipeline + - __call__ + - all + ### VisualQuestionAnsweringPipeline [[autodoc]] VisualQuestionAnsweringPipeline diff --git a/docs/source/zh/main_classes/pipelines.md b/docs/source/zh/main_classes/pipelines.md index 370b50d2469604..bc16709d8b4832 100644 --- a/docs/source/zh/main_classes/pipelines.md +++ b/docs/source/zh/main_classes/pipelines.md @@ -455,6 +455,12 @@ See [`TokenClassificationPipeline`] for all details. - __call__ - all +### ImageTextToTextPipeline + +[[autodoc]] ImageTextToTextPipeline + - __call__ + - all + ### MaskGenerationPipeline [[autodoc]] MaskGenerationPipeline diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e6789c77fb825a..47b43e0b90896f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -868,6 +868,7 @@ "ImageClassificationPipeline", "ImageFeatureExtractionPipeline", "ImageSegmentationPipeline", + "ImageTextToTextPipeline", "ImageToImagePipeline", "ImageToTextPipeline", "JsonPipelineDataFormat", @@ -5794,6 +5795,7 @@ ImageClassificationPipeline, ImageFeatureExtractionPipeline, ImageSegmentationPipeline, + ImageTextToTextPipeline, ImageToImagePipeline, ImageToTextPipeline, JsonPipelineDataFormat, diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 1a70ef05638379..f59b99b490d38d 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -385,6 +385,27 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = return image +def load_images( + images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None +) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]: + """Loads images, handling different levels of nesting. + + Args: + images: A single image, a list of images, or a list of lists of images to load. + timeout: Timeout for loading images. + + Returns: + A single image, a list of images, a list of lists of images. + """ + if isinstance(images, (list, tuple)): + if len(images) and isinstance(images[0], (list, tuple)): + return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images] + else: + return [load_image(image, timeout=timeout) for image in images] + else: + return load_image(images, timeout=timeout) + + def validate_preprocess_arguments( do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 5698abe15c8029..a8960d80acc838 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -114,6 +114,7 @@ ("oneformer", ("OneFormerImageProcessor",)), ("owlv2", ("Owlv2ImageProcessor",)), ("owlvit", ("OwlViTImageProcessor",)), + ("paligemma", ("SiglipImageProcessor",)), ("perceiver", ("PerceiverImageProcessor",)), ("pix2struct", ("Pix2StructImageProcessor",)), ("pixtral", ("PixtralImageProcessor",)), diff --git a/src/transformers/models/donut/processing_donut.py b/src/transformers/models/donut/processing_donut.py index 9552d323ac57c0..b46ff4bcfab902 100644 --- a/src/transformers/models/donut/processing_donut.py +++ b/src/transformers/models/donut/processing_donut.py @@ -24,12 +24,16 @@ from ...image_utils import ImageInput from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging class DonutProcessorKwargs(ProcessingKwargs, total=False): _defaults = {} +logger = logging.get_logger(__name__) + + class DonutProcessor(ProcessorMixin): r""" Constructs a Donut processor which wraps a Donut image processor and an XLMRoBERTa tokenizer into a single @@ -85,6 +89,16 @@ def __call__( [`~DonutTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information. """ # For backward compatibility + legacy = kwargs.pop("legacy", True) + if legacy: + # With `add_special_tokens=True`, the performance of donut are degraded when working with both images and text. + logger.warning_once( + "Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. " + "In the new behavior, if both images and text are provided, the default value of `add_special_tokens` " + "will be changed to `False` when calling the tokenizer if `add_special_tokens` is unset. " + "To test the new behavior, set `legacy=False`as a processor call argument." + ) + if self._in_target_context_manager: return self.current_processor(images, text, **kwargs) @@ -100,6 +114,8 @@ def __call__( if images is not None: inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) if text is not None: + if not legacy and images is not None: + output_kwargs["text_kwargs"].setdefault("add_special_tokens", False) encodings = self.tokenizer(text, **output_kwargs["text_kwargs"]) if text is None: diff --git a/src/transformers/models/fuyu/image_processing_fuyu.py b/src/transformers/models/fuyu/image_processing_fuyu.py index 255922b8308889..4bb9ea7964d416 100644 --- a/src/transformers/models/fuyu/image_processing_fuyu.py +++ b/src/transformers/models/fuyu/image_processing_fuyu.py @@ -19,7 +19,7 @@ import numpy as np -from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( pad, resize, @@ -475,6 +475,7 @@ def preprocess( input_data_format = infer_channel_dimension_format(batch_images[0][0]) original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images] + size = get_size_dict(size) # for BC if do_resize: batch_images = [ diff --git a/src/transformers/models/fuyu/processing_fuyu.py b/src/transformers/models/fuyu/processing_fuyu.py index ff7d2c547dc44c..e24f2fd4d1abd0 100644 --- a/src/transformers/models/fuyu/processing_fuyu.py +++ b/src/transformers/models/fuyu/processing_fuyu.py @@ -264,10 +264,10 @@ def _tokenize_prompts_with_image_and_batch( bos_token = tokenizer.vocab["|ENDOFTEXT|"] prompts_tokens = [[[bos_token] + x for x in prompt_seq] for prompt_seq in prompts_tokens] if add_beginning_of_answer_token: - boa = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING] + beginning_of_answer = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING] # Only add bbox open token to the last subsequence since that is what will be completed for token_seq in prompts_tokens: - token_seq[-1].append(boa) + token_seq[-1].append(beginning_of_answer) # Now we have a list of list of tokens which each list has a different # size. We want to extend this list to: @@ -682,6 +682,32 @@ def tokens_to_points(tokens, original_size): return results + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-processes the output of `FuyuForConditionalGeneration` to only return the text output. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + containing the token ids of the generated sequences. + + Returns: + `List[str]`: The decoded text output. + """ + beginning_of_answer = self.tokenizer.convert_tokens_to_ids(BEGINNING_OF_ANSWER_STRING) + # get boa index for each outputted sequence tensor + # start all generated sequences from the beginning of the answer token, pad to have consistent length + unpadded_output_sequences = [ + seq[(seq == beginning_of_answer).nonzero(as_tuple=True)[0] + 1 :] for seq in generated_outputs + ] + max_len = max(len(seq) for seq in unpadded_output_sequences) + # convert to torch and pad sequences + padded_output_sequences = torch.full((len(unpadded_output_sequences), max_len), self.pad_token_id) + for i, seq in enumerate(unpadded_output_sequences): + padded_output_sequences[i, : len(seq)] = torch.tensor(seq) + + return self.batch_decode(padded_output_sequences, skip_special_tokens=True) + def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please diff --git a/src/transformers/models/git/processing_git.py b/src/transformers/models/git/processing_git.py index 3744d81a0aca81..e9e96fa765d841 100644 --- a/src/transformers/models/git/processing_git.py +++ b/src/transformers/models/git/processing_git.py @@ -22,12 +22,16 @@ from ...image_utils import ImageInput from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging class GitProcessorKwargs(ProcessingKwargs, total=False): _defaults = {} +logger = logging.get_logger(__name__) + + class GitProcessor(ProcessorMixin): r""" Constructs a GIT processor which wraps a CLIP image processor and a BERT tokenizer into a single processor. @@ -91,6 +95,15 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + legacy = kwargs.pop("legacy", True) + if legacy: + logger.warning_once( + "Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. " + "In the new behavior, if both images and text are provided, the last token (EOS token) " + "of the input_ids and attention_mask tensors will be removed. " + "To test the new behavior, set `legacy=False`as a processor call argument." + ) + if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") @@ -110,6 +123,10 @@ def __call__( if images is not None: image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) data.update(image_features) + if not legacy: + data["input_ids"] = data["input_ids"][:, :-1] + data["attention_mask"] = data["attention_mask"][:, :-1] + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/kosmos2/processing_kosmos2.py b/src/transformers/models/kosmos2/processing_kosmos2.py index 76108789718b41..d7befd899f3ad3 100644 --- a/src/transformers/models/kosmos2/processing_kosmos2.py +++ b/src/transformers/models/kosmos2/processing_kosmos2.py @@ -428,6 +428,21 @@ def post_process_generation(self, text, cleanup_and_extract=True): return clean_text_and_extract_entities_with_bboxes(caption) return caption + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=True) + return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts] + @property # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names def model_input_names(self): diff --git a/src/transformers/models/mllama/processing_mllama.py b/src/transformers/models/mllama/processing_mllama.py index eea98f5bd66ac2..f183c3c1b62b52 100644 --- a/src/transformers/models/mllama/processing_mllama.py +++ b/src/transformers/models/mllama/processing_mllama.py @@ -342,6 +342,22 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names diff --git a/src/transformers/models/owlv2/image_processing_owlv2.py b/src/transformers/models/owlv2/image_processing_owlv2.py index dd32dc9f141183..3dcf145ea41ffc 100644 --- a/src/transformers/models/owlv2/image_processing_owlv2.py +++ b/src/transformers/models/owlv2/image_processing_owlv2.py @@ -19,7 +19,7 @@ import numpy as np -from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( center_to_corners_format, pad, @@ -399,6 +399,7 @@ def preprocess( image_std = image_std if image_std is not None else self.image_std size = size if size is not None else self.size + size = get_size_dict(size) # for BC images = make_list_of_images(images) diff --git a/src/transformers/models/pix2struct/processing_pix2struct.py b/src/transformers/models/pix2struct/processing_pix2struct.py index de8c594f94c9f2..bf02531ffb864f 100644 --- a/src/transformers/models/pix2struct/processing_pix2struct.py +++ b/src/transformers/models/pix2struct/processing_pix2struct.py @@ -21,6 +21,7 @@ from ...feature_extraction_utils import BatchFeature from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput +from ...utils import logging class Pix2StructImagesKwargs(ImagesKwargs, total=False): @@ -48,6 +49,9 @@ class Pix2StructProcessorKwargs(ProcessingKwargs, total=False): } +logger = logging.get_logger(__name__) + + class Pix2StructProcessor(ProcessorMixin): r""" Constructs a PIX2STRUCT processor which wraps a BERT tokenizer and PIX2STRUCT image processor into a single @@ -85,6 +89,15 @@ def __call__( Please refer to the docstring of the above two methods for more information. """ + legacy = kwargs.pop("legacy", True) + if legacy: + logger.warning_once( + "Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. " + "In the new behavior, If both images and text are provided, image_processor is not a VQA processor, and `add_special_tokens` is unset, " + "the default value of `add_special_tokens` will be changed to `False` when calling the tokenizer. " + "To test the new behavior, set `legacy=False`as a processor call argument." + ) + if images is None and text is None: raise ValueError("You have to specify either images or text.") @@ -93,8 +106,12 @@ def __call__( tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) + add_special_tokens = output_kwargs["text_kwargs"].pop("add_special_tokens", None) # Get only text if images is None and not self.image_processor.is_vqa: + output_kwargs["text_kwargs"]["add_special_tokens"] = ( + add_special_tokens if add_special_tokens is not None else True + ) self.current_processor = self.tokenizer text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) return text_encoding @@ -108,6 +125,9 @@ def __call__( encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"]) if text is not None and not self.image_processor.is_vqa: + output_kwargs["text_kwargs"]["add_special_tokens"] = ( + add_special_tokens if add_special_tokens is not None else legacy + ) text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) if "attention_mask" in text_encoding: diff --git a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py index 6c0e8d98014ede..b453b4078c7e81 100644 --- a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py @@ -168,6 +168,22 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names diff --git a/src/transformers/models/udop/processing_udop.py b/src/transformers/models/udop/processing_udop.py index ddd5d484a98883..33349af0366d77 100644 --- a/src/transformers/models/udop/processing_udop.py +++ b/src/transformers/models/udop/processing_udop.py @@ -208,20 +208,6 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) - def post_process_image_text_to_text(self, generated_outputs): - """ - Post-process the output of the model to decode the text. - - Args: - generated_outputs (`torch.Tensor` or `np.ndarray`): - The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` - or `(sequence_length,)`. - - Returns: - `List[str]`: The decoded text. - """ - return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) - @property def model_input_names(self): return ["pixel_values", "input_ids", "bbox", "attention_mask"] diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 40b3dc1015c001..07156b3cf1dbe2 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -67,6 +67,7 @@ from .image_classification import ImageClassificationPipeline from .image_feature_extraction import ImageFeatureExtractionPipeline from .image_segmentation import ImageSegmentationPipeline +from .image_text_to_text import ImageTextToTextPipeline from .image_to_image import ImageToImagePipeline from .image_to_text import ImageToTextPipeline from .mask_generation import MaskGenerationPipeline @@ -119,6 +120,7 @@ AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, + AutoModelForImageTextToText, AutoModelForMaskedLM, AutoModelForMaskGeneration, AutoModelForObjectDetection, @@ -384,6 +386,17 @@ }, "type": "multimodal", }, + "image-text-to-text": { + "impl": ImageTextToTextPipeline, + "tf": (), + "pt": (AutoModelForImageTextToText,) if is_torch_available() else (), + "default": { + "model": { + "pt": ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "2c9ba3b"), + } + }, + "type": "multimodal", + }, "object-detection": { "impl": ObjectDetectionPipeline, "tf": (), @@ -601,6 +614,7 @@ def pipeline( - `"image-classification"`: will return a [`ImageClassificationPipeline`]. - `"image-feature-extraction"`: will return an [`ImageFeatureExtractionPipeline`]. - `"image-segmentation"`: will return a [`ImageSegmentationPipeline`]. + - `"image-text-to-text"`: will return a [`ImageTextToTextPipeline`]. - `"image-to-image"`: will return a [`ImageToImagePipeline`]. - `"image-to-text"`: will return a [`ImageToTextPipeline`]. - `"mask-generation"`: will return a [`MaskGenerationPipeline`]. diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 25c2a11564c3f1..d2d4f198d41847 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -951,6 +951,14 @@ def __init__( self._num_workers = kwargs.pop("num_workers", None) self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) + # In processor only mode, we can get the modality processors from the processor + if self.processor is not None and all( + [self.tokenizer is None, self.feature_extractor is None, self.image_processor is None] + ): + self.tokenizer = getattr(self.processor, "tokenizer", None) + self.feature_extractor = getattr(self.processor, "feature_extractor", None) + self.image_processor = getattr(self.processor, "image_processor", None) + if self.image_processor is None and self.feature_extractor is not None: if isinstance(self.feature_extractor, BaseImageProcessor): # Backward compatible change, if users called diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py new file mode 100644 index 00000000000000..39738ffc385dbe --- /dev/null +++ b/src/transformers/pipelines/image_text_to_text.py @@ -0,0 +1,416 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from typing import Dict, List, Optional, Union + +from ..processing_utils import ProcessingKwargs, Unpack +from ..utils import ( + add_end_docstrings, + is_torch_available, + is_vision_available, + logging, + requires_backends, +) +from .base import Pipeline, build_pipeline_init_args + + +if is_vision_available(): + from PIL import Image + + from ..image_utils import load_images, valid_images + + +if is_torch_available(): + from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES + from .pt_utils import KeyDataset + +logger = logging.get_logger(__name__) + +IMAGE_TOKEN = "" + + +class ReturnType(enum.Enum): + TENSORS = 0 + NEW_TEXT = 1 + FULL_TEXT = 2 + + +class Chat: + """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats + to this format because the rest of the pipeline code tends to assume that lists of messages are + actually a batch of samples rather than messages in the same conversation.""" + + def __init__(self, messages: Dict, images: Union[str, List[str], "Image.Image", List["Image.Image"]]): + for message in messages: + if not ("role" in message and "content" in message): + raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") + images = retrieve_images_in_messages(messages, images) + + self.messages = messages + self.images = images + + +def retrieve_images_in_messages( + messages: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]] +): + """ + Retrieve and combine images from the chat and the images passed as input. + """ + if images is None: + images = [] + idx_images = 0 + retrieved_images = [] + for message in messages: + for content in message["content"]: + if isinstance(content, dict) and content.get("type") == "image": + if "image" in content: + retrieved_images.append(content["image"]) + elif idx_images < len(images): + retrieved_images.append(images[idx_images]) + idx_images += 1 + else: + raise ValueError( + "The number of images in the chat messages should be the same as the number of images passed to the pipeline." + ) + + # The number of images passed should be consistent with the number of images in the chat without an image key + if idx_images != len(images): + raise ValueError( + "The number of images in the chat messages should be the same as the number of images passed to the pipeline." + ) + + return retrieved_images + + +@add_end_docstrings(build_pipeline_init_args(has_processor=True)) +class ImageTextToTextPipeline(Pipeline): + """ + Image-text-to-text pipeline using an `AutoModelForImageTextToText`. This pipeline generates text given an image and text. + When the underlying model is a conversational model, it can also accept one or more chats, + in which case the pipeline will operate in chat mode and will continue the chat(s) by adding its response(s). + Each chat takes the form of a list of dicts, where each dict contains "role" and "content" keys. + + Example: + + ```python + >>> from transformers import pipeline + + >>> pipe = pipeline(task="image-text-to-text", model="Salesforce/blip-image-captioning-base") + >>> pipe("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", text="A photo of") + [{'generated_text': 'a photo of two birds'}] + ``` + + ```python + >>> from transformers import pipeline + + >>> pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf") + >>> messages = [ + >>> { + >>> "role": "user", + >>> "content": [ + >>> { + >>> "type": "image", + >>> "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + >>> }, + >>> {"type": "text", "text": "Describe this image."}, + >>> ], + >>> }, + >>> { + >>> "role": "assistant", + >>> "content": [ + >>> {"type": "text", "text": "There is a dog and"}, + >>> ], + >>> }, + >>> ] + >>> pipe(text=messages, max_new_tokens=20, return_full_text=False) + [{'input_text': [{'role': 'user', + 'content': [{'type': 'image', + 'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'}, + {'type': 'text', 'text': 'Describe this image.'}]}, + {'role': 'assistant', + 'content': [{'type': 'text', 'text': 'There is a dog and'}]}], + 'generated_text': ' a person in the image. The dog is sitting on the sand, and the person is sitting on'}] + ``` + + Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) + + This image-text to text pipeline can currently be loaded from pipeline() using the following task identifier: + "image-text-to-text". + + See the list of available models on + [huggingface.co/models](https://huggingface.co/models?pipeline_tag=image-text-to-text). + """ + + _load_processor = True + _load_image_processor = False + _load_feature_extractor = False + _load_tokenizer = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + requires_backends(self, "vision") + self.check_model_type(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES) + + def _sanitize_parameters( + self, + max_new_tokens=None, + generate_kwargs=None, + timeout=None, + return_full_text=None, + return_tensors=None, + return_type=None, + continue_final_message=None, + **kwargs: Unpack[ProcessingKwargs], + ): + forward_kwargs = {} + preprocess_params = {} + postprocess_params = {} + + preprocess_params["processing_kwargs"] = kwargs + + if timeout is not None: + preprocess_params["timeout"] = timeout + + if continue_final_message is not None: + preprocess_params["continue_final_message"] = continue_final_message + + if generate_kwargs is not None: + forward_kwargs["generate_kwargs"] = generate_kwargs + + if max_new_tokens is not None: + if "generate_kwargs" not in forward_kwargs: + forward_kwargs["generate_kwargs"] = {} + if "max_new_tokens" in forward_kwargs["generate_kwargs"]: + raise ValueError( + "'max_new_tokens' is defined twice, once in 'generate_kwargs' and once as a direct parameter," + " please use only one" + ) + forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens + + if return_full_text is not None and return_type is None: + if return_tensors is not None: + raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`") + return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT + if return_tensors is not None and return_type is None: + return_type = ReturnType.TENSORS + if return_type is not None: + postprocess_params["return_type"] = return_type + if continue_final_message is not None: + postprocess_params["continue_final_message"] = continue_final_message + + return preprocess_params, forward_kwargs, postprocess_params + + def __call__( + self, + images: Optional[ + Union[str, List[str], List[List[str]], "Image.Image", List["Image.Image"], List[List["Image.Image"]]] + ] = None, + text: Optional[Union[str, List[str], List[dict]]] = None, + **kwargs, + ): + """ + Generate a text given text and the image(s) passed as inputs. + + Args: + images (`str`, `List[str]`, `PIL.Image or `List[PIL.Image]`): + The pipeline handles three types of images: + + - A string containing a HTTP(s) link pointing to an image + - A string containing a local path to an image + - An image loaded in PIL directly + + The pipeline accepts either a single image or a batch of images. + text (str, List[str], `List[Dict[str, Union[str, PIL.Image]]]`): + The text to be used for generation. If a list of strings is passed, the length of the list should be the + same as the number of images. Text can also follow the chat format: a list of dictionaries where each + dictionary represents a message in a conversation. Each dictionary should have two keys: 'role' and + 'content'. 'role' should be one of 'user', 'system' or 'assistant'. 'content' should be a list of dictionary + containing the text of the message and the type of the message. The type of the message can be either + 'text' or 'image'. If the type is 'image', no text is needed. + return_tensors (`bool`, *optional*, defaults to `False`): + Returns the tensors of predictions (as token indices) in the outputs. If set to + `True`, the decoded text is not returned. + return_text (`bool`, *optional*): + Returns the decoded texts in the outputs. + return_full_text (`bool`, *optional*, defaults to `True`): + If set to `False` only added text is returned, otherwise the full text is returned. Cannot be + specified at the same time as `return_text`. + continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the + last message in the input chat rather than starting a new one, allowing you to "prefill" its response. + By default this is `True` when the final message in the input chat has the `assistant` role and + `False` otherwise, but you can manually override that behaviour by setting this flag. + + Return: + A list or a list of list of `dict`: Each result comes as a dictionary with the following key (cannot return a combination + of both `generated_text` and `generated_token_ids`): + + - **generated_text** (`str`, present when `return_text=True`) -- The generated text. + - **generated_token_ids** (`torch.Tensor`, present when `return_tensors=True`) -- The token + ids of the generated text. + - **input_text** (`str`) -- The input text. + """ + if images is None and text is None: + raise ValueError("You must at least provide either text or images.") + if images is not None and text is None and not valid_images(images): + """ + Supports the following format + - {"image": image, "text": text} + - [{"image": image, "text": text}] + - Generator and datasets + This is a common pattern in other multimodal pipelines, so we support it here as well. + """ + return super().__call__(images, **kwargs) + + if isinstance(text, (list, tuple, KeyDataset)) and isinstance(text[0], (list, tuple, dict)): + # We have one or more prompts in list-of-dicts format, so this is chat mode + if isinstance(text[0], dict): + return super().__call__(Chat(text, images), **kwargs) + else: + if images is None: + images = [None] * len(text) + chats = [Chat(chat, image) for chat, image in zip(text, images)] # 🐈 🐈 🐈 + return super().__call__(chats, **kwargs) + + # encourage the user to use the chat format if supported + if getattr(self.processor, "chat_template", None) is not None: + logger.warning_once( + "The input data was not formatted as a chat with dicts containing 'role' and 'content' keys, even though this model supports chat. " + "Consider using the chat format for better results. For more information, see https://huggingface.co/docs/transformers/en/chat_templating" + ) + + # support text only generation + if images is None: + return super().__call__(text, **kwargs) + if text is None: + raise ValueError("You must provide text for this pipeline.") + + return super().__call__({"images": images, "text": text}, **kwargs) + + def preprocess(self, inputs=None, timeout=None, continue_final_message=None, processing_kwargs=None): + # In case we only have text inputs + if isinstance(inputs, (list, tuple, str)): + images = None + text = inputs + inputs_text = inputs + else: + if isinstance(inputs, Chat): + # If the user passes a chat that ends in an assistant message, we treat it as a prefill by default + # because very few models support multiple separate, consecutive assistant messages + if continue_final_message is None: + continue_final_message = inputs.messages[-1]["role"] == "assistant" + text = self.processor.apply_chat_template( + inputs.messages, + add_generation_prompt=not continue_final_message, + continue_final_message=continue_final_message, + return_tensors=self.framework, + ) + inputs_text = inputs + images = inputs.images + else: + text = inputs["text"] + inputs_text = inputs["text"] + images = inputs["images"] + + images = load_images(images) + + # if batched text inputs, we set padding to True unless specified otherwise + if isinstance(text, (list, tuple)) and len(text) > 1: + processing_kwargs.setdefault("padding", True) + model_inputs = self.processor( + images=images, text=text, return_tensors=self.framework, legacy=False, **processing_kwargs + ).to(dtype=self.torch_dtype) + + model_inputs["text"] = inputs_text + + return model_inputs + + def _forward(self, model_inputs, generate_kwargs=None): + generate_kwargs = {} if generate_kwargs is None else generate_kwargs + prompt_text = model_inputs.pop("text") + input_ids = ( + model_inputs["input_ids"] if "input_ids" in model_inputs else model_inputs["decoder_input_ids"] + ) # for decoder-only models + generated_sequence = self.model.generate(**model_inputs, **generate_kwargs) + + return {"generated_sequence": generated_sequence, "prompt_text": prompt_text, "input_ids": input_ids} + + def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None): + input_texts = model_outputs["prompt_text"] + input_texts = [input_texts] if isinstance(input_texts, (str, Chat)) else input_texts + generated_sequence = model_outputs["generated_sequence"] + input_ids = model_outputs["input_ids"] + if return_type == ReturnType.TENSORS: + return [ + {"input_text": input_texts[i], "generated_token_ids": generated_sequence[i]} + for i in range(len(input_texts)) + ] + + # Decode inputs and outputs the same way to remove input text from generated text if present + generated_texts = self.processor.post_process_image_text_to_text(generated_sequence) + decoded_inputs = self.processor.post_process_image_text_to_text(input_ids) + + # Force consistent behavior for including the input text in the output + if return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}: + # Remove the input text from the generated text if the generated text starts with the input text + # (accounting for the possibility of a space between the input and generated text) + new_generated_texts = [] + for text_generated, decoded_input in zip(generated_texts, decoded_inputs): + # There can be added characters before the input text, so we need to find the beginning of the input text in the generated text + index_input_text = text_generated.find(decoded_input) + # Limit the search to 2 residual characters, like spaces or new lines, to avoid removing a large part of the answer + if 0 <= index_input_text <= 2: + # If the input text is found, we remove it + new_generated_texts.append(text_generated[index_input_text + len(decoded_input) :]) + else: + new_generated_texts.append(text_generated) + generated_texts = new_generated_texts + if return_type == ReturnType.FULL_TEXT: + full_texts = [] + for prompt_text, generated_text in zip(input_texts, generated_texts): + if isinstance(prompt_text, str): + generated_text = prompt_text + generated_text + elif isinstance(prompt_text, Chat): + if continue_final_message is None: + # If the user passes a chat ending in an assistant message, we treat it as a prefill by + # default because very few models support multiple separate, consecutive assistant messages + continue_final_message = prompt_text.messages[-1]["role"] == "assistant" + if continue_final_message: + # With assistant prefill, concat onto the end of the last message + new_text = dict(prompt_text.messages[-1]["content"][-1].items()) + new_text["text"] += generated_text + generated_text = list(prompt_text.messages)[:-1] + [ + { + "role": prompt_text.messages[-1]["role"], + "content": prompt_text.messages[-1]["content"][:-1] + [new_text], + } + ] + else: + # When we're not starting from a prefill, the output is a new assistant message + generated_text = list(prompt_text.messages) + [ + {"role": "assistant", "content": generated_text} + ] + full_texts.append(generated_text) + generated_texts = full_texts + + records = [ + { + "input_text": input_text.messages if isinstance(input_text, Chat) else input_text, + "generated_text": generated_text, + } + for input_text, generated_text in zip(input_texts, generated_texts) + ] + + return records diff --git a/src/transformers/pipelines/image_to_text.py b/src/transformers/pipelines/image_to_text.py index 0d37ce91dadc89..afd67b6ac9edee 100644 --- a/src/transformers/pipelines/image_to_text.py +++ b/src/transformers/pipelines/image_to_text.py @@ -134,6 +134,10 @@ def preprocess(self, image, prompt=None, timeout=None): image = load_image(image, timeout=timeout) if prompt is not None: + logger.warning_once( + "Passing `prompt` to the `image-to-text` pipeline is deprecated and will be removed in version 4.48" + " of 🤗 Transformers. Use the `image-text-to-text` pipeline instead", + ) if not isinstance(prompt, str): raise ValueError( f"Received an invalid text input, got - {type(prompt)} - but expected a single string. " diff --git a/src/transformers/pipelines/zero_shot_object_detection.py b/src/transformers/pipelines/zero_shot_object_detection.py index 9ad575202266ee..ce8da7340bcce5 100644 --- a/src/transformers/pipelines/zero_shot_object_detection.py +++ b/src/transformers/pipelines/zero_shot_object_detection.py @@ -7,7 +7,7 @@ if is_vision_available(): from PIL import Image - from ..image_utils import load_image + from ..image_utils import load_image, valid_images if is_torch_available(): import torch @@ -130,8 +130,23 @@ def __call__( if isinstance(image, (str, Image.Image)): inputs = {"image": image, "candidate_labels": candidate_labels} + elif isinstance(image, (list, tuple)) and valid_images(image): + return list( + super().__call__( + ({"image": img, "candidate_labels": labels} for img, labels in zip(image, candidate_labels)), + **kwargs, + ) + ) else: + """ + Supports the following format + - {"image": image, "candidate_labels": candidate_labels} + - [{"image": image, "candidate_labels": candidate_labels}] + - Generator and datasets + This is a common pattern in other multimodal pipelines, so we support it here as well. + """ inputs = image + results = super().__call__(inputs, **kwargs) return results diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 286ca49de85706..b5b02f6a00aa09 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1107,6 +1107,20 @@ def apply_chat_template( conversation, chat_template=chat_template, tokenize=tokenize, **kwargs ) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of a vlm to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + def _validate_images_text_input_order(images, text): """ diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 2f8ee3229ff2cd..d60c76393f02bb 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -436,6 +436,7 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): "feature-extraction": BlipModel, "image-to-text": BlipForConditionalGeneration, "visual-question-answering": BlipForQuestionAnswering, + "image-text-to-text": BlipForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index e5d04bd85a3404..f2b945ef4451e4 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -767,6 +767,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi "feature-extraction": Blip2Model, "image-to-text": Blip2ForConditionalGeneration, "visual-question-answering": Blip2ForConditionalGeneration, + "image-text-to-text": Blip2ForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index 2a8e7633ba40c5..bb2ba8b3428174 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -276,6 +276,7 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester { "feature-extraction": ChameleonModel, "text-generation": ChameleonForConditionalGeneration, + "image-text-to-text": ChameleonForConditionalGeneration, } if is_torch_available() else {} diff --git a/tests/models/fuyu/test_modeling_fuyu.py b/tests/models/fuyu/test_modeling_fuyu.py index 9425bddb6f703c..4bd66ab945f441 100644 --- a/tests/models/fuyu/test_modeling_fuyu.py +++ b/tests/models/fuyu/test_modeling_fuyu.py @@ -265,7 +265,9 @@ def prepare_config_and_inputs_for_common(self): @require_torch class FuyuModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (FuyuForCausalLM,) if is_torch_available() else () - pipeline_model_mapping = {"text-generation": FuyuForCausalLM} if is_torch_available() else {} + pipeline_model_mapping = ( + {"text-generation": FuyuForCausalLM, "image-text-to-text": FuyuForCausalLM} if is_torch_available() else {} + ) test_head_masking = False test_pruning = False diff --git a/tests/models/git/test_modeling_git.py b/tests/models/git/test_modeling_git.py index 33da9e26cba03d..ccfb41459caf73 100644 --- a/tests/models/git/test_modeling_git.py +++ b/tests/models/git/test_modeling_git.py @@ -401,7 +401,12 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, all_model_classes = (GitModel, GitForCausalLM) if is_torch_available() else () all_generative_model_classes = (GitForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( - {"feature-extraction": GitModel, "image-to-text": GitForCausalLM, "text-generation": GitForCausalLM} + { + "feature-extraction": GitModel, + "image-to-text": GitForCausalLM, + "text-generation": GitForCausalLM, + "image-text-to-text": GitForCausalLM, + } if is_torch_available() else {} ) diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index d19d10932bfcdc..7be87fd78390ab 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -332,7 +332,11 @@ def test_eager_matches_sdpa_generate(self): @require_torch class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (IdeficsModel, IdeficsForVisionText2Text) if is_torch_available() else () - pipeline_model_mapping = {"feature-extraction": IdeficsModel} if is_torch_available() else {} + pipeline_model_mapping = ( + {"feature-extraction": IdeficsModel, "image-text-to-text": IdeficsForVisionText2Text} + if is_torch_available() + else {} + ) test_pruning = False test_headmasking = False test_torchscript = False diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index 0b0f3c1f3d8483..3dcd0bf5fbcdeb 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -375,6 +375,7 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest all_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"image-text-to-text": Idefics2ForConditionalGeneration} if is_torch_available() else () fx_compatible = False test_pruning = False test_resize_embeddings = True diff --git a/tests/models/idefics3/test_modeling_idefics3.py b/tests/models/idefics3/test_modeling_idefics3.py index dc5aad2fd04395..598f5882470e99 100644 --- a/tests/models/idefics3/test_modeling_idefics3.py +++ b/tests/models/idefics3/test_modeling_idefics3.py @@ -317,6 +317,7 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest all_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"image-text-to-text": Idefics3ForConditionalGeneration} if is_torch_available() else () fx_compatible = False test_pruning = False test_resize_embeddings = True diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index a33be021353f72..2771dac1e3767e 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -455,6 +455,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"image-text-to-text": InstructBlipForConditionalGeneration} fx_compatible = False test_head_masking = False test_pruning = False diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index 0f0b595d3d2306..43266a750b8d6c 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -257,7 +257,11 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) all_model_classes = (Kosmos2Model, Kosmos2ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (Kosmos2ForConditionalGeneration,) if is_torch_available() else () pipeline_model_mapping = ( - {"feature-extraction": Kosmos2Model, "image-to-text": Kosmos2ForConditionalGeneration} + { + "feature-extraction": Kosmos2Model, + "image-to-text": Kosmos2ForConditionalGeneration, + "image-text-to-text": Kosmos2ForConditionalGeneration, + } if is_torch_available() else {} ) @@ -269,6 +273,7 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) _is_composite = True # TODO: `image-to-text` pipeline for this model needs Processor. + # TODO: Tiny model needs fixing for `image-text-to-text` (latent_query_num=3 not compatible with num_image_tokens=64). def is_pipeline_test_to_skip( self, pipeline_test_case_name, @@ -279,7 +284,10 @@ def is_pipeline_test_to_skip( feature_extractor_name, processor_name, ): - return pipeline_test_case_name == "ImageToTextPipelineTests" + return ( + pipeline_test_case_name == "ImageToTextPipelineTests" + or pipeline_test_case_name == "ImageTextToTextPipelineTests" + ) def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): inputs_dict = copy.deepcopy(inputs_dict) diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 23317648103b4d..9810ff7c2a56d4 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -183,7 +183,11 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM all_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else () - pipeline_model_mapping = {"image-to-text": LlavaForConditionalGeneration} if is_torch_available() else {} + pipeline_model_mapping = ( + {"image-to-text": LlavaForConditionalGeneration, "image-text-to-text": LlavaForConditionalGeneration} + if is_torch_available() + else {} + ) test_pruning = False test_head_masking = False _is_composite = True diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index e960f9f6759981..2146c94c18a4b4 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -216,6 +216,7 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes all_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"image-text-to-text": LlavaNextForConditionalGeneration} if is_torch_available() else {} test_pruning = False test_head_masking = False _is_composite = True diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index 107b6321b65cff..7a5781fa039b5b 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -217,6 +217,9 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati all_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = ( + {"image-text-to-text": LlavaOnevisionForConditionalGeneration} if is_torch_available() else {} + ) test_pruning = False test_head_masking = False _is_composite = True diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index 42bf6fd7081618..91f2169a02f42d 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -264,6 +264,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester all_model_classes = (MllamaForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (MllamaForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"image-text-to-text": MllamaForConditionalGeneration} if is_torch_available() else () test_pruning = False test_head_masking = False test_torchscript = False diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 4c591805766f86..074e0083fd0202 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -183,6 +183,7 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes all_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"image-text-to-text": PaliGemmaForConditionalGeneration} fx_compatible = False test_pruning = False test_torchscript = False diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index 18b79f3fbc9c04..7438dc6d666179 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -420,7 +420,11 @@ def prepare_config_and_inputs_for_common(self): class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {} - pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {} + pipeline_model_mapping = ( + {"image-to-text": Pix2StructForConditionalGeneration, "image-text-to-text": Pix2StructForConditionalGeneration} + if is_torch_available() + else {} + ) fx_compatible = False test_head_masking = False test_pruning = False diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index a3272853a78427..6c04ba40df19d6 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -224,6 +224,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas all_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"image-text-to-text": Qwen2VLForConditionalGeneration} test_pruning = False test_head_masking = False diff --git a/tests/models/udop/test_modeling_udop.py b/tests/models/udop/test_modeling_udop.py index 9d82173b1aed6c..d55400799dbd30 100644 --- a/tests/models/udop/test_modeling_udop.py +++ b/tests/models/udop/test_modeling_udop.py @@ -275,7 +275,11 @@ class UdopModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (UdopForConditionalGeneration,) if is_torch_available() else () - pipeline_model_mapping = {"feature-extraction": UdopModel} if is_torch_available() else {} + pipeline_model_mapping = ( + {"feature-extraction": UdopModel, "image-text-to-text": UdopForConditionalGeneration} + if is_torch_available() + else {} + ) fx_compatible = False test_pruning = False test_torchscript = False diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index 87e7925ade214c..e2f9ae1ccfdea7 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -170,6 +170,7 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest all_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"image-text-to-text": VipLlavaForConditionalGeneration} if is_torch_available() else {} fx_compatible = False test_pruning = False test_resize_embeddings = True diff --git a/tests/pipelines/test_pipelines_image_text_to_text.py b/tests/pipelines/test_pipelines_image_text_to_text.py new file mode 100644 index 00000000000000..b44b9decf98bbd --- /dev/null +++ b/tests/pipelines/test_pipelines_image_text_to_text.py @@ -0,0 +1,260 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available +from transformers.pipelines import ImageTextToTextPipeline, pipeline +from transformers.testing_utils import ( + is_pipeline_test, + require_torch, + require_vision, + slow, +) + +from .test_pipelines_common import ANY + + +if is_vision_available(): + from PIL import Image +else: + + class Image: + @staticmethod + def open(*args, **kwargs): + pass + + +@is_pipeline_test +@require_vision +class ImageTextToTextPipelineTests(unittest.TestCase): + model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING + + def get_test_pipeline(self, model, tokenizer, processor, image_processor, torch_dtype="float32"): + pipe = ImageTextToTextPipeline(model=model, processor=processor, torch_dtype=torch_dtype) + image_token = getattr(processor.tokenizer, "image_token", "") + examples = [ + { + "images": Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"), + "text": f"{image_token}This is a ", + }, + { + "images": "./tests/fixtures/tests_samples/COCO/000000039769.png", + "text": f"{image_token}Here I see a ", + }, + ] + return pipe, examples + + def run_pipeline_test(self, pipe, examples): + outputs = pipe(examples[0].get("images"), text=examples[0].get("text")) + self.assertEqual( + outputs, + [ + {"input_text": ANY(str), "generated_text": ANY(str)}, + ], + ) + + @require_torch + def test_small_model_pt_token(self): + pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf") + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + text = " What this is? Assistant: This is" + + outputs = pipe(image, text=text) + self.assertEqual( + outputs, + [ + { + "input_text": " What this is? Assistant: This is", + "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable", + } + ], + ) + + outputs = pipe([image, image], text=[text, text]) + self.assertEqual( + outputs, + [ + { + "input_text": " What this is? Assistant: This is", + "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable", + }, + { + "input_text": " What this is? Assistant: This is", + "generated_text": " What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable", + }, + ], + ) + + @require_torch + def test_consistent_batching_behaviour(self): + pipe = pipeline("image-text-to-text", model="microsoft/kosmos-2-patch14-224") + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + prompt = "a photo of" + + outputs = pipe([image, image], text=[prompt, prompt]) + outputs_batched = pipe([image, image], text=[prompt, prompt], batch_size=2) + self.assertEqual(outputs, outputs_batched) + + @slow + @require_torch + def test_model_pt_chat_template(self): + pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf") + image_ny = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + image_chicago = "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s the difference between these two images?"}, + {"type": "image"}, + {"type": "image"}, + ], + } + ] + outputs = pipe([image_ny, image_chicago], text=messages) + self.assertEqual( + outputs, + [ + { + "input_text": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s the difference between these two images?"}, + {"type": "image"}, + {"type": "image"}, + ], + } + ], + "generated_text": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s the difference between these two images?"}, + {"type": "image"}, + {"type": "image"}, + ], + }, + { + "role": "assistant", + "content": "The first image shows a statue of the Statue of Liberty in the foreground, while the second image shows", + }, + ], + } + ], + ) + + @slow + @require_torch + def test_model_pt_chat_template_continue_final_message(self): + pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf") + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "There is a dog and"}, + ], + }, + ] + outputs = pipe(text=messages) + self.assertEqual( + outputs, + [ + { + "input_text": [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": "There is a dog and"}]}, + ], + "generated_text": [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image."}, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "There is a dog and a person in the image. The dog is sitting on the sand, and the person is sitting on", + } + ], + }, + ], + } + ], + ) + + @slow + @require_torch + def test_model_pt_chat_template_new_text(self): + pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf") + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image."}, + ], + } + ] + outputs = pipe(text=messages, return_full_text=False) + self.assertEqual( + outputs, + [ + { + "input_text": [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + {"type": "text", "text": "Describe this image."}, + ], + } + ], + "generated_text": "In the image, a woman is sitting on the sandy beach, her legs crossed in a relaxed manner", + } + ], + ) diff --git a/tests/pipelines/test_pipelines_zero_shot_object_detection.py b/tests/pipelines/test_pipelines_zero_shot_object_detection.py index 48cdb9bd15ca53..5ed48de3610eb9 100644 --- a/tests/pipelines/test_pipelines_zero_shot_object_detection.py +++ b/tests/pipelines/test_pipelines_zero_shot_object_detection.py @@ -14,7 +14,12 @@ import unittest -from transformers import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, is_vision_available, pipeline +from transformers import ( + MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, + ZeroShotObjectDetectionPipeline, + is_vision_available, + pipeline, +) from transformers.testing_utils import ( is_pipeline_test, nested_simplify, @@ -52,9 +57,11 @@ def get_test_pipeline( processor=None, torch_dtype="float32", ): - object_detector = pipeline( - "zero-shot-object-detection", - model="hf-internal-testing/tiny-random-owlvit-object-detection", + object_detector = ZeroShotObjectDetectionPipeline( + model=model, + processor=processor, + tokenizer=tokenizer, + image_processor=image_processor, torch_dtype=torch_dtype, ) @@ -67,7 +74,7 @@ def get_test_pipeline( return object_detector, examples def run_pipeline_test(self, object_detector, examples): - outputs = object_detector(examples[0], threshold=0.0) + outputs = object_detector(examples[0].get("image"), examples[0].get("candidate_labels"), threshold=0.0) n = len(outputs) self.assertGreater(n, 0) diff --git a/tests/test_pipeline_mixin.py b/tests/test_pipeline_mixin.py index f079bcdd92e580..94bc3d5fae1ad2 100644 --- a/tests/test_pipeline_mixin.py +++ b/tests/test_pipeline_mixin.py @@ -71,6 +71,7 @@ from .pipelines.test_pipelines_image_classification import ImageClassificationPipelineTests from .pipelines.test_pipelines_image_feature_extraction import ImageFeatureExtractionPipelineTests from .pipelines.test_pipelines_image_segmentation import ImageSegmentationPipelineTests +from .pipelines.test_pipelines_image_text_to_text import ImageTextToTextPipelineTests from .pipelines.test_pipelines_image_to_image import ImageToImagePipelineTests from .pipelines.test_pipelines_image_to_text import ImageToTextPipelineTests from .pipelines.test_pipelines_mask_generation import MaskGenerationPipelineTests @@ -102,6 +103,7 @@ "image-classification": {"test": ImageClassificationPipelineTests}, "image-feature-extraction": {"test": ImageFeatureExtractionPipelineTests}, "image-segmentation": {"test": ImageSegmentationPipelineTests}, + "image-text-to-text": {"test": ImageTextToTextPipelineTests}, "image-to-image": {"test": ImageToImagePipelineTests}, "image-to-text": {"test": ImageToTextPipelineTests}, "mask-generation": {"test": MaskGenerationPipelineTests}, @@ -586,6 +588,18 @@ def test_pipeline_image_segmentation(self): def test_pipeline_image_segmentation_fp16(self): self.run_task_tests(task="image-segmentation", torch_dtype="float16") + @is_pipeline_test + @require_vision + @require_torch + def test_pipeline_image_text_to_text(self): + self.run_task_tests(task="image-text-to-text") + + @is_pipeline_test + @require_vision + @require_torch + def test_pipeline_image_text_to_text_fp16(self): + self.run_task_tests(task="image-text-to-text", torch_dtype="float16") + @is_pipeline_test @require_vision def test_pipeline_image_to_text(self): diff --git a/tests/utils/tiny_model_summary.json b/tests/utils/tiny_model_summary.json index 911783bc5cfbac..f27f720ec3d593 100644 --- a/tests/utils/tiny_model_summary.json +++ b/tests/utils/tiny_model_summary.json @@ -2896,7 +2896,7 @@ "model_classes": [ "IdeficsForVisionText2Text" ], - "sha": "2c2f2e2cd6b02a77d0cdd8c3767ba9a6267dbd20" + "sha": "a6be81294ff7a3d44f3aef0ed18e42b97c426831" }, "IdeficsModel": { "tokenizer_classes": [ diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index f31be7cbe1f28d..0be960f4a33e6d 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -335,6 +335,7 @@ "ImageFeatureExtractionPipeline", "ImageGPTConfig", "ImageSegmentationPipeline", + "ImageTextToTextPipeline", "ImageToImagePipeline", "ImageToTextPipeline", "InformerConfig", diff --git a/utils/update_metadata.py b/utils/update_metadata.py index 1806eb3f03df5a..b6ee1e7c8c13c2 100755 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -69,6 +69,7 @@ ("automatic-speech-recognition", "MODEL_FOR_CTC_MAPPING_NAMES", "AutoModelForCTC"), ("image-classification", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForImageClassification"), ("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"), + ("image-text-to-text", "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", "AutoModelForImageTextToText"), ("image-to-image", "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES", "AutoModelForImageToImage"), ("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"), ("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"),