Skip to content

Commit

Permalink
uniformize git processor
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Sep 23, 2024
1 parent 1456120 commit 3ae28cc
Showing 1 changed file with 38 additions and 26 deletions.
64 changes: 38 additions & 26 deletions src/transformers/models/git/processing_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,16 @@
Image/Text processor class for GIT
"""

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from typing import List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput


class GitProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


class GitProcessor(ProcessorMixin):
Expand All @@ -42,7 +50,14 @@ def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor

def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
def __call__(
self,
images: Optional[ImageInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio=None,
videos=None,
**kwargs: Unpack[GitProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
Expand All @@ -51,13 +66,13 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
Expand All @@ -68,37 +83,34 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
tokenizer_kwargs, image_processor_kwargs = {}, {}
if kwargs:
tokenizer_kwargs = {k: v for k, v in kwargs.items() if k not in self.image_processor._valid_processor_keys}
image_processor_kwargs = {
k: v for k, v in kwargs.items() if k in self.image_processor._valid_processor_keys
}

if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")

if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **tokenizer_kwargs)
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)

output_kwargs = self._merge_kwargs(
GitProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

data = {}
if text is not None:
text_features = self.tokenizer(text, **output_kwargs["text_kwargs"])
data.update(text_features)
if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **image_processor_kwargs)

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif text is not None:
return encoding
else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
data.update(image_features)
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))

def batch_decode(self, *args, **kwargs):
"""
Expand Down

0 comments on commit 3ae28cc

Please sign in to comment.