diff --git a/src/transformers/models/align/processing_align.py b/src/transformers/models/align/processing_align.py index 8bcea7eb5dadf6..76d3502028b4df 100644 --- a/src/transformers/models/align/processing_align.py +++ b/src/transformers/models/align/processing_align.py @@ -17,8 +17,11 @@ """ +from typing import Dict, List, Optional, Union + from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import PaddingStrategy, TensorType class AlignProcessor(ProcessorMixin): @@ -42,11 +45,49 @@ class AlignProcessor(ProcessorMixin): def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) - def __call__(self, text=None, images=None, padding="max_length", max_length=64, return_tensors=None, **kwargs): + def __call__( + self, + text=None, + images=None, + do_crop_margin: bool = None, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: "PILImageResampling" = None, # noqa: F821 + do_thumbnail: bool = None, + do_align_long_axis: bool = None, + do_pad: bool = None, + do_rescale: bool = None, + rescale_factor: Union[int, float] = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821 + input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821 + text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = "max_length", + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = 64, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + ): """ Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text` - and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode - the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to + arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` arguments to EfficientNetImageProcessor's [`~EfficientNetImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring of the above two methods for more information. @@ -86,11 +127,46 @@ def __call__(self, text=None, images=None, padding="max_length", max_length=64, if text is not None: encoding = self.tokenizer( - text, padding=padding, max_length=max_length, return_tensors=return_tensors, **kwargs + text, + text_pair=text_pair, + text_target=text_target, + text_pair_target=text_pair_target, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, ) if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) + image_features = self.image_processor( + images, + do_crop_margin=do_crop_margin, + do_resize=do_resize, + size=size, + resample=resample, + do_thumbnail=do_thumbnail, + do_align_long_axis=do_align_long_axis, + do_pad=do_pad, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + return_tensors=return_tensors, + data_format=data_format, + input_data_format=input_data_format, + ) if text is not None and images is not None: encoding["pixel_values"] = image_features.pixel_values diff --git a/src/transformers/models/altclip/processing_altclip.py b/src/transformers/models/altclip/processing_altclip.py index 9518c55d40eadc..ac9764ab81e64b 100644 --- a/src/transformers/models/altclip/processing_altclip.py +++ b/src/transformers/models/altclip/processing_altclip.py @@ -16,9 +16,11 @@ Image/Text processor class for AltCLIP """ import warnings +from typing import Dict, List, Optional, Union from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import PaddingStrategy, TensorType class AltCLIPProcessor(ProcessorMixin): @@ -34,22 +36,21 @@ class AltCLIPProcessor(ProcessorMixin): The image processor is a required input. tokenizer ([`XLMRobertaTokenizerFast`], *optional*): The tokenizer is a required input. + feature_extractor ([`CLIPFeatureExtractor`], *optional*): + The feature extractor is a deprecated input. """ attributes = ["image_processor", "tokenizer"] image_processor_class = "CLIPImageProcessor" tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast") - def __init__(self, image_processor=None, tokenizer=None, **kwargs): - feature_extractor = None - if "feature_extractor" in kwargs: + def __init__(self, image_processor=None, tokenizer=None, feature_extractor=None): + if "feature_extractor": warnings.warn( "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" " instead.", FutureWarning, ) - feature_extractor = kwargs.pop("feature_extractor") - image_processor = image_processor if image_processor is not None else feature_extractor if image_processor is None: raise ValueError("You need to specify an `image_processor`.") @@ -58,7 +59,45 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): super().__init__(image_processor, tokenizer) - def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + def __call__( + self, + text=None, + images=None, + do_crop_margin: bool = None, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: "PILImageResampling" = None, # noqa: F821 + do_thumbnail: bool = None, + do_align_long_axis: bool = None, + do_pad: bool = None, + do_rescale: bool = None, + rescale_factor: Union[int, float] = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821 + input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821 + text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + ): """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to XLMRobertaTokenizerFast's [`~XLMRobertaTokenizerFast.__call__`] if `text` is not @@ -97,10 +136,47 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): raise ValueError("You have to specify either text or images. Both cannot be none.") if text is not None: - encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) + encoding = self.tokenizer( + text, + text_pair=text_pair, + text_target=text_target, + text_pair_target=text_pair_target, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + ) if images is not None: - image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) + image_features = self.image_processor( + images, + do_crop_margin=do_crop_margin, + do_resize=do_resize, + size=size, + resample=resample, + do_thumbnail=do_thumbnail, + do_align_long_axis=do_align_long_axis, + do_pad=do_pad, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + return_tensors=return_tensors, + data_format=data_format, + input_data_format=input_data_format, + ) if text is not None and images is not None: encoding["pixel_values"] = image_features.pixel_values diff --git a/src/transformers/models/blip/processing_blip.py b/src/transformers/models/blip/processing_blip.py index 3b9d5c369a4412..3a9f9332afae25 100644 --- a/src/transformers/models/blip/processing_blip.py +++ b/src/transformers/models/blip/processing_blip.py @@ -76,31 +76,7 @@ def __call__( if images is None and text is None: raise ValueError("You have to specify either images or text.") - # Get only text - if images is None: - self.current_processor = self.tokenizer - text_encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, - ) - return text_encoding - - # add pixel_values - encoding_image_processor = self.image_processor(images, return_tensors=return_tensors) + text_encoding = None if text is not None: text_encoding = self.tokenizer( @@ -121,13 +97,16 @@ def __call__( return_tensors=return_tensors, **kwargs, ) - else: - text_encoding = None - if text_encoding is not None: - encoding_image_processor.update(text_encoding) + # add pixel_values encoding. If we also have text_encoding, update image encoding and return it. + # else, return the text encoding. + if images is not None: + encoding_image_processor = self.image_processor(images, return_tensors=return_tensors) + if text_encoding is not None: + encoding_image_processor.update(text_encoding) + return encoding_image_processor - return encoding_image_processor + return text_encoding def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/blip_2/processing_blip_2.py b/src/transformers/models/blip_2/processing_blip_2.py index ff7044c82aedb6..f160e3431d3e0e 100644 --- a/src/transformers/models/blip_2/processing_blip_2.py +++ b/src/transformers/models/blip_2/processing_blip_2.py @@ -78,31 +78,7 @@ def __call__( if images is None and text is None: raise ValueError("You have to specify either images or text.") - # Get only text - if images is None: - self.current_processor = self.tokenizer - text_encoding = self.tokenizer( - text=text, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_token_type_ids=return_token_type_ids, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, - **kwargs, - ) - return text_encoding - - # add pixel_values - encoding_image_processor = self.image_processor(images, return_tensors=return_tensors) + text_encoding = None if text is not None: text_encoding = self.tokenizer( @@ -123,13 +99,16 @@ def __call__( return_tensors=return_tensors, **kwargs, ) - else: - text_encoding = None - if text_encoding is not None: - encoding_image_processor.update(text_encoding) + # add pixel_values encoding. If we also have text_encoding, update image encoding and return it. + # else, return the text encoding. + if images is not None: + encoding_image_processor = self.image_processor(images, return_tensors=return_tensors) + if text_encoding is not None: + encoding_image_processor.update(text_encoding) + return encoding_image_processor - return encoding_image_processor + return text_encoding # Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower.py b/src/transformers/models/bridgetower/image_processing_bridgetower.py index 8fc62ad3970fa0..bf3c37b9b7cd2c 100644 --- a/src/transformers/models/bridgetower/image_processing_bridgetower.py +++ b/src/transformers/models/bridgetower/image_processing_bridgetower.py @@ -166,6 +166,27 @@ class BridgeTowerImageProcessor(BaseImageProcessor): do_pad (`bool`, *optional*, defaults to `True`): Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by the `do_pad` parameter in the `preprocess` method. + pad_and_return_pixel_mask (`bool`, *optional*): + Deprecated. Whether to pad the image to the `(max_height, max_width)` of the images in the batch. + Sets do_pad. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `"channels_first"`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ model_input_names = ["pixel_values"] @@ -184,12 +205,15 @@ def __init__( do_center_crop: bool = True, crop_size: Dict[str, int] = None, do_pad: bool = True, + pad_and_return_pixel_mask: Optional[bool] = None, + return_tensors: Optional[bool] = None, + data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821 + input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821 **kwargs, ) -> None: - if "pad_and_return_pixel_mask" in kwargs: - do_pad = kwargs.pop("pad_and_return_pixel_mask") + if pad_and_return_pixel_mask: + do_pad = pad_and_return_pixel_mask - super().__init__(**kwargs) size = size if size is not None else {"shortest_edge": 288} size = get_size_dict(size, default_to_square=False) @@ -222,6 +246,7 @@ def __init__( "return_tensors", "data_format", "input_data_format", + "pad_and_return_pixel_mask", ] # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.resize @@ -407,6 +432,7 @@ def preprocess( return_tensors: Optional[Union[str, TensorType]] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_and_return_pixel_mask: Optional[bool] = None, **kwargs, ) -> PIL.Image.Image: """ @@ -464,6 +490,9 @@ def preprocess( - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + pad_and_return_pixel_mask (`bool`, *optional*, deprecated, defaults to `self.do_pad`): + Whether to pad the image to the (max_height, max_width) in the batch. If `True`, a pixel mask is also + created and returned. Deprecated version of do_pad. """ do_resize = do_resize if do_resize is not None else self.do_resize size_divisor = size_divisor if size_divisor is not None else self.size_divisor @@ -482,6 +511,7 @@ def preprocess( ) size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) diff --git a/src/transformers/models/bridgetower/processing_bridgetower.py b/src/transformers/models/bridgetower/processing_bridgetower.py index 7718c3bf833fec..56d5a756ca97ee 100644 --- a/src/transformers/models/bridgetower/processing_bridgetower.py +++ b/src/transformers/models/bridgetower/processing_bridgetower.py @@ -16,8 +16,9 @@ Processor class for BridgeTower. """ -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union +from ...image_utils import ChannelDimension, PILImageResampling from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy from ...utils import TensorType @@ -50,11 +51,27 @@ def __call__( self, images, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + size_divisor: Optional[int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + pad_and_return_pixel_mask: Optional[bool] = None, + do_center_crop: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = None, max_length: Optional[int] = None, stride: int = 0, + is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -63,8 +80,6 @@ def __call__( return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs, ) -> BatchEncoding: """ This method uses [`BridgeTowerImageProcessor.__call__`] method to prepare image(s) for the model, and @@ -79,6 +94,7 @@ def __call__( truncation=truncation, max_length=max_length, stride=stride, + is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, @@ -86,13 +102,27 @@ def __call__( return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, - verbose=verbose, return_tensors=return_tensors, - **kwargs, + verbose=verbose, ) # add pixel_values + pixel_mask encoding_image_processor = self.image_processor( - images, return_tensors=return_tensors, do_normalize=True, do_center_crop=True, **kwargs + images, + do_resize=do_resize, + size=size, + size_divisor=size_divisor, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + data_format=data_format, + input_data_format=input_data_format, + do_pad=do_pad, + pad_and_return_pixel_mask=pad_and_return_pixel_mask, + return_tensors=return_tensors, ) encoding.update(encoding_image_processor) diff --git a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py index 60f40272bf9271..d92caa79923e92 100644 --- a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py +++ b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py @@ -249,6 +249,7 @@ def preprocess( - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ + do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) diff --git a/src/transformers/models/donut/image_processing_donut.py b/src/transformers/models/donut/image_processing_donut.py index 1c6e4723139046..6444b512a81f49 100644 --- a/src/transformers/models/donut/image_processing_donut.py +++ b/src/transformers/models/donut/image_processing_donut.py @@ -86,6 +86,13 @@ class DonutImageProcessor(BaseImageProcessor): channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): Image standard deviation. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. """ model_input_names = ["pixel_values"] @@ -103,6 +110,7 @@ def __init__( do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/src/transformers/models/donut/processing_donut.py b/src/transformers/models/donut/processing_donut.py index 1f03fd6306fc0a..704b41bd88ca37 100644 --- a/src/transformers/models/donut/processing_donut.py +++ b/src/transformers/models/donut/processing_donut.py @@ -36,21 +36,21 @@ class DonutProcessor(ProcessorMixin): An instance of [`DonutImageProcessor`]. The image processor is a required input. tokenizer ([`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`], *optional*): An instance of [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. The tokenizer is a required input. + feature_extractor ([`CLIPFeatureExtractor`], *optional*): + The feature extractor is a deprecated input. """ attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor=None, tokenizer=None, **kwargs): - feature_extractor = None - if "feature_extractor" in kwargs: + def __init__(self, image_processor=None, tokenizer=None, feature_extractor=None): + if "feature_extractor": warnings.warn( "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" " instead.", FutureWarning, ) - feature_extractor = kwargs.pop("feature_extractor") image_processor = image_processor if image_processor is not None else feature_extractor if image_processor is None: @@ -60,32 +60,100 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor + + self.processing_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "text_kwargs": { + "text_pair": None, + "text_target": None, + "text_pair_target": None, + "add_special_tokens": True, + "padding": "max_length", + "truncation": True, + "max_length": 512, + "stride": 0, + "is_split_into_words": False, + "pad_to_multiple_of": None, + "return_token_type_ids": True, + "return_attention_mask": True, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": { + "do_crop_margin": False, + "do_resize": True, + "size": {"height": 256, "width": 256}, + "resample": "bilinear", + "do_thumbnail": False, + "do_align_long_axis": False, + "do_pad": False, + "do_rescale": False, + "rescale_factor": 1.0, + "do_normalize": True, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225], + "data_format": "channels_first", + "input_data_format": None, + }, + "audio_kwargs": {}, + "videos_kwargs": {}, + } + self._in_target_context_manager = False - def __call__(self, *args, **kwargs): + def __call__( + self, + text=None, + images=None, + audio=None, + videos=None, # end of supported modalities in call + **kwargs, + ): """ When used in normal mode, this method forwards all its arguments to AutoImageProcessor's [`~AutoImageProcessor.__call__`] and returns its output. If used in the context [`~DonutProcessor.as_target_processor`] this method forwards all its arguments to DonutTokenizer's - [`~DonutTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information. + [`~DonutTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information. """ # For backward compatibility if self._in_target_context_manager: - return self.current_processor(*args, **kwargs) - - images = kwargs.pop("images", None) - text = kwargs.pop("text", None) - if len(args) > 0: - images = args[0] - args = args[1:] + image_kwargs = { + **self.processing_kwargs.get("images_kwargs", {}), + **self.processing_kwargs.get("common_kwargs"), + **kwargs, + } + return self.current_processor( + images, + **image_kwargs, + ) if images is None and text is None: raise ValueError("You need to specify either an `images` or `text` input to process.") if images is not None: - inputs = self.image_processor(images, *args, **kwargs) + image_kwargs = { + **self.processing_kwargs.get("images_kwargs", {}), + **self.processing_kwargs.get("common_kwargs"), + **kwargs, + } + inputs = self.image_processor( + images, + **image_kwargs, + ) if text is not None: - encodings = self.tokenizer(text, **kwargs) + text_kwargs = { + **self.processing_kwargs.get("text_kwargs", {}), + **self.processing_kwargs.get("common_kwargs"), + **kwargs, + } + + encodings = self.tokenizer( + text, + **text_kwargs, + ) if text is None: return inputs diff --git a/src/transformers/models/wav2vec2/processing_wav2vec2.py b/src/transformers/models/wav2vec2/processing_wav2vec2.py index dc6e9d14ee66cd..47d020343fbd4d 100644 --- a/src/transformers/models/wav2vec2/processing_wav2vec2.py +++ b/src/transformers/models/wav2vec2/processing_wav2vec2.py @@ -45,6 +45,37 @@ def __init__(self, feature_extractor, tokenizer): super().__init__(feature_extractor, tokenizer) self.current_processor = self.feature_extractor self._in_target_context_manager = False + self.processing_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "text_kwargs": { + "text_pair": None, + "text_target": None, + "text_pair_target": None, + "add_special_tokens": True, + "padding": "max_length", + "truncation": True, + "max_length": 512, + "stride": 0, + "is_split_into_words": False, + "pad_to_multiple_of": None, + "return_token_type_ids": True, + "return_attention_mask": True, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": { + }, + "audio_kwargs": { + "sampling_rate": None, + "raw_speech": None, + }, + "videos_kwargs": { + }, + } + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): @@ -65,7 +96,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) - def __call__(self, *args, **kwargs): + def __call__( + self, + audio=None, + text=None, + images=None, + videos=None, # end of supported modalities in call + *deprecated_args, + **deprecated_kwargs, + ): """ When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's [`~Wav2Vec2FeatureExtractor.__call__`] and returns its output. If used in the context @@ -74,26 +113,23 @@ def __call__(self, *args, **kwargs): """ # For backward compatibility if self._in_target_context_manager: - return self.current_processor(*args, **kwargs) - - if "raw_speech" in kwargs: + return self.current_processor(audio, *deprecated_args, **self.processing_kwargs['audio_kwargs'], **deprecated_kwargs) + print(deprecated_args, deprecated_kwargs) + if "raw_speech" in deprecated_kwargs: + breakpoint() warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.") - audio = kwargs.pop("raw_speech") + audio = deprecated_kwargs.pop("raw_speech") else: - audio = kwargs.pop("audio", None) - sampling_rate = kwargs.pop("sampling_rate", None) - text = kwargs.pop("text", None) - if len(args) > 0: - audio = args[0] - args = args[1:] + audio = deprecated_kwargs.pop("audio", None) + sampling_rate = deprecated_kwargs.pop("sampling_rate", None) if audio is None and text is None: raise ValueError("You need to specify either an `audio` or `text` input to process.") if audio is not None: - inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs) + inputs = self.feature_extractor(audio, *deprecated_args, sampling_rate=sampling_rate, **deprecated_kwargs) if text is not None: - encodings = self.tokenizer(text, **kwargs) + encodings = self.tokenizer(text, *deprecated_args, **deprecated_kwargs) if text is None: return inputs diff --git a/tests/models/bridgetower/test_image_processing_bridgetower.py b/tests/models/bridgetower/test_image_processing_bridgetower.py index f8837fdc964a76..84350d34575546 100644 --- a/tests/models/bridgetower/test_image_processing_bridgetower.py +++ b/tests/models/bridgetower/test_image_processing_bridgetower.py @@ -39,7 +39,7 @@ def __init__( do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, do_normalize: bool = True, - do_center_crop: bool = True, + do_center_crop: bool = False, # Current expected shape are not center-cropped. image_mean: Optional[Union[float, List[float]]] = [0.48145466, 0.4578275, 0.40821073], image_std: Optional[Union[float, List[float]]] = [0.26862954, 0.26130258, 0.27577711], do_pad: bool = True, @@ -66,10 +66,12 @@ def __init__( def prepare_image_processor_dict(self): return { + "do_normalize": self.do_normalize, "image_mean": self.image_mean, "image_std": self.image_std, - "do_normalize": self.do_normalize, + "do_pad": self.do_pad, "do_resize": self.do_resize, + "do_center_crop": self.do_center_crop, "size": self.size, "size_divisor": self.size_divisor, } diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index 90c1a4e7e12708..074da8a2bbab38 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -291,6 +291,7 @@ def test_call_numpy_4_channels(self): ) def test_image_processor_preprocess_arguments(self): + # Test that an instantiated image processor is called with the correct arg spec image_processor = self.image_processing_class(**self.image_processor_dict) if hasattr(image_processor, "_valid_processor_keys") and hasattr(image_processor, "preprocess"): preprocess_parameter_names = inspect.getfullargspec(image_processor.preprocess).args diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 04572d132b9dd1..3143899ac476bf 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -832,7 +832,7 @@ def find_indent(line: str) -> int: def stringify_default(default: Any) -> str: """ Returns the string representation of a default value, as used in docstring: numbers are left as is, all other - objects are in backtiks. + objects are in backticks. Args: default (`Any`): The default value to process @@ -862,7 +862,7 @@ def stringify_default(default: Any) -> str: def eval_math_expression(expression: str) -> Optional[Union[float, int]]: # Mainly taken from the excellent https://stackoverflow.com/a/9558001 """ - Evaluate (safely) a mathematial expression and returns its value. + Evaluate (safely) a mathematical expression and returns its value. Args: expression (`str`): The expression to evaluate.