diff --git a/src/transformers/models/blip/processing_blip.py b/src/transformers/models/blip/processing_blip.py index cd96b46ab1d26f..78e1aa58ef0443 100644 --- a/src/transformers/models/blip/processing_blip.py +++ b/src/transformers/models/blip/processing_blip.py @@ -19,9 +19,25 @@ from typing import List, Optional, Union from ...image_utils import ImageInput -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput + + +class BlipProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": {}, + } class BlipProcessor(ProcessorMixin): @@ -51,84 +67,53 @@ def __init__(self, image_processor, tokenizer, **kwargs): def __call__( self, images: ImageInput = None, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_token_type_ids: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, + text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[BlipProcessorKwargs], ) -> BatchEncoding: """ This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and [`BertTokenizerFast.__call__`] to prepare text for the model. Please refer to the docstring of the above two methods for more information. + Args: + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. """ if images is None and text is None: raise ValueError("You have to specify either images or text.") - # Get only text - if images is None: - self.current_processor = self.tokenizer - text_encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, - ) - return text_encoding - - # add pixel_values - encoding_image_processor = self.image_processor(images, return_tensors=return_tensors) + text_encoding = None + # add pixel_values encoding. If we also have text_encoding, update image encoding and return it. + # else, return the text encoding. + output_kwargs = self._merge_kwargs( + BlipProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) if text is not None: - text_encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, - ) - else: - text_encoding = None - - if text_encoding is not None: - encoding_image_processor.update(text_encoding) - - return encoding_image_processor + text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"]) + if images is not None: + encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"]) + + if text_encoding is not None: + encoding_image_processor.update(text_encoding) + return encoding_image_processor + + return text_encoding def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/blip_2/processing_blip_2.py b/src/transformers/models/blip_2/processing_blip_2.py index e879b41eb15643..606aadc1eab45f 100644 --- a/src/transformers/models/blip_2/processing_blip_2.py +++ b/src/transformers/models/blip_2/processing_blip_2.py @@ -18,22 +18,38 @@ from typing import List, Optional, Union +from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessorMixin +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import ( AddedToken, BatchEncoding, - PaddingStrategy, PreTokenizedInput, TextInput, - TruncationStrategy, ) -from ...utils import TensorType, logging +from ...utils import logging logger = logging.get_logger(__name__) +class Blip2ProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_token_type_ids": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": {}, + } + + class Blip2Processor(ProcessorMixin): r""" Constructs a BLIP-2 processor which wraps a BLIP image processor and an OPT/T5 tokenizer into a single processor. @@ -67,58 +83,44 @@ def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs): def __call__( self, images: ImageInput = None, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_token_type_ids: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, + text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Blip2ProcessorKwargs], ) -> BatchEncoding: """ This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and [`BertTokenizerFast.__call__`] to prepare text for the model. Please refer to the docstring of the above two methods for more information. + Args: + images (`ImageInput`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. """ if images is None and text is None: raise ValueError("You have to specify either images or text.") - - # Get only text - if images is None: - self.current_processor = self.tokenizer - text_encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, - ) - return text_encoding - - # add pixel_values - encoding_image_processor = self.image_processor(images, return_tensors=return_tensors) - + output_kwargs = self._merge_kwargs( + Blip2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + # BC for explicit return_tensors + if "return_tensors" in output_kwargs["common_kwargs"]: + return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None) + else: + return_tensors = None + encoding = BatchFeature(tensor_type=return_tensors) if text is not None: if isinstance(text, str): text = [text] @@ -126,24 +128,10 @@ def __call__( raise ValueError("Invalid input text. Please provide a string, or a list of strings") text_encoding = {} - _text_encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=None, # hardcode "None" here for prepending image tokens - **kwargs, - ) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + _text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None) + output_kwargs["text_kwargs"]["return_tensors"] = return_tensors # if we know how many query tokens, expand text inside processor. We need this hacky manipulation # because BLIP expects image tokens to be at the beginning even before BOS token @@ -164,14 +152,14 @@ def __call__( ) # cast to desired return tensors type - text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors) - else: - text_encoding = None - - if text_encoding is not None: - encoding_image_processor.update(text_encoding) - - return encoding_image_processor + encoding.update(BatchEncoding(text_encoding, tensor_type=return_tensors)) + # add pixel_values encoding. If we also have text_encoding, update image encoding and return it. + # else, return the text encoding. + + if images is not None: + image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"]) + encoding.update(image_encoding) + return encoding # Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/bridgetower/processing_bridgetower.py b/src/transformers/models/bridgetower/processing_bridgetower.py index 7718c3bf833fec..177eb12051654d 100644 --- a/src/transformers/models/bridgetower/processing_bridgetower.py +++ b/src/transformers/models/bridgetower/processing_bridgetower.py @@ -16,11 +16,29 @@ Processor class for BridgeTower. """ -from typing import List, Optional, Union +from typing import List, Union -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput + + +class BridgeTowerProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": { + "do_normalize": True, + "do_center_crop": True, + }, + } class BridgeTowerProcessor(ProcessorMixin): @@ -50,21 +68,9 @@ def __call__( self, images, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, + audio=None, + videos=None, + **kwargs: Unpack[BridgeTowerProcessorKwargs], ) -> BatchEncoding: """ This method uses [`BridgeTowerImageProcessor.__call__`] method to prepare image(s) for the model, and @@ -72,28 +78,14 @@ def __call__( Please refer to the docstring of the above two methods for more information. """ - encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, + output_kwargs = self._merge_kwargs( + BridgeTowerProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) + encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) # add pixel_values + pixel_mask - encoding_image_processor = self.image_processor( - images, return_tensors=return_tensors, do_normalize=True, do_center_crop=True, **kwargs - ) + encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"]) encoding.update(encoding_image_processor) return encoding diff --git a/src/transformers/models/donut/processing_donut.py b/src/transformers/models/donut/processing_donut.py index daf6e7d1dfe4ab..9552d323ac57c0 100644 --- a/src/transformers/models/donut/processing_donut.py +++ b/src/transformers/models/donut/processing_donut.py @@ -19,8 +19,15 @@ import re import warnings from contextlib import contextmanager +from typing import List, Optional, Union -from ...processing_utils import ProcessorMixin +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class DonutProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} class DonutProcessor(ProcessorMixin): @@ -63,7 +70,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): self.current_processor = self.image_processor self._in_target_context_manager = False - def __call__(self, *args, **kwargs): + def __call__( + self, + images: ImageInput = None, + text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[DonutProcessorKwargs], + ): """ When used in normal mode, this method forwards all its arguments to AutoImageProcessor's [`~AutoImageProcessor.__call__`] and returns its output. If used in the context @@ -72,28 +86,29 @@ def __call__(self, *args, **kwargs): """ # For backward compatibility if self._in_target_context_manager: - return self.current_processor(*args, **kwargs) - - images = kwargs.pop("images", None) - text = kwargs.pop("text", None) - if len(args) > 0: - images = args[0] - args = args[1:] + return self.current_processor(images, text, **kwargs) if images is None and text is None: raise ValueError("You need to specify either an `images` or `text` input to process.") + output_kwargs = self._merge_kwargs( + DonutProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: - inputs = self.image_processor(images, *args, **kwargs) + inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) if text is not None: - encodings = self.tokenizer(text, **kwargs) + encodings = self.tokenizer(text, **output_kwargs["text_kwargs"]) if text is None: return inputs elif images is None: return encodings else: - inputs["labels"] = encodings["input_ids"] + inputs["labels"] = encodings["input_ids"] # for BC + inputs["input_ids"] = encodings["input_ids"] return inputs def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/idefics3/processing_idefics3.py b/src/transformers/models/idefics3/processing_idefics3.py index 10042ca4529204..1fc9a25fd8e9a2 100644 --- a/src/transformers/models/idefics3/processing_idefics3.py +++ b/src/transformers/models/idefics3/processing_idefics3.py @@ -17,12 +17,11 @@ """ import re -import sys from typing import TYPE_CHECKING, Dict, List, Optional, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image, load_image -from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import AddedToken, BatchEncoding, TextInput from ...utils import logging @@ -30,11 +29,6 @@ if TYPE_CHECKING: from ...tokenization_utils_base import PreTokenizedInput -if sys.version_info >= (3, 11): - from typing import Unpack -else: - from typing_extensions import Unpack - logger = logging.get_logger(__name__) diff --git a/src/transformers/models/mllama/processing_mllama.py b/src/transformers/models/mllama/processing_mllama.py index 84c8eea466e75b..dbfea1a90deab8 100644 --- a/src/transformers/models/mllama/processing_mllama.py +++ b/src/transformers/models/mllama/processing_mllama.py @@ -19,19 +19,9 @@ import numpy as np - -try: - from typing import Unpack -except ImportError: - from typing_extensions import Unpack - from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ( - ImagesKwargs, - ProcessingKwargs, - ProcessorMixin, -) +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import ( BatchEncoding, PreTokenizedInput, diff --git a/src/transformers/models/omdet_turbo/processing_omdet_turbo.py b/src/transformers/models/omdet_turbo/processing_omdet_turbo.py index 909281b0c6867a..e4e2e5197f2e2c 100644 --- a/src/transformers/models/omdet_turbo/processing_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/processing_omdet_turbo.py @@ -16,13 +16,12 @@ Processor class for OmDet-Turbo. """ -import sys from typing import List, Optional, Tuple, Union from ...feature_extraction_utils import BatchFeature from ...image_transforms import center_to_corners_format from ...image_utils import ImageInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import ( TensorType, @@ -31,12 +30,6 @@ ) -if sys.version_info >= (3, 11): - from typing import Unpack -else: - from typing_extensions import Unpack - - class OmDetTurboTextKwargs(TextKwargs, total=False): task: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] diff --git a/src/transformers/models/wav2vec2/processing_wav2vec2.py b/src/transformers/models/wav2vec2/processing_wav2vec2.py index 6fe960c78eb10b..4bd4255315fdc5 100644 --- a/src/transformers/models/wav2vec2/processing_wav2vec2.py +++ b/src/transformers/models/wav2vec2/processing_wav2vec2.py @@ -18,12 +18,18 @@ import warnings from contextlib import contextmanager +from typing import List, Optional, Union -from ...processing_utils import ProcessorMixin +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer +class Wav2Vec2ProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} + + class Wav2Vec2Processor(ProcessorMixin): r""" Constructs a Wav2Vec2 processor which wraps a Wav2Vec2 feature extractor and a Wav2Vec2 CTC tokenizer into a single @@ -66,35 +72,46 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) - def __call__(self, *args, **kwargs): + def __call__( + self, + audio: AudioInput = None, + text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None, + images=None, + videos=None, + **kwargs: Unpack[Wav2Vec2ProcessorKwargs], + ): """ When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's [`~Wav2Vec2FeatureExtractor.__call__`] and returns its output. If used in the context [`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information. """ - # For backward compatibility - if self._in_target_context_manager: - return self.current_processor(*args, **kwargs) if "raw_speech" in kwargs: warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.") audio = kwargs.pop("raw_speech") - else: - audio = kwargs.pop("audio", None) - sampling_rate = kwargs.pop("sampling_rate", None) - text = kwargs.pop("text", None) - if len(args) > 0: - audio = args[0] - args = args[1:] if audio is None and text is None: raise ValueError("You need to specify either an `audio` or `text` input to process.") + output_kwargs = self._merge_kwargs( + Wav2Vec2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + # For backward compatibility + if self._in_target_context_manager: + return self.current_processor( + audio, + **output_kwargs["audio_kwargs"], + **output_kwargs["text_kwargs"], + **output_kwargs["common_kwargs"], + ) + if audio is not None: - inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs) + inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) if text is not None: - encodings = self.tokenizer(text, **kwargs) + encodings = self.tokenizer(text, **output_kwargs["text_kwargs"]) if text is None: return inputs diff --git a/src/transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py index d24c672007d734..8b09e92419ae97 100644 --- a/src/transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py @@ -17,12 +17,18 @@ """ import warnings +from typing import List, Optional, Union -from ...processing_utils import ProcessorMixin +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput from ..seamless_m4t.feature_extraction_seamless_m4t import SeamlessM4TFeatureExtractor from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer +class Wav2Vec2BertProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} + + class Wav2Vec2BertProcessor(ProcessorMixin): r""" Constructs a Wav2Vec2-BERT processor which wraps a Wav2Vec2-BERT feature extractor and a Wav2Vec2 CTC tokenizer into a single @@ -63,7 +69,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) - def __call__(self, audio=None, text=None, **kwargs): + def __call__( + self, + audio: AudioInput = None, + text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None, + images=None, + videos=None, + **kwargs: Unpack[Wav2Vec2BertProcessorKwargs], + ): """ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `audio` and `kwargs` arguments to SeamlessM4TFeatureExtractor's [`~SeamlessM4TFeatureExtractor.__call__`] if `audio` is not @@ -71,17 +84,15 @@ def __call__(self, audio=None, text=None, **kwargs): PreTrainedTokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not `None`. Please refer to the doctsring of the above two methods for more information. Args: - text (`str`, `List[str]`, `List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): The audio or batch of audios to be prepared. Each audio can be NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the sample length of the audio. - kwargs (*optional*): - Remaining dictionary of keyword arguments that will be passed to the feature extractor and/or the - tokenizer. + + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: - **input_features** -- Audio input features to be fed to a model. Returned when `audio` is not `None`. @@ -91,15 +102,18 @@ def __call__(self, audio=None, text=None, **kwargs): - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None` and `audio` is `None`. """ - sampling_rate = kwargs.pop("sampling_rate", None) - if audio is None and text is None: raise ValueError("You need to specify either an `audio` or `text` input to process.") + output_kwargs = self._merge_kwargs( + Wav2Vec2BertProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) if audio is not None: - inputs = self.feature_extractor(audio, sampling_rate=sampling_rate, **kwargs) + inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) if text is not None: - encodings = self.tokenizer(text, **kwargs) + encodings = self.tokenizer(text, **output_kwargs["text_kwargs"]) if text is None: return inputs diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 58eb3e6ed6f206..062dfe311c1dca 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -820,6 +820,8 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg "common_kwargs": {}, } + used_keys = set() + # get defaults from set model processor kwargs if they exist for modality in default_kwargs: default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy() @@ -846,18 +848,29 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg f"in a dictionary for {modality} and as a **kwarg." ) elif modality_key in kwargs: - kwarg_value = kwargs.pop(modality_key, "__empty__") + # we get a modality_key instead of popping it because modality-specific processors + # can have overlapping kwargs + kwarg_value = kwargs.get(modality_key, "__empty__") else: kwarg_value = "__empty__" if kwarg_value != "__empty__": output_kwargs[modality][modality_key] = kwarg_value - # if something remains in kwargs, it belongs to common after flattening - if set(kwargs) & set(default_kwargs): - # here kwargs is dictionary-based since it shares keys with default set - [output_kwargs["common_kwargs"].update(subdict) for _, subdict in kwargs.items()] + used_keys.add(modality_key) + + # Determine if kwargs is a flat dictionary or contains nested dictionaries + if any(key in default_kwargs for key in kwargs): + # kwargs is dictionary-based, and some keys match modality names + for modality, subdict in kwargs.items(): + if modality in default_kwargs: + for subkey, subvalue in subdict.items(): + if subkey not in used_keys: + output_kwargs[modality][subkey] = subvalue + used_keys.add(subkey) else: - # here it's a flat dict - output_kwargs["common_kwargs"].update(kwargs) + # kwargs is a flat dictionary + for key in kwargs: + if key not in used_keys: + output_kwargs["common_kwargs"][key] = kwargs[key] # all modality-specific kwargs are updated with common kwargs for modality in output_kwargs: diff --git a/tests/models/altclip/test_processor_altclip.py b/tests/models/altclip/test_processor_altclip.py index 33bff9c77ad263..5b290efb115031 100644 --- a/tests/models/altclip/test_processor_altclip.py +++ b/tests/models/altclip/test_processor_altclip.py @@ -17,17 +17,12 @@ import tempfile import unittest -from transformers import XLMRobertaTokenizer, XLMRobertaTokenizerFast +from transformers import AltCLIPProcessor, CLIPImageProcessor, XLMRobertaTokenizer, XLMRobertaTokenizerFast from transformers.testing_utils import require_vision -from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin -if is_vision_available(): - from transformers import AltCLIPProcessor, CLIPImageProcessor - - @require_vision class AltClipProcessorTest(ProcessorTesterMixin, unittest.TestCase): processor_class = AltCLIPProcessor diff --git a/tests/models/blip/test_processor_blip.py b/tests/models/blip/test_processor_blip.py index 7b851c618a773d..4d22c6527c07b1 100644 --- a/tests/models/blip/test_processor_blip.py +++ b/tests/models/blip/test_processor_blip.py @@ -17,7 +17,7 @@ import pytest -from transformers.testing_utils import require_vision +from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -139,3 +139,29 @@ def test_model_input_names(self): # For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask'] self.assertListEqual(list(inputs.keys()), ["pixel_values", "input_ids", "attention_mask"]) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"height": 214, "width": 214}, + size={"height": 214, "width": 214}, + padding="longest", + max_length=76, + ) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 24) diff --git a/tests/models/blip_2/test_processor_blip_2.py b/tests/models/blip_2/test_processor_blip_2.py index 8c7ca2ab698f48..7151be8ac71200 100644 --- a/tests/models/blip_2/test_processor_blip_2.py +++ b/tests/models/blip_2/test_processor_blip_2.py @@ -17,7 +17,7 @@ import pytest -from transformers.testing_utils import require_vision +from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -94,7 +94,7 @@ def test_tokenizer(self): encoded_tok = tokenizer(input_str, return_token_type_ids=False) for key in encoded_tok.keys(): - self.assertListEqual(encoded_tok[key], encoded_processor[key]) + self.assertListEqual(encoded_tok[key], encoded_processor[key][0]) def test_processor(self): image_processor = self.get_image_processor() @@ -107,7 +107,7 @@ def test_processor(self): inputs = processor(text=input_str, images=image_input) - self.assertListEqual(list(inputs.keys()), ["pixel_values", "input_ids", "attention_mask"]) + self.assertCountEqual(list(inputs.keys()), ["input_ids", "pixel_values", "attention_mask"]) # test if it raises when no input is passed with pytest.raises(ValueError): @@ -138,4 +138,31 @@ def test_model_input_names(self): inputs = processor(text=input_str, images=image_input) # For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask'] - self.assertListEqual(list(inputs.keys()), ["pixel_values", "input_ids", "attention_mask"]) + self.assertCountEqual(list(inputs.keys()), ["input_ids", "pixel_values", "attention_mask"]) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"height": 214, "width": 214}, + size={"height": 214, "width": 214}, + padding="longest", + max_length=76, + ) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 11) diff --git a/tests/models/bridgetower/test_processing_bridgetower.py b/tests/models/bridgetower/test_processing_bridgetower.py new file mode 100644 index 00000000000000..19902a1cc57f3b --- /dev/null +++ b/tests/models/bridgetower/test_processing_bridgetower.py @@ -0,0 +1,218 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import shutil +import tempfile +import unittest + +import numpy as np + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from PIL import Image + + from transformers import ( + AutoProcessor, + BridgeTowerImageProcessor, + BridgeTowerProcessor, + RobertaTokenizerFast, + ) + + +@require_vision +class Blip2ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = BridgeTowerProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + image_processor = BridgeTowerImageProcessor() + tokenizer = RobertaTokenizerFast.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") + + processor = BridgeTowerProcessor(image_processor, tokenizer) + + processor.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + + # Some kwargs tests are overriden from common tests to handle shortest_edge + # and size_divisor behaviour + + @require_torch + @require_vision + def test_image_processor_defaults_preserved_by_image_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component( + "image_processor", + crop_size={"shortest_edge": 234, "longest_edge": 234}, + ) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + self.assertEqual(len(inputs["pixel_values"][0][0]), 234) + + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": { + "crop_size": {"shortest_edge": 214}, + }, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_kwargs_overrides_default_image_processor_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", crop_size={"shortest_edge": 234}) + tokenizer = self.get_component("tokenizer", max_length=117) + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor(text=input_str, images=image_input, crop_size={"shortest_edge": 224}) + self.assertEqual(len(inputs["pixel_values"][0][0]), 224) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"shortest_edge": 214}, + padding="longest", + max_length=76, + ) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 6) + + @require_torch + @require_vision + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"shortest_edge": 214}, + padding="max_length", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"crop_size": {"shortest_edge": 214}}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 76) diff --git a/tests/models/donut/test_processing_donut.py b/tests/models/donut/test_processing_donut.py index ace3a109dfbb23..87cdb41a02c7bb 100644 --- a/tests/models/donut/test_processing_donut.py +++ b/tests/models/donut/test_processing_donut.py @@ -14,16 +14,32 @@ # limitations under the License. +import tempfile import unittest -from transformers import DonutProcessor +from transformers import DonutImageProcessor, DonutProcessor, XLMRobertaTokenizerFast +from transformers.testing_utils import ( + require_torch, + require_vision, +) +from ...test_processing_common import ProcessorTesterMixin -class DonutProcessorTest(unittest.TestCase): + +class DonutProcessorTest(ProcessorTesterMixin, unittest.TestCase): from_pretrained_id = "naver-clova-ix/donut-base" + processor_class = DonutProcessor def setUp(self): self.processor = DonutProcessor.from_pretrained(self.from_pretrained_id) + self.tmpdirname = tempfile.mkdtemp() + + image_processor = DonutImageProcessor() + tokenizer = XLMRobertaTokenizerFast.from_pretrained(self.from_pretrained_id) + + processor = DonutProcessor(image_processor, tokenizer) + + processor.save_pretrained(self.tmpdirname) def test_token2json(self): expected_json = { @@ -49,3 +65,30 @@ def test_token2json(self): actual_json = self.processor.token2json(sequence) self.assertDictEqual(actual_json, expected_json) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"height": 214, "width": 214}, + size={"height": 214, "width": 214}, + padding="longest", + max_length=76, + ) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 7) diff --git a/tests/models/wav2vec2/test_processor_wav2vec2.py b/tests/models/wav2vec2/test_processor_wav2vec2.py index 67883618ca86e9..30c9243e8e4f55 100644 --- a/tests/models/wav2vec2/test_processor_wav2vec2.py +++ b/tests/models/wav2vec2/test_processor_wav2vec2.py @@ -18,14 +18,19 @@ import tempfile import unittest +import numpy as np + from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES from transformers.utils import FEATURE_EXTRACTOR_NAME +from ...test_processing_common import ProcessorTesterMixin from .test_feature_extraction_wav2vec2 import floats_list -class Wav2Vec2ProcessorTest(unittest.TestCase): +class Wav2Vec2ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Wav2Vec2Processor + def setUp(self): vocab = " | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ") vocab_tokens = dict(zip(vocab, range(len(vocab)))) @@ -53,6 +58,9 @@ def setUp(self): with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: fp.write(json.dumps(feature_extractor_map) + "\n") + tokenizer = self.get_tokenizer() + tokenizer.save_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs_init): kwargs = self.add_kwargs_tokens_map.copy() kwargs.update(kwargs_init) @@ -117,7 +125,6 @@ def test_tokenizer(self): processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor) input_str = "This is a test string" - encoded_processor = processor(text=input_str) encoded_tok = tokenizer(input_str) @@ -125,6 +132,22 @@ def test_tokenizer(self): for key in encoded_tok.keys(): self.assertListEqual(encoded_tok[key], encoded_processor[key]) + def test_padding_argument_not_ignored(self): + # padding, or any other overlap arg between audio extractor and tokenizer + # should be passed to both text and audio and not ignored + + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor) + batch_duration_in_seconds = [1, 3, 2, 6] + input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds] + + # padding = True should not raise an error and will if the audio processor popped its value to None + _ = processor( + input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt" + ) + def test_tokenizer_decode(self): feature_extractor = self.get_feature_extractor() tokenizer = self.get_tokenizer() diff --git a/tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py b/tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py index b6b1506f5e4d68..704d087a56a8e3 100644 --- a/tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py +++ b/tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py @@ -18,17 +18,21 @@ import tempfile import unittest +import numpy as np + from transformers.models.seamless_m4t import SeamlessM4TFeatureExtractor from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES from transformers.models.wav2vec2_bert import Wav2Vec2BertProcessor from transformers.utils import FEATURE_EXTRACTOR_NAME +from ...test_processing_common import ProcessorTesterMixin from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list -# Copied from tests.models.wav2vec2.test_processor_wav2vec2.Wav2Vec2ProcessorTest with Wav2Vec2FeatureExtractor->SeamlessM4TFeatureExtractor, Wav2Vec2Processor->Wav2Vec2BertProcessor -class Wav2Vec2BertProcessorTest(unittest.TestCase): +class Wav2Vec2BertProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Wav2Vec2BertProcessor + def setUp(self): vocab = " | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ") vocab_tokens = dict(zip(vocab, range(len(vocab)))) @@ -40,7 +44,7 @@ def setUp(self): "eos_token": "", } feature_extractor_map = { - "feature_size": 1, + "feature_size": 80, "padding_value": 0.0, "sampling_rate": 16000, "return_attention_mask": False, @@ -56,6 +60,9 @@ def setUp(self): with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: fp.write(json.dumps(feature_extractor_map) + "\n") + tokenizer = self.get_tokenizer() + tokenizer.save_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs_init): kwargs = self.add_kwargs_tokens_map.copy() kwargs.update(kwargs_init) @@ -122,7 +129,6 @@ def test_tokenizer(self): processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) input_str = "This is a test string" - encoded_processor = processor(text=input_str) encoded_tok = tokenizer(input_str) @@ -130,6 +136,22 @@ def test_tokenizer(self): for key in encoded_tok.keys(): self.assertListEqual(encoded_tok[key], encoded_processor[key]) + def test_padding_argument_not_ignored(self): + # padding, or any other overlap arg between audio extractor and tokenizer + # should be passed to both text and audio and not ignored + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + batch_duration_in_seconds = [1, 3, 2, 6] + input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds] + + # padding = True should not raise an error and will if the audio processor popped its value to None + # processor(input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt") + _ = processor( + input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt" + ) + def test_tokenizer_decode(self): feature_extractor = self.get_feature_extractor() tokenizer = self.get_tokenizer() diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 187cf50c733cb6..9f0d88089129b8 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -16,6 +16,7 @@ import inspect import json +import random import tempfile from typing import Optional @@ -31,11 +32,7 @@ from transformers.utils import is_vision_available -try: - from typing import Unpack -except ImportError: - from typing_extensions import Unpack - +global_rng = random.Random() if is_vision_available(): from PIL import Image @@ -48,6 +45,21 @@ def prepare_image_inputs(): return image_inputs +# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list +def floats_list(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + values = [] + for batch_idx in range(shape[0]): + values.append([]) + for _ in range(shape[1]): + values[-1].append(rng.random() * scale) + + return values + + @require_torch @require_vision class ProcessorTesterMixin: @@ -333,6 +345,135 @@ def test_structured_kwargs_nested_from_dict(self): self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0) self.assertEqual(inputs[self.text_input_name].shape[-1], 76) + # text + audio kwargs testing + @require_torch + def test_tokenizer_defaults_preserved_by_kwargs_audio(self): + if "feature_extractor" not in self.processor_class.attributes: + self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") + feature_extractor = self.get_component("feature_extractor") + if hasattr(self, "get_tokenizer"): + tokenizer = self.get_tokenizer(max_length=117, padding="max_length") + elif hasattr(self, "get_component"): + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + else: + self.assertTrue(False, "Processor doesn't have get_tokenizer or get_component defined") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + raw_speech = floats_list((3, 1000)) + inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt") + if "input_ids" in inputs: + self.assertEqual(len(inputs["input_ids"][0]), 117) + elif "labels" in inputs: + self.assertEqual(len(inputs["labels"][0]), 117) + + @require_torch + def test_kwargs_overrides_default_tokenizer_kwargs_audio(self): + if "feature_extractor" not in self.processor_class.attributes: + self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") + feature_extractor = self.get_component("feature_extractor") + if hasattr(self, "get_tokenizer"): + tokenizer = self.get_tokenizer(max_length=117) + elif hasattr(self, "get_component"): + tokenizer = self.get_component("tokenizer", max_length=117) + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + raw_speech = floats_list((3, 1000)) + inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=112, padding="max_length") + if "input_ids" in inputs: + self.assertEqual(len(inputs["input_ids"][0]), 112) + elif "labels" in inputs: + self.assertEqual(len(inputs["labels"][0]), 112) + + @require_torch + def test_unstructured_kwargs_audio(self): + if "feature_extractor" not in self.processor_class.attributes: + self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") + feature_extractor = self.get_component("feature_extractor") + if hasattr(self, "get_tokenizer"): + tokenizer = self.get_tokenizer(max_length=117) + elif hasattr(self, "get_component"): + tokenizer = self.get_component("tokenizer", max_length=117) + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + raw_speech = floats_list((3, 1000)) + inputs = processor( + text=input_str, + audio=raw_speech, + return_tensors="pt", + padding="max_length", + max_length=76, + ) + + if "input_ids" in inputs: + self.assertEqual(len(inputs["input_ids"][0]), 76) + elif "labels" in inputs: + self.assertEqual(len(inputs["labels"][0]), 76) + + @require_torch + def test_doubly_passed_kwargs_audio(self): + if "feature_extractor" not in self.processor_class.attributes: + self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") + feature_extractor = self.get_component("feature_extractor") + if hasattr(self, "get_tokenizer"): + tokenizer = self.get_tokenizer() + elif hasattr(self, "get_component"): + tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer"] + raw_speech = floats_list((3, 1000)) + with self.assertRaises(ValueError): + _ = processor( + text=input_str, + audio=raw_speech, + audio_kwargs={"padding": "max_length"}, + padding="max_length", + ) + + @require_torch + @require_vision + def test_structured_kwargs_audio_nested(self): + if "feature_extractor" not in self.processor_class.attributes: + self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") + feature_extractor = self.get_component("feature_extractor") + if hasattr(self, "get_tokenizer"): + tokenizer = self.get_tokenizer() + elif hasattr(self, "get_component"): + tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer"] + raw_speech = floats_list((3, 1000)) + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + "audio_kwargs": {"padding": "max_length", "max_length": 66}, + } + + inputs = processor(text=input_str, audio=raw_speech, **all_kwargs) + if "input_ids" in inputs: + self.assertEqual(len(inputs["input_ids"][0]), 76) + elif "labels" in inputs: + self.assertEqual(len(inputs["labels"][0]), 76) + # TODO: the same test, but for audio + text processors that have strong overlap in kwargs # TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication def test_overlapping_text_kwargs_handling(self):