diff --git a/src/transformers/models/git/processing_git.py b/src/transformers/models/git/processing_git.py index 98649c644e728c..3744d81a0aca81 100644 --- a/src/transformers/models/git/processing_git.py +++ b/src/transformers/models/git/processing_git.py @@ -16,8 +16,16 @@ Image/Text processor class for GIT """ -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class GitProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} class GitProcessor(ProcessorMixin): @@ -42,7 +50,14 @@ def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor - def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[GitProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode @@ -51,13 +66,13 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: @@ -68,7 +83,7 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when @@ -76,29 +91,26 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ - tokenizer_kwargs, image_processor_kwargs = {}, {} - if kwargs: - tokenizer_kwargs = {k: v for k, v in kwargs.items() if k not in self.image_processor._valid_processor_keys} - image_processor_kwargs = { - k: v for k, v in kwargs.items() if k in self.image_processor._valid_processor_keys - } - if text is None and images is None: raise ValueError("You have to specify either text or images. Both cannot be none.") - if text is not None: - encoding = self.tokenizer(text, return_tensors=return_tensors, **tokenizer_kwargs) + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + GitProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + data = {} + if text is not None: + text_features = self.tokenizer(text, **output_kwargs["text_kwargs"]) + data.update(text_features) if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **image_processor_kwargs) - - if text is not None and images is not None: - encoding["pixel_values"] = image_features.pixel_values - return encoding - elif text is not None: - return encoding - else: - return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + data.update(image_features) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")) def batch_decode(self, *args, **kwargs): """