From 9d6f0ddcec215b24006c74acb7875fd2706a3a84 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 28 Nov 2024 10:04:05 -0500 Subject: [PATCH] Add optimized `PixtralImageProcessorFast` (#34836) * Add optimized PixtralImageProcessorFast * make style * Add dummy_vision_object * Review comments * Format * Fix dummy * Format * np.ceil for math.ceil --- docs/source/en/_config.py | 2 +- docs/source/en/model_doc/pixtral.md | 5 + src/transformers/__init__.py | 2 + src/transformers/image_utils.py | 39 ++ .../models/auto/image_processing_auto.py | 2 +- src/transformers/models/pixtral/__init__.py | 24 +- .../pixtral/image_processing_pixtral.py | 9 +- .../pixtral/image_processing_pixtral_fast.py | 349 ++++++++++++++++++ .../utils/dummy_torchvision_objects.py | 7 + .../pixtral/test_image_processing_pixtral.py | 195 ++++++---- 10 files changed, 560 insertions(+), 74 deletions(-) create mode 100644 src/transformers/models/pixtral/image_processing_pixtral_fast.py diff --git a/docs/source/en/_config.py b/docs/source/en/_config.py index 4381def017ddc5..f49e4e4731965a 100644 --- a/docs/source/en/_config.py +++ b/docs/source/en/_config.py @@ -11,4 +11,4 @@ "{processor_class}": "FakeProcessorClass", "{model_class}": "FakeModelClass", "{object_class}": "FakeObjectClass", -} \ No newline at end of file +} diff --git a/docs/source/en/model_doc/pixtral.md b/docs/source/en/model_doc/pixtral.md index ab604e4521fc73..62bdc004c51718 100644 --- a/docs/source/en/model_doc/pixtral.md +++ b/docs/source/en/model_doc/pixtral.md @@ -88,6 +88,11 @@ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up [[autodoc]] PixtralImageProcessor - preprocess +## PixtralImageProcessorFast + +[[autodoc]] PixtralImageProcessorFast + - preprocess + ## PixtralProcessor [[autodoc]] PixtralProcessor diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fa54ced6a13486..9db2e2c51f6c9c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1260,6 +1260,7 @@ _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"] _import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast") _import_structure["models.detr"].append("DetrImageProcessorFast") + _import_structure["models.pixtral"].append("PixtralImageProcessorFast") _import_structure["models.rt_detr"].append("RTDetrImageProcessorFast") _import_structure["models.vit"].append("ViTImageProcessorFast") @@ -6189,6 +6190,7 @@ from .image_processing_utils_fast import BaseImageProcessorFast from .models.deformable_detr import DeformableDetrImageProcessorFast from .models.detr import DetrImageProcessorFast + from .models.pixtral import PixtralImageProcessorFast from .models.rt_detr import RTDetrImageProcessorFast from .models.vit import ViTImageProcessorFast diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index f59b99b490d38d..51199d9f3698fc 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -24,6 +24,7 @@ from .utils import ( ExplicitEnum, + TensorType, is_jax_tensor, is_numpy_array, is_tf_tensor, @@ -447,6 +448,44 @@ def validate_preprocess_arguments( raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.") +def validate_fast_preprocess_arguments( + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + size_divisibility: Optional[int] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[Dict[str, int]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: Optional["PILImageResampling"] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, +): + """ + Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method. + Raises `ValueError` if arguments incompatibility is caught. + """ + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + # Extra checks for ImageProcessorFast + if return_tensors != "pt": + raise ValueError("Only returning PyTorch tensors is currently supported.") + + if data_format != ChannelDimension.FIRST: + raise ValueError("Only channel first data format is currently supported.") + + # In the future we can add a TF implementation here when we have TF models. class ImageFeatureExtractionMixin: """ diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 0b180272bdb085..11ae15ca461e79 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -117,7 +117,7 @@ ("paligemma", ("SiglipImageProcessor",)), ("perceiver", ("PerceiverImageProcessor",)), ("pix2struct", ("Pix2StructImageProcessor",)), - ("pixtral", ("PixtralImageProcessor",)), + ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")), ("poolformer", ("PoolFormerImageProcessor",)), ("pvt", ("PvtImageProcessor",)), ("pvt_v2", ("PvtImageProcessor",)), diff --git a/src/transformers/models/pixtral/__init__.py b/src/transformers/models/pixtral/__init__.py index 128fd3ebe0485a..400a52a8adf2a1 100644 --- a/src/transformers/models/pixtral/__init__.py +++ b/src/transformers/models/pixtral/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_torchvision_available, + is_vision_available, +) _import_structure = { @@ -41,6 +47,14 @@ else: _import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"] +try: + if not is_torchvision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_pixtral_fast"] = ["PixtralImageProcessorFast"] + if TYPE_CHECKING: from .configuration_pixtral import PixtralVisionConfig @@ -65,6 +79,14 @@ else: from .image_processing_pixtral import PixtralImageProcessor + try: + if not is_torchvision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_pixtral_fast import PixtralImageProcessorFast + else: import sys diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py index b4ec0e50c9ccc3..3f3978e1934f5d 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral.py +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -14,6 +14,7 @@ # limitations under the License. """Image processor class for Pixtral.""" +import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -179,7 +180,7 @@ def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int]) def get_resize_output_image_size( - input_image: np.ndarray, + input_image: ImageInput, size: Union[int, Tuple[int, int], List[int], Tuple[int]], patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]], input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -189,7 +190,7 @@ def get_resize_output_image_size( size. Args: - input_image (`np.ndarray`): + input_image (`ImageInput`): The image to resize. size (`int` or `Tuple[int, int]`): Max image size an input image can be. Must be a dictionary with the key "longest_edge". @@ -210,8 +211,8 @@ def get_resize_output_image_size( if ratio > 1: # Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results - height = int(np.ceil(height / ratio)) - width = int(np.ceil(width / ratio)) + height = int(math.ceil(height / ratio)) + width = int(math.ceil(width / ratio)) num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width)) return num_height_tokens * patch_height, num_width_tokens * patch_width diff --git a/src/transformers/models/pixtral/image_processing_pixtral_fast.py b/src/transformers/models/pixtral/image_processing_pixtral_fast.py new file mode 100644 index 00000000000000..82fbf3b2c094a6 --- /dev/null +++ b/src/transformers/models/pixtral/image_processing_pixtral_fast.py @@ -0,0 +1,349 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Pixtral.""" + +from typing import Dict, List, Optional, Union + +from ...image_processing_utils import get_size_dict +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_utils import ( + ChannelDimension, + ImageInput, + ImageType, + PILImageResampling, + get_image_size, + get_image_type, + infer_channel_dimension_format, + validate_fast_preprocess_arguments, + validate_kwargs, +) +from ...utils import ( + TensorType, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, + logging, +) +from .image_processing_pixtral import ( + BatchMixFeature, + convert_to_rgb, + get_resize_output_image_size, + make_list_of_images, +) + + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_vision_available(): + from ...image_utils import pil_torch_interpolation_mapping + + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class PixtralImageProcessorFast(BaseImageProcessorFast): + r""" + Constructs a fast Pixtral image processor that leverages torchvision. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`): + Size of the maximum dimension of either the height or width dimension of the image. Used to control how + images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)` + patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`): + Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + patch_size: Dict[str, int] = None, + resample: Union[PILImageResampling, "F.InterpolationMode"] = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"longest_edge": 1024} + patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16} + patch_size = get_size_dict(patch_size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.patch_size = patch_size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073] + self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711] + self.do_convert_rgb = do_convert_rgb + self._valid_processor_keys = [ + "images", + "do_resize", + "size", + "patch_size", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + def resize( + self, + image: torch.Tensor, + size: Dict[str, int], + patch_size: Dict[str, int], + interpolation: "F.InterpolationMode" = None, + **kwargs, + ) -> torch.Tensor: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`Dict[str, int]`): + Dict containing the longest possible edge of the image. + patch_size (`Dict[str, int]`): + Patch size used to calculate the size of the output image. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + Resampling filter to use when resiizing the image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + if "longest_edge" in size: + size = (size["longest_edge"], size["longest_edge"]) + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.") + + if "height" in patch_size and "width" in patch_size: + patch_size = (patch_size["height"], patch_size["width"]) + else: + raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + patch_size=patch_size, + ) + return F.resize( + image, + size=output_size, + interpolation=interpolation, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + patch_size: Dict[str, int] = None, + resample: Optional[Union[PILImageResampling, "F.InterpolationMode"]] = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> BatchMixFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Describes the maximum input dimensions to the model. + patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`): + Patch size in the model. Used to calculate the image after resizing. + resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + patch_size = patch_size if patch_size is not None else self.patch_size + patch_size = get_size_dict(patch_size, default_to_square=True) + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + device = kwargs.pop("device", None) + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + images_list = make_list_of_images(images) + image_type = get_image_type(images_list[0][0]) + + if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]: + raise ValueError(f"Unsupported input image type {image_type}") + + validate_fast_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + return_tensors=return_tensors, + data_format=data_format, + ) + + if do_convert_rgb: + images_list = [[convert_to_rgb(image) for image in images] for images in images_list] + + if image_type == ImageType.PIL: + images_list = [[F.pil_to_tensor(image) for image in images] for images in images_list] + elif image_type == ImageType.NUMPY: + # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays + images_list = [[torch.from_numpy(image).contiguous() for image in images] for images in images_list] + + if device is not None: + images_list = [[image.to(device) for image in images] for images in images_list] + + # We assume that all images have the same channel dimension format. + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images_list[0][0]) + if input_data_format == ChannelDimension.LAST: + images_list = [[image.permute(2, 0, 1).contiguous() for image in images] for images in images_list] + input_data_format = ChannelDimension.FIRST + + if do_rescale and do_normalize: + # fused rescale and normalize + new_mean = torch.tensor(image_mean, device=images_list[0][0].device) * (1.0 / rescale_factor) + new_std = torch.tensor(image_std, device=images_list[0][0].device) * (1.0 / rescale_factor) + + batch_images = [] + batch_image_sizes = [] + for sample_images in images_list: + images = [] + image_sizes = [] + for image in sample_images: + if do_resize: + interpolation = ( + pil_torch_interpolation_mapping[resample] + if isinstance(resample, (PILImageResampling, int)) + else resample + ) + image = self.resize( + image=image, + size=size, + patch_size=patch_size, + interpolation=interpolation, + ) + + if do_rescale and do_normalize: + # fused rescale and normalize + image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std) + elif do_rescale: + image = image * rescale_factor + elif do_normalize: + image = F.normalize(image, image_mean, image_std) + + images.append(image) + image_sizes.append(get_image_size(image, input_data_format)) + batch_images.append(images) + batch_image_sizes.append(image_sizes) + + return BatchMixFeature(data={"pixel_values": batch_images, "image_sizes": batch_image_sizes}, tensor_type=None) diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index 343eda60135630..747f75386490fc 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -23,6 +23,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) +class PixtralImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class RTDetrImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] diff --git a/tests/models/pixtral/test_image_processing_pixtral.py b/tests/models/pixtral/test_image_processing_pixtral.py index 3994201c065c45..8b49b5aa60b99a 100644 --- a/tests/models/pixtral/test_image_processing_pixtral.py +++ b/tests/models/pixtral/test_image_processing_pixtral.py @@ -14,12 +14,14 @@ # limitations under the License. import random +import time import unittest import numpy as np +import requests from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -32,6 +34,9 @@ from transformers import PixtralImageProcessor + if is_torchvision_available(): + from transformers import PixtralImageProcessorFast + class PixtralImageProcessingTester(unittest.TestCase): def __init__( @@ -51,6 +56,7 @@ def __init__( image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, ): + super().__init__() size = size if size is not None else {"longest_edge": 24} patch_size = patch_size if patch_size is not None else {"height": 8, "width": 8} self.parent = parent @@ -128,6 +134,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = PixtralImageProcessor if is_vision_available() else None + fast_image_processing_class = PixtralImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -138,79 +145,133 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "patch_size")) - self.assertTrue(hasattr(image_processing, "do_rescale")) - self.assertTrue(hasattr(image_processing, "rescale_factor")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "patch_size")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) def test_call_pil(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs_list = self.image_processor_tester.prepare_image_inputs() - for image_inputs in image_inputs_list: - for image in image_inputs: - self.assertIsInstance(image, Image.Image) - - # Test not batched input - encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0]) - self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape) - - # Test batched - batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values - for encoded_images, images in zip(batch_encoded_images, image_inputs_list): - for encoded_image, image in zip(encoded_images, images): - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image) - self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs_list = self.image_processor_tester.prepare_image_inputs() + for image_inputs in image_inputs_list: + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape( + image_inputs_list[0][0] + ) + self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape) + + # Test batched + batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values + for encoded_images, images in zip(batch_encoded_images, image_inputs_list): + for encoded_image, image in zip(encoded_images, images): + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image) + self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape) def test_call_numpy(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random numpy tensors - image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True) - for image_inputs in image_inputs_list: - for image in image_inputs: - self.assertIsInstance(image, np.ndarray) - - # Test not batched input - encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0]) - self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape) - - # Test batched - batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values - for encoded_images, images in zip(batch_encoded_images, image_inputs_list): - for encoded_image, image in zip(encoded_images, images): - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image) - self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape) + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True) + for image_inputs in image_inputs_list: + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape( + image_inputs_list[0][0] + ) + self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape) + + # Test batched + batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values + for encoded_images, images in zip(batch_encoded_images, image_inputs_list): + for encoded_image, image in zip(encoded_images, images): + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image) + self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape) def test_call_pytorch(self): - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - # create random PyTorch tensors + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True) + for image_inputs in image_inputs_list: + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape( + image_inputs_list[0][0] + ) + self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape) + + # Test batched + batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values + for encoded_images, images in zip(batch_encoded_images, image_inputs_list): + for encoded_image, image in zip(encoded_images, images): + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image) + self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape) + + @require_vision + @require_torch + def test_fast_is_faster_than_slow(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping speed test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping speed test as one of the image processors is not defined") + + def measure_time(image_processor, image): + start = time.time() + _ = image_processor(image, return_tensors="pt") + return time.time() - start + image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True) - for image_inputs in image_inputs_list: - for image in image_inputs: - self.assertIsInstance(image, torch.Tensor) - - # Test not batched input - encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0]) - self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape) - - # Test batched - batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values - for encoded_images, images in zip(batch_encoded_images, image_inputs_list): - for encoded_image, image in zip(encoded_images, images): - expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image) - self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + fast_time = measure_time(image_processor_fast, image_inputs_list) + slow_time = measure_time(image_processor_slow, image_inputs_list) + + self.assertLessEqual(fast_time, slow_time) + + @require_vision + @require_torch + def test_slow_fast_equivalence(self): + dummy_image = Image.open( + requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw + ) + + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + + self.assertTrue(torch.allclose(encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], atol=1e-2)) @unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy def test_call_numpy_4_channels(self):