Skip to content

Commit

Permalink
Add image text to text pipeline (huggingface#34170)
Browse files Browse the repository at this point in the history
* Standardize image-text-to-text-models-output

add post_process_image_text_to_text to chameleon and cleanup

Fix legacy kwarg behavior and deprecation warning

add post_process_image_text_to_text to qwen2_vl and llava_onevision

Add post_process_image_text_to_text to idefics3, mllama, pixtral processor

* nit var name post_process_image_text_to_text udop

* nit fix deprecation warnings

* Add image-text-to-text pipeline

* add support for image url in chat template for pipeline

* Reformat to be fully compatible with chat templates

* Add tests chat template

* Fix imports and tests

* Add pipeline tag

* change logic handling of single prompt ans multiple images

* add pipeline mapping to models

* fix batched inference

* fix tests

* Add manual batching for preprocessing

* Fix outputs with nested images

* Add support for all common processing kwargs

* Add default padding when multiple text inputs (batch size>1)

* nit change version deprecation warning

* Add support for text only inference

* add chat_template warnings

* Add pipeline tests and add copied from post process function

* Fix batched pipeline tests

* nit

* Fix pipeline tests blip2

* remove unnecessary max_new_tokens

* revert processing kosmos2 and remove unnecessary max_new_tokens

* fix pipeline tests idefics

* Force try loading processor if pipeline supports it

* revert load_processor change

* hardcode loading only processor

* remove unnecessary try except

* skip imagetexttotext tests for kosmos2 as tiny model causes problems

* Make code clearer

* Address review comments

* remove preprocessing logic from pipeline

* fix fuyu

* add BC resize fuyu

* Move post_process_image_text_to_text to ProcessorMixin

* add guard in post_process

* fix zero shot object detection pipeline

* add support for generator input in pipeline

* nit

* change default image-text-to-text model to llava onevision

* fix owlv2 size dict

* Change legacy deprecation warning to only show when True
  • Loading branch information
yonigozlan authored and frances720 committed Nov 6, 2024
1 parent c443d8d commit 0252da9
Show file tree
Hide file tree
Showing 47 changed files with 988 additions and 33 deletions.
6 changes: 6 additions & 0 deletions docs/source/en/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@ Pipelines available for multimodal tasks include the following.
- __call__
- all

### ImageTextToTextPipeline

[[autodoc]] ImageTextToTextPipeline
- __call__
- all

### MaskGenerationPipeline

[[autodoc]] MaskGenerationPipeline
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ja/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,12 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
- __call__
- all

### ImageTextToTextPipeline

[[autodoc]] ImageTextToTextPipeline
- __call__
- all

### VisualQuestionAnsweringPipeline

[[autodoc]] VisualQuestionAnsweringPipeline
Expand Down
6 changes: 6 additions & 0 deletions docs/source/zh/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ See [`TokenClassificationPipeline`] for all details.
- __call__
- all

### ImageTextToTextPipeline

[[autodoc]] ImageTextToTextPipeline
- __call__
- all

### MaskGenerationPipeline

[[autodoc]] MaskGenerationPipeline
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@
"ImageClassificationPipeline",
"ImageFeatureExtractionPipeline",
"ImageSegmentationPipeline",
"ImageTextToTextPipeline",
"ImageToImagePipeline",
"ImageToTextPipeline",
"JsonPipelineDataFormat",
Expand Down Expand Up @@ -5794,6 +5795,7 @@
ImageClassificationPipeline,
ImageFeatureExtractionPipeline,
ImageSegmentationPipeline,
ImageTextToTextPipeline,
ImageToImagePipeline,
ImageToTextPipeline,
JsonPipelineDataFormat,
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,27 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image


def load_images(
images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]:
"""Loads images, handling different levels of nesting.
Args:
images: A single image, a list of images, or a list of lists of images to load.
timeout: Timeout for loading images.
Returns:
A single image, a list of images, a list of lists of images.
"""
if isinstance(images, (list, tuple)):
if len(images) and isinstance(images[0], (list, tuple)):
return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
else:
return [load_image(image, timeout=timeout) for image in images]
else:
return load_image(images, timeout=timeout)


def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
("oneformer", ("OneFormerImageProcessor",)),
("owlv2", ("Owlv2ImageProcessor",)),
("owlvit", ("OwlViTImageProcessor",)),
("paligemma", ("SiglipImageProcessor",)),
("perceiver", ("PerceiverImageProcessor",)),
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor",)),
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/models/donut/processing_donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging


class DonutProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


logger = logging.get_logger(__name__)


class DonutProcessor(ProcessorMixin):
r"""
Constructs a Donut processor which wraps a Donut image processor and an XLMRoBERTa tokenizer into a single
Expand Down Expand Up @@ -85,6 +89,16 @@ def __call__(
[`~DonutTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
"""
# For backward compatibility
legacy = kwargs.pop("legacy", True)
if legacy:
# With `add_special_tokens=True`, the performance of donut are degraded when working with both images and text.
logger.warning_once(
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
"In the new behavior, if both images and text are provided, the default value of `add_special_tokens` "
"will be changed to `False` when calling the tokenizer if `add_special_tokens` is unset. "
"To test the new behavior, set `legacy=False`as a processor call argument."
)

if self._in_target_context_manager:
return self.current_processor(images, text, **kwargs)

Expand All @@ -100,6 +114,8 @@ def __call__(
if images is not None:
inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
if text is not None:
if not legacy and images is not None:
output_kwargs["text_kwargs"].setdefault("add_special_tokens", False)
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])

if text is None:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/fuyu/image_processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
pad,
resize,
Expand Down Expand Up @@ -475,6 +475,7 @@ def preprocess(
input_data_format = infer_channel_dimension_format(batch_images[0][0])

original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
size = get_size_dict(size) # for BC

if do_resize:
batch_images = [
Expand Down
30 changes: 28 additions & 2 deletions src/transformers/models/fuyu/processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ def _tokenize_prompts_with_image_and_batch(
bos_token = tokenizer.vocab["|ENDOFTEXT|"]
prompts_tokens = [[[bos_token] + x for x in prompt_seq] for prompt_seq in prompts_tokens]
if add_beginning_of_answer_token:
boa = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
beginning_of_answer = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
# Only add bbox open token to the last subsequence since that is what will be completed
for token_seq in prompts_tokens:
token_seq[-1].append(boa)
token_seq[-1].append(beginning_of_answer)

# Now we have a list of list of tokens which each list has a different
# size. We want to extend this list to:
Expand Down Expand Up @@ -682,6 +682,32 @@ def tokens_to_points(tokens, original_size):

return results

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
containing the token ids of the generated sequences.
Returns:
`List[str]`: The decoded text output.
"""
beginning_of_answer = self.tokenizer.convert_tokens_to_ids(BEGINNING_OF_ANSWER_STRING)
# get boa index for each outputted sequence tensor
# start all generated sequences from the beginning of the answer token, pad to have consistent length
unpadded_output_sequences = [
seq[(seq == beginning_of_answer).nonzero(as_tuple=True)[0] + 1 :] for seq in generated_outputs
]
max_len = max(len(seq) for seq in unpadded_output_sequences)
# convert to torch and pad sequences
padded_output_sequences = torch.full((len(unpadded_output_sequences), max_len), self.pad_token_id)
for i, seq in enumerate(unpadded_output_sequences):
padded_output_sequences[i, : len(seq)] = torch.tensor(seq)

return self.batch_decode(padded_output_sequences, skip_special_tokens=True)

def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/models/git/processing_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging


class GitProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


logger = logging.get_logger(__name__)


class GitProcessor(ProcessorMixin):
r"""
Constructs a GIT processor which wraps a CLIP image processor and a BERT tokenizer into a single processor.
Expand Down Expand Up @@ -91,6 +95,15 @@ def __call__(
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
legacy = kwargs.pop("legacy", True)
if legacy:
logger.warning_once(
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
"In the new behavior, if both images and text are provided, the last token (EOS token) "
"of the input_ids and attention_mask tensors will be removed. "
"To test the new behavior, set `legacy=False`as a processor call argument."
)

if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")

Expand All @@ -110,6 +123,10 @@ def __call__(
if images is not None:
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
data.update(image_features)
if not legacy:
data["input_ids"] = data["input_ids"][:, :-1]
data["attention_mask"] = data["attention_mask"][:, :-1]

return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))

def batch_decode(self, *args, **kwargs):
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/models/kosmos2/processing_kosmos2.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,21 @@ def post_process_generation(self, text, cleanup_and_extract=True):
return clean_text_and_extract_entities_with_bboxes(caption)
return caption

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=True)
return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts]

@property
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
def model_input_names(self):
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/models/mllama/processing_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,22 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/owlv2/image_processing_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_to_corners_format,
pad,
Expand Down Expand Up @@ -399,6 +399,7 @@ def preprocess(
image_std = image_std if image_std is not None else self.image_std

size = size if size is not None else self.size
size = get_size_dict(size) # for BC

images = make_list_of_images(images)

Expand Down
20 changes: 20 additions & 0 deletions src/transformers/models/pix2struct/processing_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import logging


class Pix2StructImagesKwargs(ImagesKwargs, total=False):
Expand Down Expand Up @@ -48,6 +49,9 @@ class Pix2StructProcessorKwargs(ProcessingKwargs, total=False):
}


logger = logging.get_logger(__name__)


class Pix2StructProcessor(ProcessorMixin):
r"""
Constructs a PIX2STRUCT processor which wraps a BERT tokenizer and PIX2STRUCT image processor into a single
Expand Down Expand Up @@ -85,6 +89,15 @@ def __call__(
Please refer to the docstring of the above two methods for more information.
"""
legacy = kwargs.pop("legacy", True)
if legacy:
logger.warning_once(
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
"In the new behavior, If both images and text are provided, image_processor is not a VQA processor, and `add_special_tokens` is unset, "
"the default value of `add_special_tokens` will be changed to `False` when calling the tokenizer. "
"To test the new behavior, set `legacy=False`as a processor call argument."
)

if images is None and text is None:
raise ValueError("You have to specify either images or text.")

Expand All @@ -93,8 +106,12 @@ def __call__(
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
add_special_tokens = output_kwargs["text_kwargs"].pop("add_special_tokens", None)
# Get only text
if images is None and not self.image_processor.is_vqa:
output_kwargs["text_kwargs"]["add_special_tokens"] = (
add_special_tokens if add_special_tokens is not None else True
)
self.current_processor = self.tokenizer
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
return text_encoding
Expand All @@ -108,6 +125,9 @@ def __call__(
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])

if text is not None and not self.image_processor.is_vqa:
output_kwargs["text_kwargs"]["add_special_tokens"] = (
add_special_tokens if add_special_tokens is not None else legacy
)
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])

if "attention_mask" in text_encoding:
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/models/qwen2_vl/processing_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,22 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
Expand Down
Loading

0 comments on commit 0252da9

Please sign in to comment.