From 68a1e75768d9c2a2742483fbb3c0da0c77163fb3 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Tue, 9 Jan 2024 20:00:24 +0000
Subject: [PATCH 01/40] Draft fast image processors
---
src/transformers/image_processing_base.py | 539 ++++++++++++++++++
src/transformers/image_processing_utils.py | 13 +-
.../image_processing_utils_fast.py | 13 +
.../models/vit/image_processing_vit_fast.py | 326 +++++++++++
4 files changed, 882 insertions(+), 9 deletions(-)
create mode 100644 src/transformers/image_processing_base.py
create mode 100644 src/transformers/image_processing_utils_fast.py
create mode 100644 src/transformers/models/vit/image_processing_vit_fast.py
diff --git a/src/transformers/image_processing_base.py b/src/transformers/image_processing_base.py
new file mode 100644
index 00000000000000..54cd3d9f9098f2
--- /dev/null
+++ b/src/transformers/image_processing_base.py
@@ -0,0 +1,539 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import copy
+import json
+import os
+import warnings
+from io import BytesIO
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+import requests
+
+from .feature_extraction_utils import BatchFeature as BaseBatchFeature
+from .dynamic_module_utils import custom_object_save
+from .utils import (
+ IMAGE_PROCESSOR_NAME,
+ PushToHubMixin,
+ add_model_info_to_auto_map,
+ cached_file,
+ download_url,
+ is_offline_mode,
+ is_remote_url,
+ is_vision_available,
+)
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+
+# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
+# We override the class string here, but logic is the same.
+class BatchFeature(BaseBatchFeature):
+ r"""
+ Holds the output of the image processor specific `__call__` methods.
+
+ This class is derived from a python dictionary and can be used as a dictionary.
+
+ Args:
+ data (`dict`):
+ Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
+ tensor_type (`Union[None, str, TensorType]`, *optional*):
+ You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
+ initialization.
+ """
+
+
+# TODO: (Amy) - factor out the common parts of this and the feature extractor
+class ImageProcessingMixin(PushToHubMixin):
+ """
+ This is an image processor mixin used to provide saving/loading functionality for sequential and image feature
+ extractors.
+ """
+
+ _auto_class = None
+
+ def __init__(self, **kwargs):
+ """Set elements of `kwargs` as attributes."""
+ # This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use
+ # `XXXImageProcessor`, this attribute and its value are misleading.
+ kwargs.pop("feature_extractor_type", None)
+ # Pop "processor_class" as it should be saved as private attribute
+ self._processor_class = kwargs.pop("processor_class", None)
+ # Additional attributes without default values
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ def _set_processor_class(self, processor_class: str):
+ """Sets processor class as an attribute."""
+ self._processor_class = processor_class
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
+ ):
+ r"""
+ Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained image_processor hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
+ namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a image processor file saved using the
+ [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
+ `./my_model_directory/`.
+ - a path or url to a saved image processor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model image processor should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the image processor files and override the cached versions if
+ they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file
+ exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/".
+
+
+
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final image processor object. If `True`, then this
+ functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
+ `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+ kwargs (`Dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are image processor attributes will be used to override the
+ loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
+ controlled by the `return_unused_kwargs` keyword parameter.
+
+ Returns:
+ A image processor of type [`~image_processing_utils.ImageProcessingMixin`].
+
+ Examples:
+
+ ```python
+ # We can't instantiate directly the base class *ImageProcessingMixin* so let's show the examples on a
+ # derived class: *CLIPImageProcessor*
+ image_processor = CLIPImageProcessor.from_pretrained(
+ "openai/clip-vit-base-patch32"
+ ) # Download image_processing_config from huggingface.co and cache.
+ image_processor = CLIPImageProcessor.from_pretrained(
+ "./test/saved_model/"
+ ) # E.g. image processor (or model) was saved using *save_pretrained('./test/saved_model/')*
+ image_processor = CLIPImageProcessor.from_pretrained("./test/saved_model/preprocessor_config.json")
+ image_processor = CLIPImageProcessor.from_pretrained(
+ "openai/clip-vit-base-patch32", do_normalize=False, foo=False
+ )
+ assert image_processor.do_normalize is False
+ image_processor, unused_kwargs = CLIPImageProcessor.from_pretrained(
+ "openai/clip-vit-base-patch32", do_normalize=False, foo=False, return_unused_kwargs=True
+ )
+ assert image_processor.do_normalize is False
+ assert unused_kwargs == {"foo": False}
+ ```"""
+ kwargs["cache_dir"] = cache_dir
+ kwargs["force_download"] = force_download
+ kwargs["local_files_only"] = local_files_only
+ kwargs["revision"] = revision
+
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ kwargs["token"] = token
+
+ image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
+
+ return cls.from_dict(image_processor_dict, **kwargs)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the image processor JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ use_auth_token = kwargs.pop("use_auth_token", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token", None) is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
+
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ if self._auto_class is not None:
+ custom_object_save(self, save_directory, config=self)
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME)
+
+ self.to_json_file(output_image_processor_file)
+ logger.info(f"Image processor saved in {output_image_processor_file}")
+
+ if push_to_hub:
+ self._upload_modified_files(
+ save_directory,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=kwargs.get("token"),
+ )
+
+ return [output_image_processor_file]
+
+ @classmethod
+ def get_image_processor_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+ image processor of type [`~image_processor_utils.ImageProcessingMixin`] using `from_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+
+ Returns:
+ `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", "")
+
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ image_processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
+ if os.path.isfile(pretrained_model_name_or_path):
+ resolved_image_processor_file = pretrained_model_name_or_path
+ is_local = True
+ elif is_remote_url(pretrained_model_name_or_path):
+ image_processor_file = pretrained_model_name_or_path
+ resolved_image_processor_file = download_url(pretrained_model_name_or_path)
+ else:
+ image_processor_file = IMAGE_PROCESSOR_NAME
+ try:
+ # Load from local folder or from cache or download from model Hub and cache
+ resolved_image_processor_file = cached_file(
+ pretrained_model_name_or_path,
+ image_processor_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ )
+ except EnvironmentError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise EnvironmentError(
+ f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a {IMAGE_PROCESSOR_NAME} file"
+ )
+
+ try:
+ # Load image_processor dict
+ with open(resolved_image_processor_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ image_processor_dict = json.loads(text)
+
+ except json.JSONDecodeError:
+ raise EnvironmentError(
+ f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
+ )
+
+ if is_local:
+ logger.info(f"loading configuration file {resolved_image_processor_file}")
+ else:
+ logger.info(
+ f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
+ )
+
+ if "auto_map" in image_processor_dict and not is_local:
+ image_processor_dict["auto_map"] = add_model_info_to_auto_map(
+ image_processor_dict["auto_map"], pretrained_model_name_or_path
+ )
+
+ return image_processor_dict, kwargs
+
+ @classmethod
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
+ """
+ Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters.
+
+ Args:
+ image_processor_dict (`Dict[str, Any]`):
+ Dictionary that will be used to instantiate the image processor object. Such a dictionary can be
+ retrieved from a pretrained checkpoint by leveraging the
+ [`~image_processing_utils.ImageProcessingMixin.to_dict`] method.
+ kwargs (`Dict[str, Any]`):
+ Additional parameters from which to initialize the image processor object.
+
+ Returns:
+ [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those
+ parameters.
+ """
+ image_processor_dict = image_processor_dict.copy()
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+ # The `size` parameter is a dict and was previously an int or tuple in feature extractors.
+ # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
+ # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
+ if "size" in kwargs and "size" in image_processor_dict:
+ image_processor_dict["size"] = kwargs.pop("size")
+ if "crop_size" in kwargs and "crop_size" in image_processor_dict:
+ image_processor_dict["crop_size"] = kwargs.pop("crop_size")
+
+ image_processor = cls(**image_processor_dict)
+
+ # Update image_processor with kwargs if needed
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(image_processor, key):
+ setattr(image_processor, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ logger.info(f"Image processor {image_processor}")
+ if return_unused_kwargs:
+ return image_processor, kwargs
+ else:
+ return image_processor
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["image_processor_type"] = self.__class__.__name__
+
+ return output
+
+ @classmethod
+ def from_json_file(cls, json_file: Union[str, os.PathLike]):
+ """
+ Instantiates a image processor of type [`~image_processing_utils.ImageProcessingMixin`] from the path to a JSON
+ file of parameters.
+
+ Args:
+ json_file (`str` or `os.PathLike`):
+ Path to the JSON file containing the parameters.
+
+ Returns:
+ A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object
+ instantiated from that JSON file.
+ """
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ image_processor_dict = json.loads(text)
+ return cls(**image_processor_dict)
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
+ """
+ dictionary = self.to_dict()
+
+ for key, value in dictionary.items():
+ if isinstance(value, np.ndarray):
+ dictionary[key] = value.tolist()
+
+ # make sure private name "_processor_class" is correctly
+ # saved as "processor_class"
+ _processor_class = dictionary.pop("_processor_class", None)
+ if _processor_class is not None:
+ dictionary["processor_class"] = _processor_class
+
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this image_processor instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
+ """
+ Register this class with a given auto class. This should only be used for custom image processors as the ones
+ in the library are already mapped with `AutoImageProcessor `.
+
+
+
+ This API is experimental and may have some slight breaking changes in the next releases.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`):
+ The auto class to register this new image processor with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
+
+ def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
+ """
+ Convert a single or a list of urls into the corresponding `PIL.Image` objects.
+
+ If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
+ returned.
+ """
+ headers = {
+ "User-Agent": (
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
+ " Safari/537.36"
+ )
+ }
+ if isinstance(image_url_or_urls, list):
+ return [self.fetch_images(x) for x in image_url_or_urls]
+ elif isinstance(image_url_or_urls, str):
+ response = requests.get(image_url_or_urls, stream=True, headers=headers)
+ response.raise_for_status()
+ return Image.open(BytesIO(response.content))
+ else:
+ raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py
index d60b5c6f805c20..a04d36e49aeecf 100644
--- a/src/transformers/image_processing_utils.py
+++ b/src/transformers/image_processing_utils.py
@@ -13,19 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
-import json
-import os
-import warnings
-from io import BytesIO
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Dict, Iterable, Optional, Union
import numpy as np
-import requests
-from .dynamic_module_utils import custom_object_save
-from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .image_transforms import center_crop, normalize, rescale
+from .utils import copy_func, logging
+
+from .image_processing_base import ImageProcessingMixin, BatchFeature
from .image_utils import ChannelDimension
from .utils import (
IMAGE_PROCESSOR_NAME,
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
new file mode 100644
index 00000000000000..bcb6e890a7640d
--- /dev/null
+++ b/src/transformers/image_processing_utils_fast.py
@@ -0,0 +1,13 @@
+from .image_processing_base import ImageProcessingMixin
+
+
+class BaseImageProcessorFast(ImageProcessingMixin):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def __call__(self, images, **kwargs):
+ return self.preprocess(images, **kwargs)
+
+ def preprocess(self, images, **kwargs):
+ raise NotImplementedError
+
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
new file mode 100644
index 00000000000000..57ce34b564962f
--- /dev/null
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -0,0 +1,326 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for ViT."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from ..image_utils import ChannelDimension
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+)
+from ...utils import TensorType, logging
+from ...utils.import_utils import is_torchvision_available
+
+logger = logging.get_logger(__name__)
+
+
+if is_torchvision_available():
+ from torchvision.transforms import Resize, Compose, ToTensor, Normalize, Lambda
+
+
+
+class ViTImageProcessorFast(BaseImageProcessor):
+ r"""
+ Constructs a ViT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+ _transform_params = ["do_resize", "do_rescale", "do_normalize", "size", "resample", "rescale_factor", "image_mean", "image_std"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[Dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 224, "width": 224}
+ size = get_size_dict(size)
+ self.do_resize = do_resize
+ self.do_rescale = do_rescale
+ self.do_normalize = do_normalize
+ self.size = size
+ self.resample = resample
+ self.rescale_factor = rescale_factor
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self._transform_settings = self.set_transforms(
+ do_resize=do_resize,
+ do_rescale=do_rescale,
+ do_normalize=do_normalize,
+ size=size,
+ resample=resample,
+ rescale_factor=rescale_factor,
+ image_mean=image_mean,
+ image_std=image_std,
+ )
+
+ def _set_transform_settings(self, **kwargs):
+ settings = {}
+ for k, v in kwargs.items():
+ if v not in self._transform_params:
+ raise ValueError(f"Invalid transform parameter {k}={v}.")
+ settings[k] = v
+ self._transform_settings = settings
+
+ def __same_transforms_settings(self, **kwargs):
+ """
+ Check if the current settings are the same as the current transforms.
+ """
+ for key, value in kwargs.items():
+ if value not in self._transform_settings or value != self._transform_settings[key]:
+ return False
+ return True
+
+ def _build_transforms(
+ self,
+ do_resize: bool,
+ size_dict: Dict[str, int],
+ resample: PILImageResampling,
+ do_rescale: bool,
+ do_normalize: bool,
+ image_mean: Union[float, List[float]],
+ image_std: Union[float, List[float]],
+ data_format: Union[str, ChannelDimension],
+ ) -> Compose:
+ transforms = []
+ if do_resize:
+ # FIXME - convert the interpolation mode to the pytorch equivalent
+ transforms.append(Resize(size_dict["height"], size_dict["width"], interpolation=resample))
+ if do_rescale:
+ transforms.append(ToTensor())
+ if do_normalize:
+ transforms.append(Normalize(image_mean, image_std))
+ if data_format is not None and data_format == ChannelDimension.LAST:
+ transforms.append(Lambda(lambda x: x.permute(1, 2, 0)))
+ return Compose(transforms)
+
+ def set_transforms(
+ self,
+ do_resize: bool,
+ do_rescale: bool,
+ do_normalize: bool,
+ size: Dict[str, int],
+ resample: PILImageResampling,
+ rescale_factor: float,
+ image_mean: Union[float, List[float]],
+ image_std: Union[float, List[float]],
+ data_format: Union[str, ChannelDimension],
+ ):
+ if self.__same_transforms_settings(
+ do_resize=do_resize,
+ do_rescale=do_rescale,
+ do_normalize=do_normalize,
+ size=size,
+ resample=resample,
+ rescale_factor=rescale_factor,
+ image_mean=image_mean,
+ image_std=image_std,
+ ):
+ return self._transforms
+ transforms = self._build_transforms(
+ do_resize=do_resize,
+ size_dict=size_dict,
+ resample=resample,
+ do_rescale=do_rescale,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ )
+ self._set_transform_settings(
+ do_resize=do_resize,
+ do_rescale=do_rescale,
+ do_normalize=do_normalize,
+ size=size,
+ resample=resample,
+ rescale_factor=rescale_factor,
+ image_mean=image_mean,
+ image_std=image_std,
+ )
+ self._transforms = transforms
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ):
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
+ resizing.
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
+ an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use if `do_normalize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ resample = resample if resample is not None else self.resample
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ if return_tensors != "pt":
+ raise ValueError("Only returning PyTorch tensors is currently supported.")
+
+ if data_format != ChannelDimension.FIRST:
+ raise ValueError("Only channel first data format is currently supported.")
+
+ if not self.__same_transforms_settings(
+ do_resize=do_resize,
+ do_rescale=do_rescale,
+ do_normalize=do_normalize,
+ size=size,
+ resample=resample,
+ rescale_factor=rescale_factor,
+ image_mean=image_mean,
+ image_std=image_std,
+ ):
+ self.set_transforms(
+ do_resize=do_resize,
+ do_rescale=do_rescale,
+ do_normalize=do_normalize,
+ size=size,
+ resample=resample,
+ rescale_factor=rescale_factor,
+ image_mean=image_mean,
+ image_std=image_std,
+ )
+
+ transformed_images = self._transforms(images)
+
+ size = size if size is not None else self.size
+ size_dict = get_size_dict(size)
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ if do_resize and size is None:
+ raise ValueError("Size must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
From dbf9959ec565b2bef66d59cf783a526f845dd035 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Fri, 2 Feb 2024 20:06:25 +0000
Subject: [PATCH 02/40] Draft working fast version
---
.../source/en/main_classes/image_processor.md | 5 +
docs/source/en/model_doc/vit.md | 7 +-
src/transformers/__init__.py | 15 +-
src/transformers/image_processing_base.py | 7 +-
src/transformers/image_processing_utils.py | 2 +
.../image_processing_utils_fast.py | 59 ++++-
src/transformers/models/vit/__init__.py | 2 +
.../models/vit/image_processing_vit_fast.py | 212 ++++++++----------
.../utils/dummy_vision_objects.py | 28 +++
9 files changed, 203 insertions(+), 134 deletions(-)
diff --git a/docs/source/en/main_classes/image_processor.md b/docs/source/en/main_classes/image_processor.md
index 04a3cd1337a526..1c65be6f350088 100644
--- a/docs/source/en/main_classes/image_processor.md
+++ b/docs/source/en/main_classes/image_processor.md
@@ -32,3 +32,8 @@ An image processor is in charge of preparing input features for vision models an
## BaseImageProcessor
[[autodoc]] image_processing_utils.BaseImageProcessor
+
+
+## BaseImageProcessorFast
+
+[[autodoc]] image_processing_utils.BaseImageProcessorFast
diff --git a/docs/source/en/model_doc/vit.md b/docs/source/en/model_doc/vit.md
index b49cb821859f59..53a550895ce22e 100644
--- a/docs/source/en/model_doc/vit.md
+++ b/docs/source/en/model_doc/vit.md
@@ -62,7 +62,7 @@ Following the original Vision Transformer, some follow-up works have been made:
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be
found [here](https://github.com/google-research/vision_transformer).
-Note that we converted the weights from Ross Wightman's [timm library](https://github.com/rwightman/pytorch-image-models),
+Note that we converted the weights from Ross Wightman's [timm library](https://github.com/rwightman/pytorch-image-models),
who already converted the weights from JAX to PyTorch. Credits go to him!
## Usage tips
@@ -158,6 +158,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] ViTImageProcessor
- preprocess
+## ViTImageProcessorFast
+
+[[autodoc]] ViTImageProcessorFast
+ - preprocess
+
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 973764da0b7a55..ef702df9c4be2b 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -1104,7 +1104,9 @@
name for name in dir(dummy_vision_objects) if not name.startswith("_")
]
else:
- _import_structure["image_processing_utils"] = ["ImageProcessingMixin"]
+ _import_structure["image_processing_base"] = ["ImageProcessingMixin"]
+ _import_structure["image_processing_utils"] = ["BaseImageProcessor"]
+ _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
_import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"])
_import_structure["models.bit"].extend(["BitImageProcessor"])
@@ -1162,7 +1164,8 @@
_import_structure["models.video_llava"].append("VideoLlavaImageProcessor")
_import_structure["models.videomae"].extend(["VideoMAEFeatureExtractor", "VideoMAEImageProcessor"])
_import_structure["models.vilt"].extend(["ViltFeatureExtractor", "ViltImageProcessor", "ViltProcessor"])
- _import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor"])
+ _import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor", "ViTImageProcessorFast"])
+ _import_structure["models.vit_hybrid"].extend(["ViTHybridImageProcessor"])
_import_structure["models.vitmatte"].append("VitMatteImageProcessor")
_import_structure["models.vivit"].append("VivitImageProcessor")
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
@@ -5703,7 +5706,9 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_vision_objects import *
else:
- from .image_processing_utils import ImageProcessingMixin
+ from .image_processing_base import ImageProcessingMixin
+ from .image_processing_utils import BaseImageProcessor
+ from .image_processing_utils_fast import BaseImageProcessorFast
from .image_utils import ImageFeatureExtractionMixin
from .models.beit import BeitFeatureExtractor, BeitImageProcessor
from .models.bit import BitImageProcessor
@@ -5788,11 +5793,9 @@
from .models.video_llava import VideoLlavaImageProcessor
from .models.videomae import VideoMAEFeatureExtractor, VideoMAEImageProcessor
from .models.vilt import ViltFeatureExtractor, ViltImageProcessor, ViltProcessor
- from .models.vit import ViTFeatureExtractor, ViTImageProcessor
- from .models.vitmatte import VitMatteImageProcessor
+ from .models.vit import ViTFeatureExtractor, ViTImageProcessor, ViTImageProcessorFast
from .models.vivit import VivitImageProcessor
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
-
# Modeling
try:
if not is_torch_available():
diff --git a/src/transformers/image_processing_base.py b/src/transformers/image_processing_base.py
index 54cd3d9f9098f2..2146afa2108cc2 100644
--- a/src/transformers/image_processing_base.py
+++ b/src/transformers/image_processing_base.py
@@ -19,13 +19,13 @@
import os
import warnings
from io import BytesIO
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import requests
-from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .dynamic_module_utils import custom_object_save
+from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .utils import (
IMAGE_PROCESSOR_NAME,
PushToHubMixin,
@@ -35,6 +35,7 @@
is_offline_mode,
is_remote_url,
is_vision_available,
+ logging,
)
@@ -42,6 +43,8 @@
from PIL import Image
+logger = logging.get_logger(__name__)
+
# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
# We override the class string here, but logic is the same.
diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py
index a04d36e49aeecf..95c4b42abcf24c 100644
--- a/src/transformers/image_processing_utils.py
+++ b/src/transformers/image_processing_utils.py
@@ -17,7 +17,9 @@
import numpy as np
+from .image_processing_base import BatchFeature, ImageProcessingMixin
from .image_transforms import center_crop, normalize, rescale
+from .image_utils import ChannelDimension
from .utils import copy_func, logging
from .image_processing_base import ImageProcessingMixin, BatchFeature
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index bcb6e890a7640d..e84c55d03ae1f8 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -1,13 +1,58 @@
-from .image_processing_base import ImageProcessingMixin
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import cache
-class BaseImageProcessorFast(ImageProcessingMixin):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
+from .image_processing_utils import BaseImageProcessor
- def __call__(self, images, **kwargs):
- return self.preprocess(images, **kwargs)
- def preprocess(self, images, **kwargs):
+class BaseImageProcessorFast(BaseImageProcessor):
+ _transform_params = None
+
+ def _set_transform_settings(self, **kwargs):
+ settings = {}
+ for k, v in kwargs.items():
+ if k not in self._transform_params:
+ raise ValueError(f"Invalid transform parameter {k}={v}.")
+ settings[k] = v
+ self._transform_settings = settings
+
+ def _same_transforms_settings(self, **kwargs):
+ """
+ Check if the current settings are the same as the current transforms.
+ """
+ for key, value in kwargs.items():
+ if value not in self._transform_settings or value != self._transform_settings[key]:
+ return False
+ return True
+
+ def _build_transforms(self, **kwargs):
raise NotImplementedError
+ def set_transforms(self, **kwargs):
+ # FIXME - put input validation or kwargs for all these methods
+
+ if self._same_transforms_settings(**kwargs):
+ return self._transforms
+
+ transforms = self._build_transforms(**kwargs)
+ self._set_transform_settings(**kwargs)
+ self._transforms = transforms
+
+ @cache
+ def _maybe_update_transforms(self, **kwargs):
+ if self._same_transforms_settings(**kwargs):
+ return
+ self.set_transforms(**kwargs)
diff --git a/src/transformers/models/vit/__init__.py b/src/transformers/models/vit/__init__.py
index db41e881faafa6..25f55487c4bfd1 100644
--- a/src/transformers/models/vit/__init__.py
+++ b/src/transformers/models/vit/__init__.py
@@ -33,6 +33,7 @@
else:
_import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
_import_structure["image_processing_vit"] = ["ViTImageProcessor"]
+ _import_structure["image_processing_vit_fast"] = ["ViTImageProcessorFast"]
try:
if not is_torch_available():
@@ -82,6 +83,7 @@
else:
from .feature_extraction_vit import ViTFeatureExtractor
from .image_processing_vit import ViTImageProcessor
+ from .image_processing_vit_fast import ViTImageProcessorFast
try:
if not is_torch_available():
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index 57ce34b564962f..bb440978f9469c 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -14,37 +14,63 @@
# limitations under the License.
"""Image processor class for ViT."""
-from typing import Dict, List, Optional, Union
+from dataclasses import dataclass
+from functools import cache
+from typing import Any, Dict, List, Optional, Union
import numpy as np
-from ..image_utils import ChannelDimension
-from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
-from ...image_transforms import to_channel_dimension_format
+from ...image_processing_utils import get_size_dict
+from ...image_processing_utils_fast import BaseImageProcessorFast
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
- infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
- to_numpy_array,
- valid_images,
)
from ...utils import TensorType, logging
-from ...utils.import_utils import is_torchvision_available
+from ...utils.import_utils import is_torch_available, is_vision_available
+
logger = logging.get_logger(__name__)
-if is_torchvision_available():
- from torchvision.transforms import Resize, Compose, ToTensor, Normalize, Lambda
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+ from torchvision.transforms import Compose, InterpolationMode, Normalize, Resize, ToTensor
+
+
+pil_torch_interpolation_mapping = {
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST,
+ PILImageResampling.BOX: InterpolationMode.BOX,
+ PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
+ PILImageResampling.HAMMING: InterpolationMode.HAMMING,
+ PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
+ PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST,
+}
+
+@dataclass(frozen=True)
+class SizeDict:
+ height: int = None
+ width: int = None
+ longest_edge: int = None
+ shortest_edge: int = None
+ def __getitem__(self, key):
+ if hasattr(self, key):
+ return getattr(self, key)
+ raise KeyError(f"Key {key} not found in SizeDict.")
-class ViTImageProcessorFast(BaseImageProcessor):
+
+class ViTImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a ViT image processor.
@@ -76,7 +102,16 @@ class ViTImageProcessorFast(BaseImageProcessor):
"""
model_input_names = ["pixel_values"]
- _transform_params = ["do_resize", "do_rescale", "do_normalize", "size", "resample", "rescale_factor", "image_mean", "image_std"]
+ _transform_params = [
+ "do_resize",
+ "do_rescale",
+ "do_normalize",
+ "size",
+ "resample",
+ "rescale_factor",
+ "image_mean",
+ "image_std",
+ ]
def __init__(
self,
@@ -101,7 +136,8 @@ def __init__(
self.rescale_factor = rescale_factor
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
- self._transform_settings = self.set_transforms(
+ self._transform_settings = {}
+ self.set_transforms(
do_resize=do_resize,
do_rescale=do_rescale,
do_normalize=do_normalize,
@@ -112,90 +148,53 @@ def __init__(
image_std=image_std,
)
- def _set_transform_settings(self, **kwargs):
- settings = {}
- for k, v in kwargs.items():
- if v not in self._transform_params:
- raise ValueError(f"Invalid transform parameter {k}={v}.")
- settings[k] = v
- self._transform_settings = settings
-
- def __same_transforms_settings(self, **kwargs):
- """
- Check if the current settings are the same as the current transforms.
- """
- for key, value in kwargs.items():
- if value not in self._transform_settings or value != self._transform_settings[key]:
- return False
- return True
-
def _build_transforms(
self,
do_resize: bool,
- size_dict: Dict[str, int],
+ size: Dict[str, int],
resample: PILImageResampling,
do_rescale: bool,
+ rescale_factor: float, # dummy
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
- data_format: Union[str, ChannelDimension],
) -> Compose:
transforms = []
if do_resize:
- # FIXME - convert the interpolation mode to the pytorch equivalent
- transforms.append(Resize(size_dict["height"], size_dict["width"], interpolation=resample))
+ transforms.append(
+ Resize((size["height"], size["width"]), interpolation=pil_torch_interpolation_mapping[resample])
+ )
if do_rescale:
transforms.append(ToTensor())
if do_normalize:
transforms.append(Normalize(image_mean, image_std))
- if data_format is not None and data_format == ChannelDimension.LAST:
- transforms.append(Lambda(lambda x: x.permute(1, 2, 0)))
return Compose(transforms)
- def set_transforms(
+ @cache
+ def _validate_input_arguments(
self,
+ return_tensors: Union[str, TensorType],
do_resize: bool,
- do_rescale: bool,
- do_normalize: bool,
size: Dict[str, int],
resample: PILImageResampling,
+ do_rescale: bool,
rescale_factor: float,
+ do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
data_format: Union[str, ChannelDimension],
):
- if self.__same_transforms_settings(
- do_resize=do_resize,
- do_rescale=do_rescale,
- do_normalize=do_normalize,
- size=size,
- resample=resample,
- rescale_factor=rescale_factor,
- image_mean=image_mean,
- image_std=image_std,
- ):
- return self._transforms
- transforms = self._build_transforms(
- do_resize=do_resize,
- size_dict=size_dict,
- resample=resample,
- do_rescale=do_rescale,
- do_normalize=do_normalize,
- image_mean=image_mean,
- image_std=image_std,
- data_format=data_format,
- )
- self._set_transform_settings(
- do_resize=do_resize,
- do_rescale=do_rescale,
- do_normalize=do_normalize,
- size=size,
- resample=resample,
- rescale_factor=rescale_factor,
- image_mean=image_mean,
- image_std=image_std,
- )
- self._transforms = transforms
+ if return_tensors != "pt":
+ raise ValueError("Only returning PyTorch tensors is currently supported.")
+
+ if data_format != ChannelDimension.FIRST:
+ raise ValueError("Only channel first data format is currently supported.")
+
+ if do_resize and size is None:
+ raise ValueError("Size must be specified if do_resize is True.")
+
+ if do_rescale and rescale_factor is None:
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
def preprocess(
self,
@@ -264,14 +263,24 @@ def preprocess(
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
+ size = size if size is not None else self.size
+ # Make hashable for cache
+ size = SizeDict(**size)
+ image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
+ image_std = tuple(image_std) if isinstance(image_std, list) else image_std
- if return_tensors != "pt":
- raise ValueError("Only returning PyTorch tensors is currently supported.")
+ images = make_list_of_images(images)
- if data_format != ChannelDimension.FIRST:
- raise ValueError("Only channel first data format is currently supported.")
+ if do_rescale:
+ if isinstance(images[0], np.ndarray) and is_scaled_image(images[0]):
+ raise ValueError(
+ "Images are expected to have pixel values in the range [0, 255] when do_rescale=True. "
+ "Got pixel values in the range [0, 1]."
+ )
+ elif not isinstance(images[0], Image.Image):
+ raise ValueError("Images must be of type PIL.Image.Image or np.ndarray when do_rescale=True.")
- if not self.__same_transforms_settings(
+ self._maybe_update_transforms(
do_resize=do_resize,
do_rescale=do_rescale,
do_normalize=do_normalize,
@@ -280,47 +289,14 @@ def preprocess(
rescale_factor=rescale_factor,
image_mean=image_mean,
image_std=image_std,
- ):
- self.set_transforms(
- do_resize=do_resize,
- do_rescale=do_rescale,
- do_normalize=do_normalize,
- size=size,
- resample=resample,
- rescale_factor=rescale_factor,
- image_mean=image_mean,
- image_std=image_std,
- )
-
- transformed_images = self._transforms(images)
-
- size = size if size is not None else self.size
- size_dict = get_size_dict(size)
-
- images = make_list_of_images(images)
-
- if not valid_images(images):
- raise ValueError(
- "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
- "torch.Tensor, tf.Tensor or jax.ndarray."
- )
-
- if do_resize and size is None:
- raise ValueError("Size must be specified if do_resize is True.")
-
- if do_rescale and rescale_factor is None:
- raise ValueError("Rescale factor must be specified if do_rescale is True.")
-
- # All transformations expect numpy arrays.
- images = [to_numpy_array(image) for image in images]
-
- if input_data_format is None:
- # We assume that all images have the same channel dimension format.
- input_data_format = infer_channel_dimension_format(images[0])
+ )
+ transformed_images = [self._transforms(image) for image in images]
- images = [
- to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
- ]
+ data = {"pixel_values": torch.vstack(transformed_images)}
+ return data
- data = {"pixel_values": images}
- return BatchFeature(data=data, tensor_type=return_tensors)
+ def to_dict(self) -> Dict[str, Any]:
+ result = super().to_dict()
+ result.pop("_transforms", None)
+ result.pop("_transform_settings", None)
+ return result
diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py
index aae31e9e4dd7f4..34741a703509ec 100644
--- a/src/transformers/utils/dummy_vision_objects.py
+++ b/src/transformers/utils/dummy_vision_objects.py
@@ -9,6 +9,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
+class BaseImageProcessor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
+class BaseImageProcessorFast(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class ImageFeatureExtractionMixin(metaclass=DummyObject):
_backends = ["vision"]
@@ -597,6 +611,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
+class ViTImageProcessorFast(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
+class ViTHybridImageProcessor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class VitMatteImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
From 3632cf704cff919a00e51a3827cf358f3f3ec4ff Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Tue, 13 Feb 2024 20:20:48 +0000
Subject: [PATCH 03/40] py3.8 compatible cache
---
.../image_processing_utils_fast.py | 6 ++--
.../models/vit/image_processing_vit_fast.py | 33 +++++++++++--------
2 files changed, 21 insertions(+), 18 deletions(-)
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index e84c55d03ae1f8..b8128538bea312 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from functools import cache
+import functools
from .image_processing_utils import BaseImageProcessor
@@ -42,8 +42,6 @@ def _build_transforms(self, **kwargs):
raise NotImplementedError
def set_transforms(self, **kwargs):
- # FIXME - put input validation or kwargs for all these methods
-
if self._same_transforms_settings(**kwargs):
return self._transforms
@@ -51,7 +49,7 @@ def set_transforms(self, **kwargs):
self._set_transform_settings(**kwargs)
self._transforms = transforms
- @cache
+ @functools.lru_cache(maxsize=1)
def _maybe_update_transforms(self, **kwargs):
if self._same_transforms_settings(**kwargs):
return
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index bb440978f9469c..7ade1a2c572651 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -14,8 +14,8 @@
# limitations under the License.
"""Image processor class for ViT."""
+import functools
from dataclasses import dataclass
-from functools import cache
from typing import Any, Dict, List, Optional, Union
import numpy as np
@@ -119,7 +119,7 @@ def __init__(
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
- rescale_factor: Union[int, float] = 1 / 255,
+ rescale_factor: Union[int, float] = None,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
@@ -170,7 +170,7 @@ def _build_transforms(
transforms.append(Normalize(image_mean, image_std))
return Compose(transforms)
- @cache
+ @functools.lru_cache(maxsize=1)
def _validate_input_arguments(
self,
return_tensors: Union[str, TensorType],
@@ -190,12 +190,15 @@ def _validate_input_arguments(
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")
- if do_resize and size is None:
- raise ValueError("Size must be specified if do_resize is True.")
+ if do_resize and None in (size, resample):
+ raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
+ if do_normalize and None in (image_mean, image_std):
+ raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
+
def preprocess(
self,
images: ImageInput,
@@ -238,17 +241,10 @@ def preprocess(
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use if `do_normalize` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
- The type of tensors to return. Can be one of:
- - Unset: Return a list of `np.ndarray`.
- - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ The type of tensors to return. Only "pt" is supported
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
- The channel dimension format for the output image. Can be one of:
+ The channel dimension format for the output image. The following formats are currently supported:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- - Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
@@ -256,6 +252,15 @@ def preprocess(
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
+ if return_tensors != "pt":
+ raise ValueError("Only returning PyTorch tensors is currently supported.")
+
+ if input_data_format is not None and input_data_format != ChannelDimension.FIRST:
+ raise ValueError("Only channel first data format is currently supported.")
+
+ if data_format != ChannelDimension.FIRST:
+ raise ValueError("Only channel first data format is currently supported.")
+
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
From 49196c81f26b2a974d735e60cbec59601b673af4 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Fri, 10 May 2024 14:27:03 +0000
Subject: [PATCH 04/40] Enable loading fast image processors through auto
---
.../models/auto/image_processing_auto.py | 284 +++++++++++-------
1 file changed, 176 insertions(+), 108 deletions(-)
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index eb21b58e20f14e..88d6edda6197b8 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -19,13 +19,14 @@
import os
import warnings
from collections import OrderedDict
-from typing import Dict, Optional, Union
+from typing import Dict, Optional, Union, TYPE_CHECKING
# Build the list of all image processors
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
-from ...image_processing_utils import ImageProcessingMixin
-from ...utils import CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, logging
+from ...image_processing_utils import ImageProcessingMixin, BaseImageProcessor
+from ...image_processing_utils_fast import BaseImageProcessorFast
+from ...utils import CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, logging, is_torchvision_available
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
@@ -37,104 +38,111 @@
logger = logging.get_logger(__name__)
-IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
- [
- ("align", "EfficientNetImageProcessor"),
- ("beit", "BeitImageProcessor"),
- ("bit", "BitImageProcessor"),
- ("blip", "BlipImageProcessor"),
- ("blip-2", "BlipImageProcessor"),
- ("bridgetower", "BridgeTowerImageProcessor"),
- ("chinese_clip", "ChineseCLIPImageProcessor"),
- ("clip", "CLIPImageProcessor"),
- ("clipseg", "ViTImageProcessor"),
- ("conditional_detr", "ConditionalDetrImageProcessor"),
- ("convnext", "ConvNextImageProcessor"),
- ("convnextv2", "ConvNextImageProcessor"),
- ("cvt", "ConvNextImageProcessor"),
- ("data2vec-vision", "BeitImageProcessor"),
- ("deformable_detr", "DeformableDetrImageProcessor"),
- ("deit", "DeiTImageProcessor"),
- ("depth_anything", "DPTImageProcessor"),
- ("deta", "DetaImageProcessor"),
- ("detr", "DetrImageProcessor"),
- ("dinat", "ViTImageProcessor"),
- ("dinov2", "BitImageProcessor"),
- ("donut-swin", "DonutImageProcessor"),
- ("dpt", "DPTImageProcessor"),
- ("efficientformer", "EfficientFormerImageProcessor"),
- ("efficientnet", "EfficientNetImageProcessor"),
- ("flava", "FlavaImageProcessor"),
- ("focalnet", "BitImageProcessor"),
- ("fuyu", "FuyuImageProcessor"),
- ("git", "CLIPImageProcessor"),
- ("glpn", "GLPNImageProcessor"),
- ("grounding-dino", "GroundingDinoImageProcessor"),
- ("groupvit", "CLIPImageProcessor"),
- ("idefics", "IdeficsImageProcessor"),
- ("idefics2", "Idefics2ImageProcessor"),
- ("imagegpt", "ImageGPTImageProcessor"),
- ("instructblip", "BlipImageProcessor"),
- ("kosmos-2", "CLIPImageProcessor"),
- ("layoutlmv2", "LayoutLMv2ImageProcessor"),
- ("layoutlmv3", "LayoutLMv3ImageProcessor"),
- ("levit", "LevitImageProcessor"),
- ("llava", "CLIPImageProcessor"),
- ("llava_next", "LlavaNextImageProcessor"),
- ("mask2former", "Mask2FormerImageProcessor"),
- ("maskformer", "MaskFormerImageProcessor"),
- ("mgp-str", "ViTImageProcessor"),
- ("mobilenet_v1", "MobileNetV1ImageProcessor"),
- ("mobilenet_v2", "MobileNetV2ImageProcessor"),
- ("mobilevit", "MobileViTImageProcessor"),
- ("mobilevit", "MobileViTImageProcessor"),
- ("mobilevitv2", "MobileViTImageProcessor"),
- ("nat", "ViTImageProcessor"),
- ("nougat", "NougatImageProcessor"),
- ("oneformer", "OneFormerImageProcessor"),
- ("owlv2", "Owlv2ImageProcessor"),
- ("owlvit", "OwlViTImageProcessor"),
- ("paligemma", "CLIPImageProcessor"),
- ("perceiver", "PerceiverImageProcessor"),
- ("pix2struct", "Pix2StructImageProcessor"),
- ("poolformer", "PoolFormerImageProcessor"),
- ("pvt", "PvtImageProcessor"),
- ("pvt_v2", "PvtImageProcessor"),
- ("regnet", "ConvNextImageProcessor"),
- ("resnet", "ConvNextImageProcessor"),
- ("sam", "SamImageProcessor"),
- ("segformer", "SegformerImageProcessor"),
- ("seggpt", "SegGptImageProcessor"),
- ("siglip", "SiglipImageProcessor"),
- ("swiftformer", "ViTImageProcessor"),
- ("swin", "ViTImageProcessor"),
- ("swin2sr", "Swin2SRImageProcessor"),
- ("swinv2", "ViTImageProcessor"),
- ("table-transformer", "DetrImageProcessor"),
- ("timesformer", "VideoMAEImageProcessor"),
- ("tvlt", "TvltImageProcessor"),
- ("tvp", "TvpImageProcessor"),
- ("udop", "LayoutLMv3ImageProcessor"),
- ("upernet", "SegformerImageProcessor"),
- ("van", "ConvNextImageProcessor"),
- ("video_llava", "VideoLlavaImageProcessor"),
- ("videomae", "VideoMAEImageProcessor"),
- ("vilt", "ViltImageProcessor"),
- ("vipllava", "CLIPImageProcessor"),
- ("vit", "ViTImageProcessor"),
- ("vit_hybrid", "ViTHybridImageProcessor"),
- ("vit_mae", "ViTImageProcessor"),
- ("vit_msn", "ViTImageProcessor"),
- ("vitmatte", "VitMatteImageProcessor"),
- ("xclip", "CLIPImageProcessor"),
- ("yolos", "YolosImageProcessor"),
- ]
-)
+
+if TYPE_CHECKING:
+ # This significantly improves completion suggestion performance when
+ # the transformers package is used with Microsoft's Pylance language server.
+ IMAGE_PROCESSOR_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict()
+else:
+ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
+ [
+ ("align", ("EfficientNetImageProcessor", None)),
+ ("beit", ("BeitImageProcessor", None)),
+ ("bit", ("BitImageProcessor", None)),
+ ("blip", ("BlipImageProcessor", None)),
+ ("blip-2", ("BlipImageProcessor", None)),
+ ("bridgetower", ("BridgeTowerImageProcessor", None)),
+ ("chinese_clip", ("ChineseCLIPImageProcessor", None)),
+ ("clip", ("CLIPImageProcessor", None)),
+ ("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
+ ("conditional_detr", ("ConditionalDetrImageProcessor", None)),
+ ("convnext", ("ConvNextImageProcessor", None)),
+ ("convnextv2", ("ConvNextImageProcessor", None)),
+ ("cvt", ("ConvNextImageProcessor", None)),
+ ("data2vec-vision", ("BeitImageProcessor", None)),
+ ("deformable_detr", ("DeformableDetrImageProcessor", None)),
+ ("deit", ("DeiTImageProcessor", None)),
+ ("depth_anything", ("DPTImageProcessor", None)),
+ ("deta", ("DetaImageProcessor", None)),
+ ("detr", ("DetrImageProcessor", None)),
+ ("dinat", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
+ ("dinov2", ("BitImageProcessor", None)),
+ ("donut-swin", ("DonutImageProcessor", None)),
+ ("dpt", ("DPTImageProcessor", None)),
+ ("efficientformer", ("EfficientFormerImageProcessor", None)),
+ ("efficientnet", ("EfficientNetImageProcessor", None)),
+ ("flava", ("FlavaImageProcessor", None)),
+ ("focalnet", ("BitImageProcessor", None)),
+ ("fuyu", ("FuyuImageProcessor", None)),
+ ("git", ("CLIPImageProcessor", None)),
+ ("glpn", ("GLPNImageProcessor", None)),
+ ("grounding-dino", ("GroundingDinoImageProcessor", None)),
+ ("groupvit", ("CLIPImageProcessor", None)),
+ ("idefics", ("IdeficsImageProcessor", None)),
+ ("idefics2", ("Idefics2ImageProcessor", None)),
+ ("imagegpt", ("ImageGPTImageProcessor", None)),
+ ("instructblip", ("BlipImageProcessor", None)),
+ ("kosmos-2", ("CLIPImageProcessor", None)),
+ ("layoutlmv2", ("LayoutLMv2ImageProcessor", None)),
+ ("layoutlmv3", ("LayoutLMv3ImageProcessor", None)),
+ ("levit", ("LevitImageProcessor", None)),
+ ("llava", ("CLIPImageProcessor", None)),
+ ("llava_next", ("LlavaNextImageProcessor", None)),
+ ("mask2former", ("Mask2FormerImageProcessor", None)),
+ ("maskformer", ("MaskFormerImageProcessor", None)),
+ ("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
+ ("mobilenet_v1", ("MobileNetV1ImageProcessor", None)),
+ ("mobilenet_v2", ("MobileNetV2ImageProcessor", None)),
+ ("mobilevit", ("MobileViTImageProcessor", None)),
+ ("mobilevit", ("MobileViTImageProcessor", None)),
+ ("mobilevitv2", ("MobileViTImageProcessor", None)),
+ ("nat", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
+ ("nougat", ("NougatImageProcessor", None)),
+ ("oneformer", ("OneFormerImageProcessor", None)),
+ ("owlv2", ("Owlv2ImageProcessor", None)),
+ ("owlvit", ("OwlViTImageProcessor", None)),
+ ("perceiver", ("PerceiverImageProcessor", None)),
+ ("pix2struct", ("Pix2StructImageProcessor", None)),
+ ("poolformer", ("PoolFormerImageProcessor", None)),
+ ("pvt", ("PvtImageProcessor", None)),
+ ("pvt_v2", ("PvtImageProcessor", None)),
+ ("regnet", ("ConvNextImageProcessor", None)),
+ ("resnet", ("ConvNextImageProcessor", None)),
+ ("sam", ("SamImageProcessor", None)),
+ ("segformer", ("SegformerImageProcessor", None)),
+ ("seggpt", ("SegGptImageProcessor", None)),
+ ("siglip", ("SiglipImageProcessor", None)),
+ ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
+ ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("swin2sr", ("Swin2SRImageProcessor", None)),
+ ("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
+ ("table-transformer", ("DetrImageProcessor", None)),
+ ("timesformer", ("VideoMAEImageProcessor", None)),
+ ("tvlt", ("TvltImageProcessor", None)),
+ ("tvp", ("TvpImageProcessor", None)),
+ ("udop", ("LayoutLMv3ImageProcessor", None)),
+ ("upernet", ("SegformerImageProcessor", None)),
+ ("van", ("ConvNextImageProcessor", None)),
+ ("videomae", ("VideoMAEImageProcessor", None)),
+ ("vilt", ("ViltImageProcessor", None)),
+ ("vipllava", ("CLIPImageProcessor", None)),
+ ("vit", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
+ ("vit_hybrid", ("ViTHybridImageProcessor", None)),
+ ("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
+ ("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
+ ("vitmatte", ("VitMatteImageProcessor", None)),
+ ("xclip", ("CLIPImageProcessor", None)),
+ ("yolos", ("YolosImageProcessor", None)),
+ ]
+ )
IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
def image_processor_class_from_name(class_name: str):
+ if class_name == "BaseImageProcessorFast":
+ return BaseImageProcessorFast
+
for module_name, extractors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
if class_name in extractors:
module_name = model_type_to_module_name(module_name)
@@ -145,11 +153,12 @@ def image_processor_class_from_name(class_name: str):
except AttributeError:
continue
- for _, extractor in IMAGE_PROCESSOR_MAPPING._extra_content.items():
- if getattr(extractor, "__name__", None) == class_name:
- return extractor
+ for _, extractors in IMAGE_PROCESSOR_MAPPING._extra_content.items():
+ for extractor in extractors:
+ if getattr(extractor, "__name__", None) == class_name:
+ return extractor
- # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
+ # We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
# init and we return the proper dummy to get an appropriate error message.
main_module = importlib.import_module("transformers")
if hasattr(main_module, class_name):
@@ -274,7 +283,7 @@ def __init__(self):
@classmethod
@replace_list_option_in_docstrings(IMAGE_PROCESSOR_MAPPING_NAMES)
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
r"""
Instantiate one of the image processor classes of the library from a pretrained model vocabulary.
@@ -314,6 +323,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
+ use_fast (`bool`, *optional*, defaults to `True`):
+ Use a fast torchvision-base image processor if it is supported for a given model.
+ If a fast tokenizer is not available for a given model, a normal numpy-based image processor
+ is returned instead.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
If `False`, then this function returns just the final image processor object. If `True`, then this
functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
@@ -358,6 +371,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
kwargs["token"] = use_auth_token
config = kwargs.pop("config", None)
+ use_fast = kwargs.pop("use_fast", True)
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True
@@ -395,10 +409,16 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
+ if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple):
+ # In some configs, only the slow image processor class is stored
+ image_processor_auto_map = (image_processor_auto_map, None)
+
if has_remote_code and trust_remote_code:
- image_processor_class = get_class_from_dynamic_module(
- image_processor_auto_map, pretrained_model_name_or_path, **kwargs
- )
+ if use_fast and image_processor_auto_map[1] is not None:
+ class_ref = image_processor_auto_map[1]
+ else:
+ class_ref = image_processor_auto_map[0]
+ image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
_ = kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
image_processor_class.register_for_auto_class()
@@ -407,8 +427,19 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
return image_processor_class.from_dict(config_dict, **kwargs)
# Last try: we use the IMAGE_PROCESSOR_MAPPING.
elif type(config) in IMAGE_PROCESSOR_MAPPING:
- image_processor_class = IMAGE_PROCESSOR_MAPPING[type(config)]
- return image_processor_class.from_dict(config_dict, **kwargs)
+ image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)]
+
+ image_processor_class_py, image_processor_fast_class = image_processor_tuple
+
+ if image_processor_fast_class and (use_fast or image_processor_class_py is None):
+ return image_processor_fast_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
+ else:
+ if image_processor_class_py is None:
+ return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
+ else:
+ raise ValueError(
+ "This image processor cannot be instantiated. Please make sure you have `torchvision` installed."
+ )
raise ValueError(
f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a "
@@ -417,7 +448,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
)
@staticmethod
- def register(config_class, image_processor_class, exist_ok=False):
+ def register(config_class, image_processor_class=None, slow_image_processor_class=None, fast_image_processor_class=None, exist_ok=False):
"""
Register a new image processor for this class.
@@ -426,4 +457,41 @@ def register(config_class, image_processor_class, exist_ok=False):
The configuration corresponding to the model to register.
image_processor_class ([`ImageProcessingMixin`]): The image processor to register.
"""
- IMAGE_PROCESSOR_MAPPING.register(config_class, image_processor_class, exist_ok=exist_ok)
+ if image_processor_class is not None:
+ if slow_image_processor_class is not None:
+ raise ValueError("Cannot specify both image_processor_class and slow_image_processor_class")
+ warnings.warn(
+ "The image_processor_class argument is deprecated and will be removed in v4.42. Please use slow_image_processor_class, or fast_image_processor_class instead",
+ FutureWarning
+ )
+ slow_image_processor_class = image_processor_class
+
+ if slow_image_processor_class is None and fast_image_processor_class is None:
+ raise ValueError("You need to specify either slow_image_processor_class or fast_image_processor_class")
+ if slow_image_processor_class is not None and issubclass(slow_image_processor_class, BaseImageProcessorFast):
+ raise ValueError("You passed a fast image processor in as the `slow_image_processor_class`.")
+ if fast_tokenizer_class is not None and issubclass(fast_image_processor_class, BaseImageProcessor):
+ raise ValueError("You passed a slow image processor in as the `fast_image_processor_class`.")
+
+ if (
+ slow_image_processor_class is not None
+ and fast_image_processor_class is not None
+ and issubclass(fast_image_processor_class, PreTrainedTokenizerFast)
+ and fast_image_processor_class.slow_image_processor_class != slow_image_processor_class
+ ):
+ raise ValueError(
+ "The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
+ "consistent with the slow tokenizer class you passed (fast tokenizer has "
+ f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
+ "so they match!"
+ )
+
+ # Avoid resetting a set slow/fast image processor if we are passing just the other ones.
+ if config_class in IMAGE_PROCESSOR_MAPPING._extra_content:
+ existing_slow, existing_fast = IMAGE_PROCESSOR_MAPPING[config_class]
+ if slow_image_processor_class is None:
+ slow_image_processor_class = existing_slow
+ if fast_image_processor_class is None:
+ fast_image_processor_class = existing_fast
+
+ IMAGE_PROCESSOR_MAPPING.register(config_class, (slow_image_processor_class, fast_image_processor_class), exist_ok=exist_ok)
From 6fc2901e4288708756a6960450768887df11de27 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Fri, 10 May 2024 16:11:38 +0000
Subject: [PATCH 05/40] Tidy up; rescale behaviour based on input type
---
.../image_processing_utils_fast.py | 15 ++++
.../models/auto/image_processing_auto.py | 22 ++++--
.../models/vit/image_processing_vit_fast.py | 77 ++++++++++++++-----
3 files changed, 86 insertions(+), 28 deletions(-)
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index b8128538bea312..aa13fc6ad4442d 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -20,6 +20,7 @@
class BaseImageProcessorFast(BaseImageProcessor):
_transform_params = None
+ _transform_settings = None
def _set_transform_settings(self, **kwargs):
settings = {}
@@ -33,15 +34,25 @@ def _same_transforms_settings(self, **kwargs):
"""
Check if the current settings are the same as the current transforms.
"""
+ if self._transform_settings is None:
+ raise ValueError("Transform settings have not been set.")
+
for key, value in kwargs.items():
if value not in self._transform_settings or value != self._transform_settings[key]:
return False
return True
def _build_transforms(self, **kwargs):
+ """
+ Given the input settings e.g. do_resize, build the image transforms.
+ """
raise NotImplementedError
def set_transforms(self, **kwargs):
+ """
+ Set the image transforms based on the given settings.
+ If the settings are the same as the current ones, do nothing.
+ """
if self._same_transforms_settings(**kwargs):
return self._transforms
@@ -51,6 +62,10 @@ def set_transforms(self, **kwargs):
@functools.lru_cache(maxsize=1)
def _maybe_update_transforms(self, **kwargs):
+ """
+ If settings are different from those stored in `self._transform_settings`, update
+ the image transforms to apply
+ """
if self._same_transforms_settings(**kwargs):
return
self.set_transforms(**kwargs)
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 88d6edda6197b8..28d7c194ce5a58 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -19,14 +19,14 @@
import os
import warnings
from collections import OrderedDict
-from typing import Dict, Optional, Union, TYPE_CHECKING
+from typing import TYPE_CHECKING, Dict, Optional, Union
# Build the list of all image processors
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
-from ...image_processing_utils import ImageProcessingMixin, BaseImageProcessor
+from ...image_processing_utils import BaseImageProcessor, ImageProcessingMixin
from ...image_processing_utils_fast import BaseImageProcessorFast
-from ...utils import CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, logging, is_torchvision_available
+from ...utils import CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, is_torchvision_available, logging
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
@@ -448,7 +448,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
)
@staticmethod
- def register(config_class, image_processor_class=None, slow_image_processor_class=None, fast_image_processor_class=None, exist_ok=False):
+ def register(
+ config_class,
+ image_processor_class=None,
+ slow_image_processor_class=None,
+ fast_image_processor_class=None,
+ exist_ok=False,
+ ):
"""
Register a new image processor for this class.
@@ -461,8 +467,8 @@ def register(config_class, image_processor_class=None, slow_image_processor_clas
if slow_image_processor_class is not None:
raise ValueError("Cannot specify both image_processor_class and slow_image_processor_class")
warnings.warn(
- "The image_processor_class argument is deprecated and will be removed in v4.42. Please use slow_image_processor_class, or fast_image_processor_class instead",
- FutureWarning
+ "The image_processor_class argument is deprecated and will be removed in v4.42. Please use slow_image_processor_class, or fast_image_processor_class instead",
+ FutureWarning,
)
slow_image_processor_class = image_processor_class
@@ -494,4 +500,6 @@ def register(config_class, image_processor_class=None, slow_image_processor_clas
if fast_image_processor_class is None:
fast_image_processor_class = existing_fast
- IMAGE_PROCESSOR_MAPPING.register(config_class, (slow_image_processor_class, fast_image_processor_class), exist_ok=exist_ok)
+ IMAGE_PROCESSOR_MAPPING.register(
+ config_class, (slow_image_processor_class, fast_image_processor_class), exist_ok=exist_ok
+ )
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index 7ade1a2c572651..4aa9d79f6b1aca 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -28,10 +28,10 @@
ChannelDimension,
ImageInput,
PILImageResampling,
- is_scaled_image,
make_list_of_images,
)
from ...utils import TensorType, logging
+from ...utils.generic import ExplictEnum
from ...utils.import_utils import is_torch_available, is_vision_available
@@ -42,8 +42,7 @@
import torch
if is_vision_available():
- from PIL import Image
- from torchvision.transforms import Compose, InterpolationMode, Normalize, Resize, ToTensor
+ from torchvision.transforms import Compose, InterpolationMode, Lambda, Normalize, Resize, ToTensor
pil_torch_interpolation_mapping = {
@@ -70,6 +69,22 @@ def __getitem__(self, key):
raise KeyError(f"Key {key} not found in SizeDict.")
+class ImageType(ExplictEnum):
+ PIL = "pillow"
+ TORCH = "torch"
+ NUMPY = "numpy"
+
+
+def get_image_type(image):
+ if is_vision_available() and isinstance(image, PIL.Image.Image):
+ return ImageType.PIL
+ if is_torch_available() and isinstance(image, torch.Tensor):
+ return ImageType.TORCH
+ if isinstance(image, np.ndarray):
+ return ImageType.NUMPY
+ raise ValueError(f"Unrecognised image type {type(image)}")
+
+
class ViTImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a ViT image processor.
@@ -158,14 +173,34 @@ def _build_transforms(
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
+ image_type: ImageType,
) -> Compose:
+ """
+ Given the input settings build the image transforms using `torchvision.transforms.Compose`.
+ """
+
+ def rescale_image(image, rescale_factor):
+ return image * rescale_factor
+
transforms = []
if do_resize:
transforms.append(
Resize((size["height"], size["width"]), interpolation=pil_torch_interpolation_mapping[resample])
)
if do_rescale:
- transforms.append(ToTensor())
+ # To maintain cross-compatibility between the slow and fast image processors, we need to
+ # be able to accept both PIL images as torch.Tensor or numpy images.
+ if image_type in (ImageType.PIL, ImageType.NUMPY):
+ transforms.append(ToTensor())
+ # ToTensor scales the pixel values to [0, 1]
+ if rescale_factor != 1 / 255:
+ rescale_factor = rescale_factor * 255
+ transforms.append(Lambda(rescale_image))
+ # If do_rescale is `True`, we should still respect it
+ elif image_type == torch.Tensor:
+ transforms.append(Lambda(rescale_image))
+ else:
+ raise ValueError(f"Unsupported image type {image_type}")
if do_normalize:
transforms.append(Normalize(image_mean, image_std))
return Compose(transforms)
@@ -183,6 +218,7 @@ def _validate_input_arguments(
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
data_format: Union[str, ChannelDimension],
+ image_type: ImageType,
):
if return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")
@@ -252,15 +288,6 @@ def preprocess(
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
- if return_tensors != "pt":
- raise ValueError("Only returning PyTorch tensors is currently supported.")
-
- if input_data_format is not None and input_data_format != ChannelDimension.FIRST:
- raise ValueError("Only channel first data format is currently supported.")
-
- if data_format != ChannelDimension.FIRST:
- raise ValueError("Only channel first data format is currently supported.")
-
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
@@ -276,14 +303,21 @@ def preprocess(
images = make_list_of_images(images)
- if do_rescale:
- if isinstance(images[0], np.ndarray) and is_scaled_image(images[0]):
- raise ValueError(
- "Images are expected to have pixel values in the range [0, 255] when do_rescale=True. "
- "Got pixel values in the range [0, 1]."
- )
- elif not isinstance(images[0], Image.Image):
- raise ValueError("Images must be of type PIL.Image.Image or np.ndarray when do_rescale=True.")
+ image_type = get_image_type(images[0])
+
+ self._validate_input_arguments(
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ return_tensors=return_tensors,
+ data_format=data_format,
+ image_type=image_type,
+ )
self._maybe_update_transforms(
do_resize=do_resize,
@@ -294,6 +328,7 @@ def preprocess(
rescale_factor=rescale_factor,
image_mean=image_mean,
image_std=image_std,
+ image_type=image_type,
)
transformed_images = [self._transforms(image) for image in images]
From 834ae6af2f6b00d9e9c21f912160b8e42591a709 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Fri, 10 May 2024 17:52:47 +0000
Subject: [PATCH 06/40] Enable tests for fast image processors
---
.../image_processing_utils_fast.py | 2 +-
.../models/auto/image_processing_auto.py | 10 +-
.../models/vit/image_processing_vit_fast.py | 25 +-
.../models/beit/test_image_processing_beit.py | 1 +
.../models/blip/test_image_processing_blip.py | 2 +
.../test_image_processing_bridgetower.py | 1 +
.../test_image_processing_chinese_clip.py | 2 +
.../models/clip/test_image_processing_clip.py | 1 +
.../test_image_processing_conditional_detr.py | 1 +
.../test_image_processing_convnext.py | 1 +
.../test_image_processing_deformable_detr.py | 1 +
.../models/deit/test_image_processing_deit.py | 1 +
.../models/detr/test_image_processing_detr.py | 1 +
.../donut/test_image_processing_donut.py | 1 +
tests/models/dpt/test_image_processing_dpt.py | 1 +
.../test_image_processing_efficientnet.py | 1 +
.../flava/test_image_processing_flava.py | 1 +
.../models/glpn/test_image_processing_glpn.py | 1 +
.../test_image_processing_grounding_dino.py | 1 +
.../idefics/test_image_processing_idefics.py | 1 +
.../test_image_processing_idefics2.py | 1 +
.../test_image_processing_imagegpt.py | 1 +
.../test_image_processing_layoutlmv2.py | 1 +
.../test_image_processing_layoutlmv3.py | 1 +
.../levit/test_image_processing_levit.py | 1 +
.../test_image_processor_llava_next.py | 1 +
.../test_image_processing_mask2former.py | 1 +
.../test_image_processing_maskformer.py | 1 +
.../test_image_processing_mobilenet_v1.py | 1 +
.../test_image_processing_mobilenet_v2.py | 1 +
.../test_image_processing_mobilevit.py | 1 +
.../nougat/test_image_processing_nougat.py | 1 +
.../test_image_processing_oneformer.py | 1 +
.../owlv2/test_image_processor_owlv2.py | 1 +
.../owlvit/test_image_processing_owlvit.py | 1 +
.../test_image_processing_pix2struct.py | 2 +
.../test_image_processing_poolformer.py | 1 +
tests/models/pvt/test_image_processing_pvt.py | 1 +
.../test_image_processing_segformer.py | 1 +
.../seggpt/test_image_processing_seggpt.py | 1 +
.../siglip/test_image_processor_siglip.py | 1 +
.../test_image_processing_superpoint.py | 1 +
.../swin2sr/test_image_processing_swin2sr.py | 1 +
tests/models/tvp/test_image_processing_tvp.py | 1 +
.../test_image_processing_videomae.py | 1 +
.../models/vilt/test_image_processing_vilt.py | 1 +
tests/models/vit/test_image_processing_vit.py | 23 +-
.../test_image_processing_vitmatte.py | 1 +
.../vivit/test_image_processing_vivit.py | 1 +
.../yolos/test_image_processing_yolos.py | 1 +
tests/test_image_processing_common.py | 300 ++++++++++--------
51 files changed, 240 insertions(+), 169 deletions(-)
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index aa13fc6ad4442d..873bbd1bb5293a 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -35,7 +35,7 @@ def _same_transforms_settings(self, **kwargs):
Check if the current settings are the same as the current transforms.
"""
if self._transform_settings is None:
- raise ValueError("Transform settings have not been set.")
+ return False
for key, value in kwargs.items():
if value not in self._transform_settings or value != self._transform_settings[key]:
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 28d7c194ce5a58..ab0533416552f7 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -19,7 +19,7 @@
import os
import warnings
from collections import OrderedDict
-from typing import TYPE_CHECKING, Dict, Optional, Union
+from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
# Build the list of all image processors
from ...configuration_utils import PretrainedConfig
@@ -476,19 +476,19 @@ def register(
raise ValueError("You need to specify either slow_image_processor_class or fast_image_processor_class")
if slow_image_processor_class is not None and issubclass(slow_image_processor_class, BaseImageProcessorFast):
raise ValueError("You passed a fast image processor in as the `slow_image_processor_class`.")
- if fast_tokenizer_class is not None and issubclass(fast_image_processor_class, BaseImageProcessor):
+ if fast_image_processor_class is not None and issubclass(fast_image_processor_class, BaseImageProcessor):
raise ValueError("You passed a slow image processor in as the `fast_image_processor_class`.")
if (
slow_image_processor_class is not None
and fast_image_processor_class is not None
- and issubclass(fast_image_processor_class, PreTrainedTokenizerFast)
+ and issubclass(fast_image_processor_class, BaseImageProcessorFast)
and fast_image_processor_class.slow_image_processor_class != slow_image_processor_class
):
raise ValueError(
- "The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
+ "The fast tokenizer class you are passing has a `slow_image_processor_class` attribute that is not "
"consistent with the slow tokenizer class you passed (fast tokenizer has "
- f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
+ f"{fast_image_processor_class.slow_image_processor_class} and you passed {slow_image_processor_class}. Fix one of those "
"so they match!"
)
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index 4aa9d79f6b1aca..6fcec725e7a2f9 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -31,8 +31,8 @@
make_list_of_images,
)
from ...utils import TensorType, logging
-from ...utils.generic import ExplictEnum
-from ...utils.import_utils import is_torch_available, is_vision_available
+from ...utils.generic import ExplicitEnum
+from ...utils.import_utils import is_torch_available, is_vision_available, is_torchvision_available
logger = logging.get_logger(__name__)
@@ -42,6 +42,9 @@
import torch
if is_vision_available():
+ from PIL import Image
+
+if is_torchvision_available():
from torchvision.transforms import Compose, InterpolationMode, Lambda, Normalize, Resize, ToTensor
@@ -69,14 +72,14 @@ def __getitem__(self, key):
raise KeyError(f"Key {key} not found in SizeDict.")
-class ImageType(ExplictEnum):
+class ImageType(ExplicitEnum):
PIL = "pillow"
TORCH = "torch"
NUMPY = "numpy"
def get_image_type(image):
- if is_vision_available() and isinstance(image, PIL.Image.Image):
+ if is_vision_available() and isinstance(image, Image.Image):
return ImageType.PIL
if is_torch_available() and isinstance(image, torch.Tensor):
return ImageType.TORCH
@@ -126,6 +129,7 @@ class ViTImageProcessorFast(BaseImageProcessorFast):
"rescale_factor",
"image_mean",
"image_std",
+ "image_type",
]
def __init__(
@@ -134,7 +138,7 @@ def __init__(
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
- rescale_factor: Union[int, float] = None,
+ rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
@@ -152,16 +156,6 @@ def __init__(
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self._transform_settings = {}
- self.set_transforms(
- do_resize=do_resize,
- do_rescale=do_rescale,
- do_normalize=do_normalize,
- size=size,
- resample=resample,
- rescale_factor=rescale_factor,
- image_mean=image_mean,
- image_std=image_std,
- )
def _build_transforms(
self,
@@ -302,7 +296,6 @@ def preprocess(
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
images = make_list_of_images(images)
-
image_type = get_image_type(images[0])
self._validate_input_arguments(
diff --git a/tests/models/beit/test_image_processing_beit.py b/tests/models/beit/test_image_processing_beit.py
index d23e54db0d35eb..e91517b3dbe08c 100644
--- a/tests/models/beit/test_image_processing_beit.py
+++ b/tests/models/beit/test_image_processing_beit.py
@@ -121,6 +121,7 @@ class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = BeitImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = BeitImageProcessingTester(self)
@property
diff --git a/tests/models/blip/test_image_processing_blip.py b/tests/models/blip/test_image_processing_blip.py
index 1d7e7f12ee9bfa..905e1dad55e269 100644
--- a/tests/models/blip/test_image_processing_blip.py
+++ b/tests/models/blip/test_image_processing_blip.py
@@ -90,6 +90,7 @@ class BlipImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = BlipImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = BlipImageProcessingTester(self)
@property
@@ -112,6 +113,7 @@ class BlipImageProcessingTestFourChannels(ImageProcessingTestMixin, unittest.Tes
image_processing_class = BlipImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = BlipImageProcessingTester(self, num_channels=4)
self.expected_encoded_image_num_channels = 3
diff --git a/tests/models/bridgetower/test_image_processing_bridgetower.py b/tests/models/bridgetower/test_image_processing_bridgetower.py
index f8837fdc964a76..1dc5419b77c886 100644
--- a/tests/models/bridgetower/test_image_processing_bridgetower.py
+++ b/tests/models/bridgetower/test_image_processing_bridgetower.py
@@ -136,6 +136,7 @@ class BridgeTowerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
image_processing_class = BridgeTowerImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = BridgeTowerImageProcessingTester(self)
@property
diff --git a/tests/models/chinese_clip/test_image_processing_chinese_clip.py b/tests/models/chinese_clip/test_image_processing_chinese_clip.py
index 7eea00f885201c..94e41e8eaa06a3 100644
--- a/tests/models/chinese_clip/test_image_processing_chinese_clip.py
+++ b/tests/models/chinese_clip/test_image_processing_chinese_clip.py
@@ -98,6 +98,7 @@ class ChineseCLIPImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
image_processing_class = ChineseCLIPImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = ChineseCLIPImageProcessingTester(self, do_center_crop=True)
@property
@@ -135,6 +136,7 @@ class ChineseCLIPImageProcessingTestFourChannels(ImageProcessingTestMixin, unitt
image_processing_class = ChineseCLIPImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = ChineseCLIPImageProcessingTester(self, num_channels=4, do_center_crop=True)
self.expected_encoded_image_num_channels = 3
diff --git a/tests/models/clip/test_image_processing_clip.py b/tests/models/clip/test_image_processing_clip.py
index a35a23d8da9b72..740399d13fbb11 100644
--- a/tests/models/clip/test_image_processing_clip.py
+++ b/tests/models/clip/test_image_processing_clip.py
@@ -94,6 +94,7 @@ class CLIPImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = CLIPImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = CLIPImageProcessingTester(self)
@property
diff --git a/tests/models/conditional_detr/test_image_processing_conditional_detr.py b/tests/models/conditional_detr/test_image_processing_conditional_detr.py
index 7bbee7e83140c9..171ec2d44f499a 100644
--- a/tests/models/conditional_detr/test_image_processing_conditional_detr.py
+++ b/tests/models/conditional_detr/test_image_processing_conditional_detr.py
@@ -131,6 +131,7 @@ class ConditionalDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcess
image_processing_class = ConditionalDetrImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = ConditionalDetrImageProcessingTester(self)
@property
diff --git a/tests/models/convnext/test_image_processing_convnext.py b/tests/models/convnext/test_image_processing_convnext.py
index 0c331741807c59..d2eaae453432ba 100644
--- a/tests/models/convnext/test_image_processing_convnext.py
+++ b/tests/models/convnext/test_image_processing_convnext.py
@@ -87,6 +87,7 @@ class ConvNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = ConvNextImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = ConvNextImageProcessingTester(self)
@property
diff --git a/tests/models/deformable_detr/test_image_processing_deformable_detr.py b/tests/models/deformable_detr/test_image_processing_deformable_detr.py
index 59ba5b59e34f13..51fbfc33f8c195 100644
--- a/tests/models/deformable_detr/test_image_processing_deformable_detr.py
+++ b/tests/models/deformable_detr/test_image_processing_deformable_detr.py
@@ -131,6 +131,7 @@ class DeformableDetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessi
image_processing_class = DeformableDetrImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = DeformableDetrImageProcessingTester(self)
@property
diff --git a/tests/models/deit/test_image_processing_deit.py b/tests/models/deit/test_image_processing_deit.py
index 21dc3d9e95a79f..462ad56d6bf45c 100644
--- a/tests/models/deit/test_image_processing_deit.py
+++ b/tests/models/deit/test_image_processing_deit.py
@@ -93,6 +93,7 @@ class DeiTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
test_cast_dtype = True
def setUp(self):
+ super().setUp()
self.image_processor_tester = DeiTImageProcessingTester(self)
@property
diff --git a/tests/models/detr/test_image_processing_detr.py b/tests/models/detr/test_image_processing_detr.py
index 7f9f18b9d49f4f..fc6d5651272459 100644
--- a/tests/models/detr/test_image_processing_detr.py
+++ b/tests/models/detr/test_image_processing_detr.py
@@ -130,6 +130,7 @@ class DetrImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMixi
image_processing_class = DetrImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = DetrImageProcessingTester(self)
@property
diff --git a/tests/models/donut/test_image_processing_donut.py b/tests/models/donut/test_image_processing_donut.py
index c1a2bd3b26ec46..9d96eb8ede27f4 100644
--- a/tests/models/donut/test_image_processing_donut.py
+++ b/tests/models/donut/test_image_processing_donut.py
@@ -99,6 +99,7 @@ class DonutImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = DonutImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = DonutImageProcessingTester(self)
@property
diff --git a/tests/models/dpt/test_image_processing_dpt.py b/tests/models/dpt/test_image_processing_dpt.py
index 2cc72274c4a7d9..aa1b954a08a26f 100644
--- a/tests/models/dpt/test_image_processing_dpt.py
+++ b/tests/models/dpt/test_image_processing_dpt.py
@@ -86,6 +86,7 @@ class DPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = DPTImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = DPTImageProcessingTester(self)
@property
diff --git a/tests/models/efficientnet/test_image_processing_efficientnet.py b/tests/models/efficientnet/test_image_processing_efficientnet.py
index fd754d8eb9e97c..28b701c5c9aee0 100644
--- a/tests/models/efficientnet/test_image_processing_efficientnet.py
+++ b/tests/models/efficientnet/test_image_processing_efficientnet.py
@@ -86,6 +86,7 @@ class EfficientNetImageProcessorTest(ImageProcessingTestMixin, unittest.TestCase
image_processing_class = EfficientNetImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = EfficientNetImageProcessorTester(self)
@property
diff --git a/tests/models/flava/test_image_processing_flava.py b/tests/models/flava/test_image_processing_flava.py
index d89a1a6f6bfb58..04457e51acfdb5 100644
--- a/tests/models/flava/test_image_processing_flava.py
+++ b/tests/models/flava/test_image_processing_flava.py
@@ -175,6 +175,7 @@ class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
maxDiff = None
def setUp(self):
+ super().setUp()
self.image_processor_tester = FlavaImageProcessingTester(self)
@property
diff --git a/tests/models/glpn/test_image_processing_glpn.py b/tests/models/glpn/test_image_processing_glpn.py
index f9cadb33137843..abffb31a66936c 100644
--- a/tests/models/glpn/test_image_processing_glpn.py
+++ b/tests/models/glpn/test_image_processing_glpn.py
@@ -93,6 +93,7 @@ class GLPNImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = GLPNImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = GLPNImageProcessingTester(self)
@property
diff --git a/tests/models/grounding_dino/test_image_processing_grounding_dino.py b/tests/models/grounding_dino/test_image_processing_grounding_dino.py
index 6d20a019814b65..68618fb256aa7a 100644
--- a/tests/models/grounding_dino/test_image_processing_grounding_dino.py
+++ b/tests/models/grounding_dino/test_image_processing_grounding_dino.py
@@ -146,6 +146,7 @@ class GroundingDinoImageProcessingTest(AnnotationFormatTestMixin, ImageProcessin
image_processing_class = GroundingDinoImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = GroundingDinoImageProcessingTester(self)
@property
diff --git a/tests/models/idefics/test_image_processing_idefics.py b/tests/models/idefics/test_image_processing_idefics.py
index de42a421cd877e..0273480333f1be 100644
--- a/tests/models/idefics/test_image_processing_idefics.py
+++ b/tests/models/idefics/test_image_processing_idefics.py
@@ -127,6 +127,7 @@ class IdeficsImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = IdeficsImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = IdeficsImageProcessingTester(self)
@property
diff --git a/tests/models/idefics2/test_image_processing_idefics2.py b/tests/models/idefics2/test_image_processing_idefics2.py
index 4b3af1f6320608..2e0d36e75c8a08 100644
--- a/tests/models/idefics2/test_image_processing_idefics2.py
+++ b/tests/models/idefics2/test_image_processing_idefics2.py
@@ -185,6 +185,7 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = Idefics2ImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = Idefics2ImageProcessingTester(self)
@property
diff --git a/tests/models/imagegpt/test_image_processing_imagegpt.py b/tests/models/imagegpt/test_image_processing_imagegpt.py
index 4596d742a282bc..0d91824f195ea5 100644
--- a/tests/models/imagegpt/test_image_processing_imagegpt.py
+++ b/tests/models/imagegpt/test_image_processing_imagegpt.py
@@ -96,6 +96,7 @@ class ImageGPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = ImageGPTImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = ImageGPTImageProcessingTester(self)
@property
diff --git a/tests/models/layoutlmv2/test_image_processing_layoutlmv2.py b/tests/models/layoutlmv2/test_image_processing_layoutlmv2.py
index eebb7420be30b0..4413c8d756b2bb 100644
--- a/tests/models/layoutlmv2/test_image_processing_layoutlmv2.py
+++ b/tests/models/layoutlmv2/test_image_processing_layoutlmv2.py
@@ -76,6 +76,7 @@ class LayoutLMv2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
image_processing_class = LayoutLMv2ImageProcessor if is_pytesseract_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = LayoutLMv2ImageProcessingTester(self)
@property
diff --git a/tests/models/layoutlmv3/test_image_processing_layoutlmv3.py b/tests/models/layoutlmv3/test_image_processing_layoutlmv3.py
index 8d4b64c2ccd409..a12fb6af0d599a 100644
--- a/tests/models/layoutlmv3/test_image_processing_layoutlmv3.py
+++ b/tests/models/layoutlmv3/test_image_processing_layoutlmv3.py
@@ -76,6 +76,7 @@ class LayoutLMv3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
image_processing_class = LayoutLMv3ImageProcessor if is_pytesseract_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = LayoutLMv3ImageProcessingTester(self)
@property
diff --git a/tests/models/levit/test_image_processing_levit.py b/tests/models/levit/test_image_processing_levit.py
index 756993c6b67400..882707629036f7 100644
--- a/tests/models/levit/test_image_processing_levit.py
+++ b/tests/models/levit/test_image_processing_levit.py
@@ -91,6 +91,7 @@ class LevitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = LevitImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = LevitImageProcessingTester(self)
@property
diff --git a/tests/models/llava_next/test_image_processor_llava_next.py b/tests/models/llava_next/test_image_processor_llava_next.py
index 8b1f98bbcaefc4..ff5c9e970874cf 100644
--- a/tests/models/llava_next/test_image_processor_llava_next.py
+++ b/tests/models/llava_next/test_image_processor_llava_next.py
@@ -105,6 +105,7 @@ class LlavaNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->LlavaNext
def setUp(self):
+ super().setUp()
self.image_processor_tester = LlavaNextImageProcessingTester(self)
@property
diff --git a/tests/models/mask2former/test_image_processing_mask2former.py b/tests/models/mask2former/test_image_processing_mask2former.py
index 9e7045c480699f..ae0fff89069054 100644
--- a/tests/models/mask2former/test_image_processing_mask2former.py
+++ b/tests/models/mask2former/test_image_processing_mask2former.py
@@ -149,6 +149,7 @@ class Mask2FormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
image_processing_class = Mask2FormerImageProcessor if (is_vision_available() and is_torch_available()) else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = Mask2FormerImageProcessingTester(self)
@property
diff --git a/tests/models/maskformer/test_image_processing_maskformer.py b/tests/models/maskformer/test_image_processing_maskformer.py
index fca65765959bc7..5d30431f1f2aad 100644
--- a/tests/models/maskformer/test_image_processing_maskformer.py
+++ b/tests/models/maskformer/test_image_processing_maskformer.py
@@ -149,6 +149,7 @@ class MaskFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
image_processing_class = MaskFormerImageProcessor if (is_vision_available() and is_torch_available()) else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = MaskFormerImageProcessingTester(self)
@property
diff --git a/tests/models/mobilenet_v1/test_image_processing_mobilenet_v1.py b/tests/models/mobilenet_v1/test_image_processing_mobilenet_v1.py
index ce0ecba34c0e54..c9d32b0bab679e 100644
--- a/tests/models/mobilenet_v1/test_image_processing_mobilenet_v1.py
+++ b/tests/models/mobilenet_v1/test_image_processing_mobilenet_v1.py
@@ -82,6 +82,7 @@ class MobileNetV1ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
image_processing_class = MobileNetV1ImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = MobileNetV1ImageProcessingTester(self)
@property
diff --git a/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py b/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py
index 4c94be47212f2f..e9cdf4a4359e57 100644
--- a/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py
+++ b/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py
@@ -82,6 +82,7 @@ class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
image_processing_class = MobileNetV2ImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = MobileNetV2ImageProcessingTester(self)
@property
diff --git a/tests/models/mobilevit/test_image_processing_mobilevit.py b/tests/models/mobilevit/test_image_processing_mobilevit.py
index 92e1a55947b168..9895befc8f4fd9 100644
--- a/tests/models/mobilevit/test_image_processing_mobilevit.py
+++ b/tests/models/mobilevit/test_image_processing_mobilevit.py
@@ -112,6 +112,7 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = MobileViTImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = MobileViTImageProcessingTester(self)
@property
diff --git a/tests/models/nougat/test_image_processing_nougat.py b/tests/models/nougat/test_image_processing_nougat.py
index fc61ecbc1988d5..5ab2901d31e862 100644
--- a/tests/models/nougat/test_image_processing_nougat.py
+++ b/tests/models/nougat/test_image_processing_nougat.py
@@ -111,6 +111,7 @@ class NougatImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = NougatImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = NougatImageProcessingTester(self)
@property
diff --git a/tests/models/oneformer/test_image_processing_oneformer.py b/tests/models/oneformer/test_image_processing_oneformer.py
index 245af190c9a9d1..e60cc31b30feee 100644
--- a/tests/models/oneformer/test_image_processing_oneformer.py
+++ b/tests/models/oneformer/test_image_processing_oneformer.py
@@ -159,6 +159,7 @@ class OneFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = image_processing_class
def setUp(self):
+ super().setUp()
self.image_processor_tester = OneFormerImageProcessorTester(self)
@property
diff --git a/tests/models/owlv2/test_image_processor_owlv2.py b/tests/models/owlv2/test_image_processor_owlv2.py
index 87b96d06547cdf..51814b6dd806aa 100644
--- a/tests/models/owlv2/test_image_processor_owlv2.py
+++ b/tests/models/owlv2/test_image_processor_owlv2.py
@@ -90,6 +90,7 @@ class Owlv2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = Owlv2ImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = Owlv2ImageProcessingTester(self)
@property
diff --git a/tests/models/owlvit/test_image_processing_owlvit.py b/tests/models/owlvit/test_image_processing_owlvit.py
index f4897c051ec34b..4442b1a65a7f81 100644
--- a/tests/models/owlvit/test_image_processing_owlvit.py
+++ b/tests/models/owlvit/test_image_processing_owlvit.py
@@ -92,6 +92,7 @@ class OwlViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = OwlViTImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = OwlViTImageProcessingTester(self)
@property
diff --git a/tests/models/pix2struct/test_image_processing_pix2struct.py b/tests/models/pix2struct/test_image_processing_pix2struct.py
index f0b94c4cf5a071..09e1abd8068989 100644
--- a/tests/models/pix2struct/test_image_processing_pix2struct.py
+++ b/tests/models/pix2struct/test_image_processing_pix2struct.py
@@ -87,6 +87,7 @@ class Pix2StructImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
image_processing_class = Pix2StructImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = Pix2StructImageProcessingTester(self)
@property
@@ -288,6 +289,7 @@ class Pix2StructImageProcessingTestFourChannels(ImageProcessingTestMixin, unitte
image_processing_class = Pix2StructImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = Pix2StructImageProcessingTester(self, num_channels=4)
self.expected_encoded_image_num_channels = 3
diff --git a/tests/models/poolformer/test_image_processing_poolformer.py b/tests/models/poolformer/test_image_processing_poolformer.py
index 017a511c408511..af4c2bcbb55e13 100644
--- a/tests/models/poolformer/test_image_processing_poolformer.py
+++ b/tests/models/poolformer/test_image_processing_poolformer.py
@@ -88,6 +88,7 @@ class PoolFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
image_processing_class = PoolFormerImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = PoolFormerImageProcessingTester(self)
@property
diff --git a/tests/models/pvt/test_image_processing_pvt.py b/tests/models/pvt/test_image_processing_pvt.py
index d6b11313d81147..d24421fc74102e 100644
--- a/tests/models/pvt/test_image_processing_pvt.py
+++ b/tests/models/pvt/test_image_processing_pvt.py
@@ -84,6 +84,7 @@ class PvtImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = PvtImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = PvtImageProcessingTester(self)
@property
diff --git a/tests/models/segformer/test_image_processing_segformer.py b/tests/models/segformer/test_image_processing_segformer.py
index bee6a4a24b3f1c..988843b710f6bb 100644
--- a/tests/models/segformer/test_image_processing_segformer.py
+++ b/tests/models/segformer/test_image_processing_segformer.py
@@ -112,6 +112,7 @@ class SegformerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = SegformerImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = SegformerImageProcessingTester(self)
@property
diff --git a/tests/models/seggpt/test_image_processing_seggpt.py b/tests/models/seggpt/test_image_processing_seggpt.py
index 04cefb70d0efb4..f79b7ea44370dc 100644
--- a/tests/models/seggpt/test_image_processing_seggpt.py
+++ b/tests/models/seggpt/test_image_processing_seggpt.py
@@ -114,6 +114,7 @@ class SegGptImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = SegGptImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = SegGptImageProcessingTester(self)
@property
diff --git a/tests/models/siglip/test_image_processor_siglip.py b/tests/models/siglip/test_image_processor_siglip.py
index 5f43d6f08ab111..7dbd05070c66b4 100644
--- a/tests/models/siglip/test_image_processor_siglip.py
+++ b/tests/models/siglip/test_image_processor_siglip.py
@@ -91,6 +91,7 @@ class SiglipImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = SiglipImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = SiglipImageProcessingTester(self)
@property
diff --git a/tests/models/superpoint/test_image_processing_superpoint.py b/tests/models/superpoint/test_image_processing_superpoint.py
index 19406bc91ad06f..90bbf82d1ed80a 100644
--- a/tests/models/superpoint/test_image_processing_superpoint.py
+++ b/tests/models/superpoint/test_image_processing_superpoint.py
@@ -77,6 +77,7 @@ class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
image_processing_class = SuperPointImageProcessor if is_vision_available() else None
def setUp(self) -> None:
+ super().setUp()
self.image_processor_tester = SuperPointImageProcessingTester(self)
@property
diff --git a/tests/models/swin2sr/test_image_processing_swin2sr.py b/tests/models/swin2sr/test_image_processing_swin2sr.py
index 719ac79d09db23..732a7e95412a88 100644
--- a/tests/models/swin2sr/test_image_processing_swin2sr.py
+++ b/tests/models/swin2sr/test_image_processing_swin2sr.py
@@ -98,6 +98,7 @@ class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = Swin2SRImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = Swin2SRImageProcessingTester(self)
@property
diff --git a/tests/models/tvp/test_image_processing_tvp.py b/tests/models/tvp/test_image_processing_tvp.py
index 1c9a84beb8427b..7de45d4bee06bf 100644
--- a/tests/models/tvp/test_image_processing_tvp.py
+++ b/tests/models/tvp/test_image_processing_tvp.py
@@ -127,6 +127,7 @@ class TvpImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = TvpImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = TvpImageProcessingTester(self)
@property
diff --git a/tests/models/videomae/test_image_processing_videomae.py b/tests/models/videomae/test_image_processing_videomae.py
index 4a6f0b93c4dde9..319e39fcc2cced 100644
--- a/tests/models/videomae/test_image_processing_videomae.py
+++ b/tests/models/videomae/test_image_processing_videomae.py
@@ -99,6 +99,7 @@ class VideoMAEImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = VideoMAEImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = VideoMAEImageProcessingTester(self)
@property
diff --git a/tests/models/vilt/test_image_processing_vilt.py b/tests/models/vilt/test_image_processing_vilt.py
index 607a8b929d1f8b..f68b2d2628ad7c 100644
--- a/tests/models/vilt/test_image_processing_vilt.py
+++ b/tests/models/vilt/test_image_processing_vilt.py
@@ -130,6 +130,7 @@ class ViltImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = ViltImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = ViltImageProcessingTester(self)
@property
diff --git a/tests/models/vit/test_image_processing_vit.py b/tests/models/vit/test_image_processing_vit.py
index c1c22c0a800a40..1c376f55aa3e98 100644
--- a/tests/models/vit/test_image_processing_vit.py
+++ b/tests/models/vit/test_image_processing_vit.py
@@ -84,6 +84,7 @@ class ViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = ViTImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = ViTImageProcessingTester(self)
@property
@@ -91,16 +92,18 @@ def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
- image_processing = self.image_processing_class(**self.image_processor_dict)
- self.assertTrue(hasattr(image_processing, "image_mean"))
- self.assertTrue(hasattr(image_processing, "image_std"))
- self.assertTrue(hasattr(image_processing, "do_normalize"))
- self.assertTrue(hasattr(image_processing, "do_resize"))
- self.assertTrue(hasattr(image_processing, "size"))
+ for image_processing_class in self.image_processor_list:
+ image_processing = image_processing_class(**self.image_processor_dict)
+ self.assertTrue(hasattr(image_processing, "image_mean"))
+ self.assertTrue(hasattr(image_processing, "image_std"))
+ self.assertTrue(hasattr(image_processing, "do_normalize"))
+ self.assertTrue(hasattr(image_processing, "do_resize"))
+ self.assertTrue(hasattr(image_processing, "size"))
def test_image_processor_from_dict_with_kwargs(self):
- image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
- self.assertEqual(image_processor.size, {"height": 18, "width": 18})
+ for image_processing_class in self.image_processor_list:
+ image_processor = image_processing_class.from_dict(self.image_processor_dict)
+ self.assertEqual(image_processor.size, {"height": 18, "width": 18})
- image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
- self.assertEqual(image_processor.size, {"height": 42, "width": 42})
+ image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42)
+ self.assertEqual(image_processor.size, {"height": 42, "width": 42})
diff --git a/tests/models/vitmatte/test_image_processing_vitmatte.py b/tests/models/vitmatte/test_image_processing_vitmatte.py
index e86cfde1e5cb5d..8aebee3735f4f9 100644
--- a/tests/models/vitmatte/test_image_processing_vitmatte.py
+++ b/tests/models/vitmatte/test_image_processing_vitmatte.py
@@ -94,6 +94,7 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = VitMatteImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = VitMatteImageProcessingTester(self)
@property
diff --git a/tests/models/vivit/test_image_processing_vivit.py b/tests/models/vivit/test_image_processing_vivit.py
index dad120ef818e9b..0e8301f66734fd 100644
--- a/tests/models/vivit/test_image_processing_vivit.py
+++ b/tests/models/vivit/test_image_processing_vivit.py
@@ -99,6 +99,7 @@ class VivitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = VivitImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = VivitImageProcessingTester(self)
@property
diff --git a/tests/models/yolos/test_image_processing_yolos.py b/tests/models/yolos/test_image_processing_yolos.py
index f04015ac0c9b22..a94cd8b883e834 100644
--- a/tests/models/yolos/test_image_processing_yolos.py
+++ b/tests/models/yolos/test_image_processing_yolos.py
@@ -143,6 +143,7 @@ class YolosImageProcessingTest(AnnotationFormatTestMixin, ImageProcessingTestMix
image_processing_class = YolosImageProcessor if is_vision_available() else None
def setUp(self):
+ super().setUp()
self.image_processor_tester = YolosImageProcessingTester(self)
@property
diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py
index 90c1a4e7e12708..815d3fa7271bd2 100644
--- a/tests/test_image_processing_common.py
+++ b/tests/test_image_processing_common.py
@@ -129,176 +129,202 @@ def prepare_video_inputs(
class ImageProcessingTestMixin:
test_cast_dtype = None
+ image_processing_class = None
+ fast_image_processing_class = None
+ image_processors_list = None
+ test_slow_image_processor = True
+ test_fast_image_processor = True
+
+ def setUp(self):
+ image_processor_list = []
+
+ if self.test_slow_image_processor and self.image_processing_class:
+ image_processor_list.append(self.image_processing_class)
+
+ if self.test_fast_image_processor and self.fast_image_processing_class:
+ image_processor_list.append(self.fast_image_processing_class)
+
+ self.image_processor_list = image_processor_list
def test_image_processor_to_json_string(self):
- image_processor = self.image_processing_class(**self.image_processor_dict)
- obj = json.loads(image_processor.to_json_string())
- for key, value in self.image_processor_dict.items():
- self.assertEqual(obj[key], value)
+ for image_processing_class in self.image_processor_list:
+ image_processor = image_processing_class(**self.image_processor_dict)
+ obj = json.loads(image_processor.to_json_string())
+ for key, value in self.image_processor_dict.items():
+ self.assertEqual(obj[key], value)
def test_image_processor_to_json_file(self):
- image_processor_first = self.image_processing_class(**self.image_processor_dict)
+ for image_processing_class in self.image_processor_list:
+ image_processor_first = image_processing_class(**self.image_processor_dict)
- with tempfile.TemporaryDirectory() as tmpdirname:
- json_file_path = os.path.join(tmpdirname, "image_processor.json")
- image_processor_first.to_json_file(json_file_path)
- image_processor_second = self.image_processing_class.from_json_file(json_file_path)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ json_file_path = os.path.join(tmpdirname, "image_processor.json")
+ image_processor_first.to_json_file(json_file_path)
+ image_processor_second = image_processing_class.from_json_file(json_file_path)
- self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
+ self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
def test_image_processor_from_and_save_pretrained(self):
- image_processor_first = self.image_processing_class(**self.image_processor_dict)
+ for image_processing_class in self.image_processor_list:
+ image_processor_first = image_processing_class(**self.image_processor_dict)
- with tempfile.TemporaryDirectory() as tmpdirname:
- saved_file = image_processor_first.save_pretrained(tmpdirname)[0]
- check_json_file_has_correct_format(saved_file)
- image_processor_second = self.image_processing_class.from_pretrained(tmpdirname)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ saved_file = image_processor_first.save_pretrained(tmpdirname)[0]
+ check_json_file_has_correct_format(saved_file)
+ image_processor_second = image_processing_class.from_pretrained(tmpdirname)
- self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
+ self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
def test_init_without_params(self):
- image_processor = self.image_processing_class()
- self.assertIsNotNone(image_processor)
+ for image_processing_class in self.image_processor_list:
+ image_processor = image_processing_class()
+ self.assertIsNotNone(image_processor)
@require_torch
@require_vision
def test_cast_dtype_device(self):
- if self.test_cast_dtype is not None:
- # Initialize image_processor
- image_processor = self.image_processing_class(**self.image_processor_dict)
+ for image_processing_class in self.image_processor_list:
+ if self.test_cast_dtype is not None:
+ # Initialize image_processor
+ image_processor = image_processing_class(**self.image_processor_dict)
- # create random PyTorch tensors
- image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
+ # create random PyTorch tensors
+ image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
- encoding = image_processor(image_inputs, return_tensors="pt")
- # for layoutLM compatiblity
- self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
- self.assertEqual(encoding.pixel_values.dtype, torch.float32)
+ encoding = image_processor(image_inputs, return_tensors="pt")
+ # for layoutLM compatiblity
+ self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
+ self.assertEqual(encoding.pixel_values.dtype, torch.float32)
- encoding = image_processor(image_inputs, return_tensors="pt").to(torch.float16)
- self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
- self.assertEqual(encoding.pixel_values.dtype, torch.float16)
+ encoding = image_processor(image_inputs, return_tensors="pt").to(torch.float16)
+ self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
+ self.assertEqual(encoding.pixel_values.dtype, torch.float16)
- encoding = image_processor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16)
- self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
- self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16)
+ encoding = image_processor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16)
+ self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
+ self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16)
- with self.assertRaises(TypeError):
- _ = image_processor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu")
+ with self.assertRaises(TypeError):
+ _ = image_processor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu")
- # Try with text + image feature
- encoding = image_processor(image_inputs, return_tensors="pt")
- encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])})
- encoding = encoding.to(torch.float16)
+ # Try with text + image feature
+ encoding = image_processor(image_inputs, return_tensors="pt")
+ encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])})
+ encoding = encoding.to(torch.float16)
- self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
- self.assertEqual(encoding.pixel_values.dtype, torch.float16)
- self.assertEqual(encoding.input_ids.dtype, torch.long)
+ self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
+ self.assertEqual(encoding.pixel_values.dtype, torch.float16)
+ self.assertEqual(encoding.input_ids.dtype, torch.long)
def test_call_pil(self):
- # Initialize image_processing
- image_processing = self.image_processing_class(**self.image_processor_dict)
- # create random PIL images
- image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
- for image in image_inputs:
- self.assertIsInstance(image, Image.Image)
-
- # Test not batched input
- encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
- expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
- self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
-
- # Test batched
- encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
- expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
- self.assertEqual(
- tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
- )
+ for image_processing_class in self.image_processor_list:
+ # Initialize image_processing
+ image_processing = image_processing_class(**self.image_processor_dict)
+ # create random PIL images
+ image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
+ self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
+
+ # Test batched
+ encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
+ self.assertEqual(
+ tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
+ )
def test_call_numpy(self):
- # Initialize image_processing
- image_processing = self.image_processing_class(**self.image_processor_dict)
- # create random numpy tensors
- image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
- for image in image_inputs:
- self.assertIsInstance(image, np.ndarray)
-
- # Test not batched input
- encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
- expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
- self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
-
- # Test batched
- encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
- expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
- self.assertEqual(
- tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
- )
+ for image_processing_class in self.image_processor_list:
+ # Initialize image_processing
+ image_processing = image_processing_class(**self.image_processor_dict)
+ # create random numpy tensors
+ image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, np.ndarray)
+
+ # Test not batched input
+ encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
+ self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
+
+ # Test batched
+ encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
+ self.assertEqual(
+ tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
+ )
def test_call_pytorch(self):
- # Initialize image_processing
- image_processing = self.image_processing_class(**self.image_processor_dict)
- # create random PyTorch tensors
- image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
-
- for image in image_inputs:
- self.assertIsInstance(image, torch.Tensor)
-
- # Test not batched input
- encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
- expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
- self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
-
- # Test batched
- expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
- encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
- self.assertEqual(
- tuple(encoded_images.shape),
- (self.image_processor_tester.batch_size, *expected_output_image_shape),
- )
+ for image_processing_class in self.image_processor_list:
+ # Initialize image_processing
+ image_processing = image_processing_class(**self.image_processor_dict)
+ # create random PyTorch tensors
+ image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
+
+ for image in image_inputs:
+ self.assertIsInstance(image, torch.Tensor)
+
+ # Test not batched input
+ encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
+ self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
+
+ # Test batched
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
+ encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ tuple(encoded_images.shape),
+ (self.image_processor_tester.batch_size, *expected_output_image_shape),
+ )
def test_call_numpy_4_channels(self):
- # Test that can process images which have an arbitrary number of channels
- # Initialize image_processing
- image_processor = self.image_processing_class(**self.image_processor_dict)
-
- # create random numpy tensors
- self.image_processor_tester.num_channels = 4
- image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
-
- # Test not batched input
- encoded_images = image_processor(
- image_inputs[0],
- return_tensors="pt",
- input_data_format="channels_first",
- image_mean=0,
- image_std=1,
- ).pixel_values
- expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
- self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
-
- # Test batched
- encoded_images = image_processor(
- image_inputs,
- return_tensors="pt",
- input_data_format="channels_first",
- image_mean=0,
- image_std=1,
- ).pixel_values
- expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
- self.assertEqual(
- tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
- )
+ for image_processing_class in self.image_processor_list:
+ # Test that can process images which have an arbitrary number of channels
+ # Initialize image_processing
+ image_processor = image_processing_class(**self.image_processor_dict)
+
+ # create random numpy tensors
+ self.image_processor_tester.num_channels = 4
+ image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
+
+ # Test not batched input
+ encoded_images = image_processor(
+ image_inputs[0],
+ return_tensors="pt",
+ input_data_format="channels_first",
+ image_mean=0,
+ image_std=1,
+ ).pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
+ self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
+
+ # Test batched
+ encoded_images = image_processor(
+ image_inputs,
+ return_tensors="pt",
+ input_data_format="channels_first",
+ image_mean=0,
+ image_std=1,
+ ).pixel_values
+ expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
+ self.assertEqual(
+ tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
+ )
def test_image_processor_preprocess_arguments(self):
- image_processor = self.image_processing_class(**self.image_processor_dict)
- if hasattr(image_processor, "_valid_processor_keys") and hasattr(image_processor, "preprocess"):
- preprocess_parameter_names = inspect.getfullargspec(image_processor.preprocess).args
- preprocess_parameter_names.remove("self")
- preprocess_parameter_names.sort()
- valid_processor_keys = image_processor._valid_processor_keys
- valid_processor_keys.sort()
- self.assertEqual(preprocess_parameter_names, valid_processor_keys)
+ for image_processing_class in self.image_processor_list:
+ image_processor = image_processing_class(**self.image_processor_dict)
+ if hasattr(image_processor, "_valid_processor_keys") and hasattr(image_processor, "preprocess"):
+ preprocess_parameter_names = inspect.getfullargspec(image_processor.preprocess).args
+ preprocess_parameter_names.remove("self")
+ preprocess_parameter_names.sort()
+ valid_processor_keys = image_processor._valid_processor_keys
+ valid_processor_keys.sort()
+ self.assertEqual(preprocess_parameter_names, valid_processor_keys)
class AnnotationFormatTestMixin:
From 6d5c328ea2e391eac9b4303cf0f52b68f41f85da Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Mon, 13 May 2024 15:19:55 +0000
Subject: [PATCH 07/40] Smarter rescaling
---
.../models/vit/image_processing_vit_fast.py | 37 +++++++++++++------
1 file changed, 26 insertions(+), 11 deletions(-)
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index 6fcec725e7a2f9..2f7cc053d1a2d8 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -32,7 +32,7 @@
)
from ...utils import TensorType, logging
from ...utils.generic import ExplicitEnum
-from ...utils.import_utils import is_torch_available, is_vision_available, is_torchvision_available
+from ...utils.import_utils import is_torch_available, is_torchvision_available, is_vision_available
logger = logging.get_logger(__name__)
@@ -181,20 +181,35 @@ def rescale_image(image, rescale_factor):
transforms.append(
Resize((size["height"], size["width"]), interpolation=pil_torch_interpolation_mapping[resample])
)
+
+ # Regardless of whether we rescale, all PIL and numpy values need to be converted to a torch tensor
+ # to keep cross compatibility with slow image processors
+ convert_to_tensor = image_type in (ImageType.PIL, ImageType.NUMPY)
+ if convert_to_tensor:
+ transforms.append(ToTensor())
+
if do_rescale:
- # To maintain cross-compatibility between the slow and fast image processors, we need to
- # be able to accept both PIL images as torch.Tensor or numpy images.
- if image_type in (ImageType.PIL, ImageType.NUMPY):
- transforms.append(ToTensor())
- # ToTensor scales the pixel values to [0, 1]
+ if convert_to_tensor:
+ # ToTensor scales the pixel values to [0, 1] by dividing by the largest value in the image.
+ # By default, the rescale factor for the image processor is 1 / 255, i.e. assuming the maximum
+ # possible value is 255. Here, if it's different, we need to undo the (assumed) 1/255 scaling
+ # and then rescale again
+ #
+ # NB: This means that the final pixel values will be different in the torchvision transform
+ # depending on the pixels in the image as they become [min_val / max_value, max_value / max_value]
+ # whereas in the image processors they are [min_value * rescale_factor, max_value * rescale_factor]
if rescale_factor != 1 / 255:
rescale_factor = rescale_factor * 255
- transforms.append(Lambda(rescale_image))
- # If do_rescale is `True`, we should still respect it
- elif image_type == torch.Tensor:
- transforms.append(Lambda(rescale_image))
+ transforms.append(Lambda(functools.partial(rescale_image, rescale_factor=rescale_factor)))
else:
- raise ValueError(f"Unsupported image type {image_type}")
+ # If do_rescale is `True`, we should still respect it
+ transforms.append(Lambda(functools.partial(rescale_image, rescale_factor=rescale_factor)))
+ elif convert_to_tensor:
+ # If we've converted to a tensor and do_rescale=False, then we need to unscale.
+ # As with do_scale=True, we assume that the pixel values were rescaled by 1/255
+ rescale_factor = 255
+ transforms.append(Lambda(functools.partial(rescale_image, rescale_factor=rescale_factor)))
+
if do_normalize:
transforms.append(Normalize(image_mean, image_std))
return Compose(transforms)
From eb701c1197a36aae94b2f8e836c0fe6f797d0ff7 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Thu, 16 May 2024 12:28:12 +0000
Subject: [PATCH 08/40] Don't default to Fast
---
src/transformers/models/auto/image_processing_auto.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index ab0533416552f7..74e66f819cd049 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -323,7 +323,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
- use_fast (`bool`, *optional*, defaults to `True`):
+ use_fast (`bool`, *optional*, defaults to `False`):
Use a fast torchvision-base image processor if it is supported for a given model.
If a fast tokenizer is not available for a given model, a normal numpy-based image processor
is returned instead.
@@ -371,7 +371,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
kwargs["token"] = use_auth_token
config = kwargs.pop("config", None)
- use_fast = kwargs.pop("use_fast", True)
+ use_fast = kwargs.pop("use_fast", False)
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True
From 415be88dceb6c078a48c474fd65437845d21b3a9 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Thu, 16 May 2024 12:51:31 +0000
Subject: [PATCH 09/40] Safer imports
---
.../source/en/main_classes/image_processor.md | 2 +-
src/transformers/__init__.py | 30 ++++++++++++++++---
.../models/auto/image_processing_auto.py | 2 +-
src/transformers/models/vit/__init__.py | 16 ++++++++++
.../models/vit/image_processing_vit_fast.py | 21 +++++++------
.../utils/dummy_torchvision_objects.py | 16 ++++++++++
.../utils/dummy_vision_objects.py | 14 ---------
.../test_image_processing_video_llava.py | 1 +
8 files changed, 71 insertions(+), 31 deletions(-)
create mode 100644 src/transformers/utils/dummy_torchvision_objects.py
diff --git a/docs/source/en/main_classes/image_processor.md b/docs/source/en/main_classes/image_processor.md
index 1c65be6f350088..59a78e68214d6d 100644
--- a/docs/source/en/main_classes/image_processor.md
+++ b/docs/source/en/main_classes/image_processor.md
@@ -36,4 +36,4 @@ An image processor is in charge of preparing input features for vision models an
## BaseImageProcessorFast
-[[autodoc]] image_processing_utils.BaseImageProcessorFast
+[[autodoc]] image_processing_utils_fast.BaseImageProcessorFast
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index ef702df9c4be2b..c668c397979745 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -1106,7 +1106,6 @@
else:
_import_structure["image_processing_base"] = ["ImageProcessingMixin"]
_import_structure["image_processing_utils"] = ["BaseImageProcessor"]
- _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
_import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"])
_import_structure["models.bit"].extend(["BitImageProcessor"])
@@ -1164,12 +1163,24 @@
_import_structure["models.video_llava"].append("VideoLlavaImageProcessor")
_import_structure["models.videomae"].extend(["VideoMAEFeatureExtractor", "VideoMAEImageProcessor"])
_import_structure["models.vilt"].extend(["ViltFeatureExtractor", "ViltImageProcessor", "ViltProcessor"])
- _import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor", "ViTImageProcessorFast"])
+ _import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor"])
_import_structure["models.vit_hybrid"].extend(["ViTHybridImageProcessor"])
_import_structure["models.vitmatte"].append("VitMatteImageProcessor")
_import_structure["models.vivit"].append("VivitImageProcessor")
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
+try:
+ if not is_torchvision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_torchvision_objects
+
+ _import_structure["utils.dummy_torchvision_objects"] = [
+ name for name in dir(dummy_torchvision_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
+ _import_structure["models.vit"].append("ViTImageProcessorFast")
# PyTorch-backed objects
try:
@@ -5708,7 +5719,6 @@
else:
from .image_processing_base import ImageProcessingMixin
from .image_processing_utils import BaseImageProcessor
- from .image_processing_utils_fast import BaseImageProcessorFast
from .image_utils import ImageFeatureExtractionMixin
from .models.beit import BeitFeatureExtractor, BeitImageProcessor
from .models.bit import BitImageProcessor
@@ -5793,9 +5803,21 @@
from .models.video_llava import VideoLlavaImageProcessor
from .models.videomae import VideoMAEFeatureExtractor, VideoMAEImageProcessor
from .models.vilt import ViltFeatureExtractor, ViltImageProcessor, ViltProcessor
- from .models.vit import ViTFeatureExtractor, ViTImageProcessor, ViTImageProcessorFast
+ from .models.vit import ViTFeatureExtractor, ViTImageProcessor
+ from .models.vit_hybrid import ViTHybridImageProcessor
+ from .models.vitmatte import VitMatteImageProcessor
from .models.vivit import VivitImageProcessor
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
+
+ try:
+ if not is_torchvision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_torchvision_objects import *
+ else:
+ from .image_processing_utils_fast import BaseImageProcessorFast
+ from .models.vit import ViTImageProcessorFast
+
# Modeling
try:
if not is_torch_available():
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 74e66f819cd049..803b31b13b6b46 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -438,7 +438,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
raise ValueError(
- "This image processor cannot be instantiated. Please make sure you have `torchvision` installed."
+ "This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
)
raise ValueError(
diff --git a/src/transformers/models/vit/__init__.py b/src/transformers/models/vit/__init__.py
index 25f55487c4bfd1..3066331278e44f 100644
--- a/src/transformers/models/vit/__init__.py
+++ b/src/transformers/models/vit/__init__.py
@@ -19,6 +19,7 @@
is_flax_available,
is_tf_available,
is_torch_available,
+ is_torchvision_available,
is_vision_available,
)
@@ -33,6 +34,14 @@
else:
_import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
_import_structure["image_processing_vit"] = ["ViTImageProcessor"]
+
+
+try:
+ if not is_torchvision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["image_processing_vit_fast"] = ["ViTImageProcessorFast"]
try:
@@ -83,6 +92,13 @@
else:
from .feature_extraction_vit import ViTFeatureExtractor
from .image_processing_vit import ViTImageProcessor
+
+ try:
+ if not is_torchvision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .image_processing_vit_fast import ViTImageProcessorFast
try:
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index 2f7cc053d1a2d8..7a2f740b2b018d 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -47,16 +47,15 @@
if is_torchvision_available():
from torchvision.transforms import Compose, InterpolationMode, Lambda, Normalize, Resize, ToTensor
-
-pil_torch_interpolation_mapping = {
- PILImageResampling.NEAREST: InterpolationMode.NEAREST,
- PILImageResampling.BOX: InterpolationMode.BOX,
- PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
- PILImageResampling.HAMMING: InterpolationMode.HAMMING,
- PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
- PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
- PILImageResampling.NEAREST: InterpolationMode.NEAREST,
-}
+ pil_torch_interpolation_mapping = {
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST,
+ PILImageResampling.BOX: InterpolationMode.BOX,
+ PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
+ PILImageResampling.HAMMING: InterpolationMode.HAMMING,
+ PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
+ PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST,
+ }
@dataclass(frozen=True)
@@ -168,7 +167,7 @@ def _build_transforms(
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
image_type: ImageType,
- ) -> Compose:
+ ) -> "Compose":
"""
Given the input settings build the image transforms using `torchvision.transforms.Compose`.
"""
diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py
new file mode 100644
index 00000000000000..1d532aeea2a4de
--- /dev/null
+++ b/src/transformers/utils/dummy_torchvision_objects.py
@@ -0,0 +1,16 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class BaseImageProcessorFast(metaclass=DummyObject):
+ _backends = ["torchvision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torchvision"])
+
+
+class ViTImageProcessorFast(metaclass=DummyObject):
+ _backends = ["torchvision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torchvision"])
diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py
index 34741a703509ec..e60e869dcf7af0 100644
--- a/src/transformers/utils/dummy_vision_objects.py
+++ b/src/transformers/utils/dummy_vision_objects.py
@@ -16,13 +16,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
-class BaseImageProcessorFast(metaclass=DummyObject):
- _backends = ["vision"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["vision"])
-
-
class ImageFeatureExtractionMixin(metaclass=DummyObject):
_backends = ["vision"]
@@ -611,13 +604,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
-class ViTImageProcessorFast(metaclass=DummyObject):
- _backends = ["vision"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["vision"])
-
-
class ViTHybridImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
diff --git a/tests/models/video_llava/test_image_processing_video_llava.py b/tests/models/video_llava/test_image_processing_video_llava.py
index 4b69022bae0b82..808001d2814def 100644
--- a/tests/models/video_llava/test_image_processing_video_llava.py
+++ b/tests/models/video_llava/test_image_processing_video_llava.py
@@ -128,6 +128,7 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->VideoLlava
def setUp(self):
+ super().setUp()
self.image_processor_tester = VideoLlavaImageProcessingTester(self)
@property
From fb89515487bd6b4c98de78dc0af6e7eab20ce505 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Thu, 16 May 2024 19:08:27 +0000
Subject: [PATCH 10/40] Add necessary Pillow requirement
---
examples/pytorch/_tests_requirements.txt | 1 +
1 file changed, 1 insertion(+)
diff --git a/examples/pytorch/_tests_requirements.txt b/examples/pytorch/_tests_requirements.txt
index 2a854b12e6aa30..819b49c799aec7 100644
--- a/examples/pytorch/_tests_requirements.txt
+++ b/examples/pytorch/_tests_requirements.txt
@@ -29,3 +29,4 @@ timm
albumentations >= 1.4.5
torchmetrics
pycocotools
+Pillow>=10.0.1,<=15.0
From d2eb99f21db218c57ad3910abbb83dc12cb665af Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Thu, 16 May 2024 19:47:44 +0000
Subject: [PATCH 11/40] Woops
---
src/transformers/models/auto/image_processing_auto.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 803b31b13b6b46..843eae599694b8 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -429,12 +429,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
elif type(config) in IMAGE_PROCESSOR_MAPPING:
image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)]
- image_processor_class_py, image_processor_fast_class = image_processor_tuple
+ image_processor_class_py, image_processor_class_fast = image_processor_tuple
- if image_processor_fast_class and (use_fast or image_processor_class_py is None):
- return image_processor_fast_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
+ if image_processor_class_fast and (use_fast or image_processor_class_py is None):
+ return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
- if image_processor_class_py is None:
+ if image_processor_class_py is not None:
return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
raise ValueError(
From fc1530e8dbebcded47448295f22c890e98c11ec8 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Thu, 16 May 2024 20:05:53 +0000
Subject: [PATCH 12/40] Add AutoImageProcessor test
---
tests/test_image_processing_common.py | 14 +++++++++++++-
1 file changed, 13 insertions(+), 1 deletion(-)
diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py
index 815d3fa7271bd2..a46eaa30764501 100644
--- a/tests/test_image_processing_common.py
+++ b/tests/test_image_processing_common.py
@@ -19,7 +19,7 @@
import pathlib
import tempfile
-from transformers import BatchFeature
+from transformers import BatchFeature, AutoImageProcessor
from transformers.image_utils import AnnotationFormat, AnnotionFormat
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
@@ -175,6 +175,18 @@ def test_image_processor_from_and_save_pretrained(self):
self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
+ def test_image_processor_save_load_with_autoimageprocessor(self):
+ for image_processing_class in self.image_processor_list:
+ image_processor_first = image_processing_class(**self.image_processor_dict)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ saved_file = image_processor_first.save_pretrained(tmpdirname)[0]
+ check_json_file_has_correct_format(saved_file)
+
+ image_processor_second = AutoImageProcessor.from_pretrained(tmpdirname)
+
+ self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
+
def test_init_without_params(self):
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class()
From a3f6d0273156fe51f3b2a0c9908d227fb46eb071 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Thu, 16 May 2024 20:11:06 +0000
Subject: [PATCH 13/40] Fix up
---
tests/test_image_processing_common.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py
index a46eaa30764501..d929997ee87369 100644
--- a/tests/test_image_processing_common.py
+++ b/tests/test_image_processing_common.py
@@ -19,7 +19,7 @@
import pathlib
import tempfile
-from transformers import BatchFeature, AutoImageProcessor
+from transformers import AutoImageProcessor, BatchFeature
from transformers.image_utils import AnnotationFormat, AnnotionFormat
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
From e0bd18dff8ebaf4645c09bee5a43deeda546451e Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Thu, 16 May 2024 20:31:57 +0000
Subject: [PATCH 14/40] Fix test for imagegpt
---
.../test_image_processing_imagegpt.py | 45 +++++++++++++------
1 file changed, 32 insertions(+), 13 deletions(-)
diff --git a/tests/models/imagegpt/test_image_processing_imagegpt.py b/tests/models/imagegpt/test_image_processing_imagegpt.py
index 0d91824f195ea5..669e87a20ef078 100644
--- a/tests/models/imagegpt/test_image_processing_imagegpt.py
+++ b/tests/models/imagegpt/test_image_processing_imagegpt.py
@@ -22,7 +22,8 @@
import numpy as np
from datasets import load_dataset
-from transformers.testing_utils import require_torch, require_vision, slow
+from transformers import AutoImageProcessor
+from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision, slow
from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@@ -142,18 +143,36 @@ def test_image_processor_to_json_file(self):
self.assertEqual(image_processor_first[key], value)
def test_image_processor_from_and_save_pretrained(self):
- image_processor_first = self.image_processing_class(**self.image_processor_dict)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- image_processor_first.save_pretrained(tmpdirname)
- image_processor_second = self.image_processing_class.from_pretrained(tmpdirname).to_dict()
-
- image_processor_first = image_processor_first.to_dict()
- for key, value in image_processor_first.items():
- if key == "clusters":
- self.assertTrue(np.array_equal(value, image_processor_second[key]))
- else:
- self.assertEqual(image_processor_first[key], value)
+ for image_processing_class in self.image_processor_list:
+ image_processor_first = self.image_processing_class(**self.image_processor_dict)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ image_processor_first.save_pretrained(tmpdirname)
+ image_processor_second = self.image_processing_class.from_pretrained(tmpdirname).to_dict()
+
+ image_processor_first = image_processor_first.to_dict()
+ for key, value in image_processor_first.items():
+ if key == "clusters":
+ self.assertTrue(np.array_equal(value, image_processor_second[key]))
+ else:
+ self.assertEqual(image_processor_first[key], value)
+
+ def test_image_processor_save_load_with_autoimageprocessor(self):
+ for image_processing_class in self.image_processor_list:
+ image_processor_first = image_processing_class(**self.image_processor_dict)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ saved_file = image_processor_first.save_pretrained(tmpdirname)[0]
+ check_json_file_has_correct_format(saved_file)
+
+ image_processor_second = AutoImageProcessor.from_pretrained(tmpdirname)
+
+ image_processor_first = image_processor_first.to_dict()
+ for key, value in image_processor_first.items():
+ if key == "clusters":
+ self.assertTrue(np.array_equal(value, image_processor_second[key]))
+ else:
+ self.assertEqual(image_processor_first[key], value)
@unittest.skip("ImageGPT requires clusters at initialization")
def test_init_without_params(self):
From 8c4761a842e351b1a241caabddac7909fe5f7e95 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Fri, 17 May 2024 13:43:37 +0000
Subject: [PATCH 15/40] Fix test
---
tests/models/imagegpt/test_image_processing_imagegpt.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/tests/models/imagegpt/test_image_processing_imagegpt.py b/tests/models/imagegpt/test_image_processing_imagegpt.py
index 669e87a20ef078..a9dbc636ef302d 100644
--- a/tests/models/imagegpt/test_image_processing_imagegpt.py
+++ b/tests/models/imagegpt/test_image_processing_imagegpt.py
@@ -168,6 +168,8 @@ def test_image_processor_save_load_with_autoimageprocessor(self):
image_processor_second = AutoImageProcessor.from_pretrained(tmpdirname)
image_processor_first = image_processor_first.to_dict()
+ image_processor_second = image_processor_second.to_dict()
+
for key, value in image_processor_first.items():
if key == "clusters":
self.assertTrue(np.array_equal(value, image_processor_second[key]))
From 687da888048581fdac9b93f601b21890ea401d86 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 22 May 2024 14:36:48 +0000
Subject: [PATCH 16/40] Review comments
---
.../image_processing_utils_fast.py | 37 ++++++++++
src/transformers/image_utils.py | 15 ++++
.../models/vit/image_processing_vit_fast.py | 73 ++++++-------------
3 files changed, 75 insertions(+), 50 deletions(-)
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index 873bbd1bb5293a..1f53368c3433de 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -16,6 +16,11 @@
import functools
from .image_processing_utils import BaseImageProcessor
+from .utils import is_torchvision_available
+
+
+if is_torchvision_available():
+ from torchvision.transforms import functional as F
class BaseImageProcessorFast(BaseImageProcessor):
@@ -69,3 +74,35 @@ def _maybe_update_transforms(self, **kwargs):
if self._same_transforms_settings(**kwargs):
return
self.set_transforms(**kwargs)
+
+
+def _cast_tensor_to_float(x):
+ if x.is_floating_point():
+ return x
+ return x.float()
+
+
+class FusedRescaleNormalize:
+ """
+ Rescale and normalize the input image in one step.
+ """
+
+ def __init__(self, mean, std, rescale_factor: float = 1.0, inplace: bool = False):
+ self.mean = mean * (1.0 / rescale_factor)
+ self.std = std * (1.0 / rescale_factor)
+
+ def __call__(self, image):
+ image = _cast_tensor_to_float(image)
+ return F.normalize(image, self.mean, self.std, inplace=self.inplace)
+
+
+class Rescale:
+ """
+ Rescale the input image by rescale factor: image *= rescale_factor.
+ """
+
+ def __init__(self, rescale_factor: float = 1.0):
+ self.rescale_factor = rescale_factor
+
+ def __call__(self, image):
+ return image.mul(self.rescale_factor)
diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py
index aaa9e4eadc6a2a..8d3b955c918053 100644
--- a/src/transformers/image_utils.py
+++ b/src/transformers/image_utils.py
@@ -28,6 +28,7 @@
is_tf_tensor,
is_torch_available,
is_torch_tensor,
+ is_torchvision_available,
is_vision_available,
logging,
requires_backends,
@@ -52,6 +53,20 @@
else:
PILImageResampling = PIL.Image
+ if is_torchvision_available():
+ from torchvision.transforms import InterpolationMode
+
+ pil_torch_interpolation_mapping = {
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST,
+ PILImageResampling.BOX: InterpolationMode.BOX,
+ PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
+ PILImageResampling.HAMMING: InterpolationMode.HAMMING,
+ PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
+ PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST,
+ }
+
+
if TYPE_CHECKING:
if is_torch_available():
import torch
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index 7a2f740b2b018d..30961dd3911f6e 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -12,16 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Image processor class for ViT."""
+"""Fast Image processor class for ViT."""
import functools
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
-import numpy as np
-
from ...image_processing_utils import get_size_dict
-from ...image_processing_utils_fast import BaseImageProcessorFast
+from ...image_processing_utils_fast import BaseImageProcessorFast, FusedRescaleNormalize, Rescale
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
@@ -29,8 +27,9 @@
ImageInput,
PILImageResampling,
make_list_of_images,
+ pil_torch_interpolation_mapping,
)
-from ...utils import TensorType, logging
+from ...utils import TensorType, is_numpy_array, is_torch_tensor, logging
from ...utils.generic import ExplicitEnum
from ...utils.import_utils import is_torch_available, is_torchvision_available, is_vision_available
@@ -44,18 +43,9 @@
if is_vision_available():
from PIL import Image
-if is_torchvision_available():
- from torchvision.transforms import Compose, InterpolationMode, Lambda, Normalize, Resize, ToTensor
- pil_torch_interpolation_mapping = {
- PILImageResampling.NEAREST: InterpolationMode.NEAREST,
- PILImageResampling.BOX: InterpolationMode.BOX,
- PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
- PILImageResampling.HAMMING: InterpolationMode.HAMMING,
- PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
- PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
- PILImageResampling.NEAREST: InterpolationMode.NEAREST,
- }
+if is_torchvision_available():
+ from torchvision.transforms import Compose, Lambda, Normalize, PILToTensor, Resize
@dataclass(frozen=True)
@@ -80,9 +70,9 @@ class ImageType(ExplicitEnum):
def get_image_type(image):
if is_vision_available() and isinstance(image, Image.Image):
return ImageType.PIL
- if is_torch_available() and isinstance(image, torch.Tensor):
+ if is_torch_tensor(image):
return ImageType.TORCH
- if isinstance(image, np.ndarray):
+ if is_numpy_array(image):
return ImageType.NUMPY
raise ValueError(f"Unrecognised image type {type(image)}")
@@ -171,46 +161,29 @@ def _build_transforms(
"""
Given the input settings build the image transforms using `torchvision.transforms.Compose`.
"""
-
- def rescale_image(image, rescale_factor):
- return image * rescale_factor
-
transforms = []
if do_resize:
transforms.append(
Resize((size["height"], size["width"]), interpolation=pil_torch_interpolation_mapping[resample])
)
- # Regardless of whether we rescale, all PIL and numpy values need to be converted to a torch tensor
+ # All PIL and numpy values need to be converted to a torch tensor
# to keep cross compatibility with slow image processors
- convert_to_tensor = image_type in (ImageType.PIL, ImageType.NUMPY)
- if convert_to_tensor:
- transforms.append(ToTensor())
-
- if do_rescale:
- if convert_to_tensor:
- # ToTensor scales the pixel values to [0, 1] by dividing by the largest value in the image.
- # By default, the rescale factor for the image processor is 1 / 255, i.e. assuming the maximum
- # possible value is 255. Here, if it's different, we need to undo the (assumed) 1/255 scaling
- # and then rescale again
- #
- # NB: This means that the final pixel values will be different in the torchvision transform
- # depending on the pixels in the image as they become [min_val / max_value, max_value / max_value]
- # whereas in the image processors they are [min_value * rescale_factor, max_value * rescale_factor]
- if rescale_factor != 1 / 255:
- rescale_factor = rescale_factor * 255
- transforms.append(Lambda(functools.partial(rescale_image, rescale_factor=rescale_factor)))
- else:
- # If do_rescale is `True`, we should still respect it
- transforms.append(Lambda(functools.partial(rescale_image, rescale_factor=rescale_factor)))
- elif convert_to_tensor:
- # If we've converted to a tensor and do_rescale=False, then we need to unscale.
- # As with do_scale=True, we assume that the pixel values were rescaled by 1/255
- rescale_factor = 255
- transforms.append(Lambda(functools.partial(rescale_image, rescale_factor=rescale_factor)))
-
- if do_normalize:
+ if image_type == ImageType.PIL:
+ transforms.append(PILToTensor())
+
+ elif image_type == ImageType.NUMPY:
+ # Do we want to permute the channels here?
+ transforms.append(Lambda(lambda x: torch.from_numpy(x)))
+
+ # We can combine rescale and normalize into a single operation for speed
+ if do_rescale and do_normalize:
+ transforms.append(FusedRescaleNormalize(image_mean, image_std, rescale_factor=rescale_factor))
+ elif do_rescale:
+ transforms.append(Rescale(rescale_factor=rescale_factor))
+ elif do_normalize:
transforms.append(Normalize(image_mean, image_std))
+
return Compose(transforms)
@functools.lru_cache(maxsize=1)
From fc1e12160b91061187f36cd3df0848db45b5a0c1 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 22 May 2024 14:52:11 +0000
Subject: [PATCH 17/40] Add warning for TF and JAX input types
---
src/transformers/models/vit/image_processing_vit_fast.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index 30961dd3911f6e..669ce5bb2e6a67 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -162,6 +162,7 @@ def _build_transforms(
Given the input settings build the image transforms using `torchvision.transforms.Compose`.
"""
transforms = []
+
if do_resize:
transforms.append(
Resize((size["height"], size["width"]), interpolation=pil_torch_interpolation_mapping[resample])
@@ -285,6 +286,9 @@ def preprocess(
images = make_list_of_images(images)
image_type = get_image_type(images[0])
+ if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
+ raise ValueError(f"Unsupported input image type {image_type}")
+
self._validate_input_arguments(
do_resize=do_resize,
size=size,
From 1077938bd01052a7f18d26ffb392b8f6afed2448 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 22 May 2024 15:08:37 +0000
Subject: [PATCH 18/40] Rearrange
---
.../image_processing_utils_fast.py | 53 +++++++------------
src/transformers/image_transforms.py | 36 +++++++++++++
src/transformers/image_utils.py | 31 ++++++++---
.../models/vit/image_processing_vit_fast.py | 45 +++-------------
4 files changed, 85 insertions(+), 80 deletions(-)
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index 1f53368c3433de..d0c5733edc2259 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -14,13 +14,28 @@
# limitations under the License.
import functools
+from dataclasses import dataclass
from .image_processing_utils import BaseImageProcessor
-from .utils import is_torchvision_available
-if is_torchvision_available():
- from torchvision.transforms import functional as F
+@dataclass(frozen=True)
+class SizeDict:
+ """
+ Hashable dictionary to store image size information.
+ """
+
+ height: int = None
+ width: int = None
+ longest_edge: int = None
+ shortest_edge: int = None
+ max_height: int = None
+ max_width: int = None
+
+ def __getitem__(self, key):
+ if hasattr(self, key):
+ return getattr(self, key)
+ raise KeyError(f"Key {key} not found in SizeDict.")
class BaseImageProcessorFast(BaseImageProcessor):
@@ -74,35 +89,3 @@ def _maybe_update_transforms(self, **kwargs):
if self._same_transforms_settings(**kwargs):
return
self.set_transforms(**kwargs)
-
-
-def _cast_tensor_to_float(x):
- if x.is_floating_point():
- return x
- return x.float()
-
-
-class FusedRescaleNormalize:
- """
- Rescale and normalize the input image in one step.
- """
-
- def __init__(self, mean, std, rescale_factor: float = 1.0, inplace: bool = False):
- self.mean = mean * (1.0 / rescale_factor)
- self.std = std * (1.0 / rescale_factor)
-
- def __call__(self, image):
- image = _cast_tensor_to_float(image)
- return F.normalize(image, self.mean, self.std, inplace=self.inplace)
-
-
-class Rescale:
- """
- Rescale the input image by rescale factor: image *= rescale_factor.
- """
-
- def __init__(self, rescale_factor: float = 1.0):
- self.rescale_factor = rescale_factor
-
- def __call__(self, image):
- return image.mul(self.rescale_factor)
diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py
index 65d6413db73789..739bedc755199a 100644
--- a/src/transformers/image_transforms.py
+++ b/src/transformers/image_transforms.py
@@ -31,6 +31,7 @@
is_flax_available,
is_tf_available,
is_torch_available,
+ is_torchvision_available,
is_vision_available,
requires_backends,
)
@@ -50,6 +51,9 @@
if is_flax_available():
import jax.numpy as jnp
+if is_torchvision_available():
+ from torchvision.transforms import functional as F
+
def to_channel_dimension_format(
image: np.ndarray,
@@ -802,3 +806,35 @@ def flip_channel_order(
if data_format is not None:
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image
+
+
+def _cast_tensor_to_float(x):
+ if x.is_floating_point():
+ return x
+ return x.float()
+
+
+class FusedRescaleNormalize:
+ """
+ Rescale and normalize the input image in one step.
+ """
+
+ def __init__(self, mean, std, rescale_factor: float = 1.0, inplace: bool = False):
+ self.mean = mean * (1.0 / rescale_factor)
+ self.std = std * (1.0 / rescale_factor)
+
+ def __call__(self, image: "torch.Tensor"):
+ image = _cast_tensor_to_float(image)
+ return F.normalize(image, self.mean, self.std, inplace=self.inplace)
+
+
+class Rescale:
+ """
+ Rescale the input image by rescale factor: image *= rescale_factor.
+ """
+
+ def __init__(self, rescale_factor: float = 1.0):
+ self.rescale_factor = rescale_factor
+
+ def __call__(self, image: "torch.Tensor"):
+ return image.mul(self.rescale_factor)
diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py
index 8d3b955c918053..aa09e74558a389 100644
--- a/src/transformers/image_utils.py
+++ b/src/transformers/image_utils.py
@@ -25,6 +25,7 @@
from .utils import (
ExplicitEnum,
is_jax_tensor,
+ is_numpy_array,
is_tf_tensor,
is_torch_available,
is_torch_tensor,
@@ -105,14 +106,30 @@ def is_pil_image(img):
return is_vision_available() and isinstance(img, PIL.Image.Image)
+class ImageType(ExplicitEnum):
+ PIL = "pillow"
+ TORCH = "torch"
+ NUMPY = "numpy"
+ TENSORFLOW = "tensorflow"
+ JAX = "jax"
+
+
+def get_image_type(image):
+ if is_pil_image(image):
+ return ImageType.PIL
+ if is_torch_tensor(image):
+ return ImageType.TORCH
+ if is_numpy_array(image):
+ return ImageType.NUMPY
+ if is_tf_tensor(image):
+ return ImageType.TENSORFLOW
+ if is_jax_tensor(image):
+ return ImageType.JAX
+ raise ValueError(f"Unrecognised image type {type(image)}")
+
+
def is_valid_image(img):
- return (
- (is_vision_available() and isinstance(img, PIL.Image.Image))
- or isinstance(img, np.ndarray)
- or is_torch_tensor(img)
- or is_tf_tensor(img)
- or is_jax_tensor(img)
- )
+ return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
def valid_images(imgs):
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index 669ce5bb2e6a67..7604da15a8e3e5 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -15,23 +15,24 @@
"""Fast Image processor class for ViT."""
import functools
-from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from ...image_processing_utils import get_size_dict
-from ...image_processing_utils_fast import BaseImageProcessorFast, FusedRescaleNormalize, Rescale
+from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict
+from ...image_transforms import FusedRescaleNormalize, Rescale
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
+ ImageType,
PILImageResampling,
+ get_image_type,
make_list_of_images,
pil_torch_interpolation_mapping,
)
-from ...utils import TensorType, is_numpy_array, is_torch_tensor, logging
-from ...utils.generic import ExplicitEnum
-from ...utils.import_utils import is_torch_available, is_torchvision_available, is_vision_available
+from ...utils import TensorType, logging
+from ...utils.import_utils import is_torch_available, is_torchvision_available
logger = logging.get_logger(__name__)
@@ -40,43 +41,11 @@
if is_torch_available():
import torch
-if is_vision_available():
- from PIL import Image
-
if is_torchvision_available():
from torchvision.transforms import Compose, Lambda, Normalize, PILToTensor, Resize
-@dataclass(frozen=True)
-class SizeDict:
- height: int = None
- width: int = None
- longest_edge: int = None
- shortest_edge: int = None
-
- def __getitem__(self, key):
- if hasattr(self, key):
- return getattr(self, key)
- raise KeyError(f"Key {key} not found in SizeDict.")
-
-
-class ImageType(ExplicitEnum):
- PIL = "pillow"
- TORCH = "torch"
- NUMPY = "numpy"
-
-
-def get_image_type(image):
- if is_vision_available() and isinstance(image, Image.Image):
- return ImageType.PIL
- if is_torch_tensor(image):
- return ImageType.TORCH
- if is_numpy_array(image):
- return ImageType.NUMPY
- raise ValueError(f"Unrecognised image type {type(image)}")
-
-
class ViTImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a ViT image processor.
@@ -162,7 +131,7 @@ def _build_transforms(
Given the input settings build the image transforms using `torchvision.transforms.Compose`.
"""
transforms = []
-
+
if do_resize:
transforms.append(
Resize((size["height"], size["width"]), interpolation=pil_torch_interpolation_mapping[resample])
From 5cb11df6505b104a086404e40985bbde04f8794b Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 22 May 2024 15:13:11 +0000
Subject: [PATCH 19/40] Return transforms
---
.../image_processing_utils_fast.py | 20 ++++++++++++-------
.../models/vit/image_processing_vit_fast.py | 4 ++--
2 files changed, 15 insertions(+), 9 deletions(-)
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index d0c5733edc2259..a7ffc6411f3f96 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -17,6 +17,11 @@
from dataclasses import dataclass
from .image_processing_utils import BaseImageProcessor
+from .utils.import_utils import is_torchvision_available
+
+
+if is_torchvision_available():
+ from torchvision.transforms import Compose
@dataclass(frozen=True)
@@ -42,7 +47,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
_transform_params = None
_transform_settings = None
- def _set_transform_settings(self, **kwargs):
+ def _set_transform_settings(self, **kwargs) -> None:
settings = {}
for k, v in kwargs.items():
if k not in self._transform_params:
@@ -50,7 +55,7 @@ def _set_transform_settings(self, **kwargs):
settings[k] = v
self._transform_settings = settings
- def _same_transforms_settings(self, **kwargs):
+ def _same_transforms_settings(self, **kwargs) -> bool:
"""
Check if the current settings are the same as the current transforms.
"""
@@ -62,13 +67,13 @@ def _same_transforms_settings(self, **kwargs):
return False
return True
- def _build_transforms(self, **kwargs):
+ def _build_transforms(self, **kwargs) -> Compose:
"""
Given the input settings e.g. do_resize, build the image transforms.
"""
raise NotImplementedError
- def set_transforms(self, **kwargs):
+ def set_transforms(self, **kwargs) -> Compose:
"""
Set the image transforms based on the given settings.
If the settings are the same as the current ones, do nothing.
@@ -79,13 +84,14 @@ def set_transforms(self, **kwargs):
transforms = self._build_transforms(**kwargs)
self._set_transform_settings(**kwargs)
self._transforms = transforms
+ return transforms
@functools.lru_cache(maxsize=1)
- def _maybe_update_transforms(self, **kwargs):
+ def _maybe_update_transforms(self, **kwargs) -> Compose:
"""
If settings are different from those stored in `self._transform_settings`, update
the image transforms to apply
"""
if self._same_transforms_settings(**kwargs):
- return
- self.set_transforms(**kwargs)
+ return self._transforms
+ return self.set_transforms(**kwargs)
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index 7604da15a8e3e5..f2d4d4ded42862 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -272,7 +272,7 @@ def preprocess(
image_type=image_type,
)
- self._maybe_update_transforms(
+ transforms = self._maybe_update_transforms(
do_resize=do_resize,
do_rescale=do_rescale,
do_normalize=do_normalize,
@@ -283,7 +283,7 @@ def preprocess(
image_std=image_std,
image_type=image_type,
)
- transformed_images = [self._transforms(image) for image in images]
+ transformed_images = [transforms(image) for image in images]
data = {"pixel_values": torch.vstack(transformed_images)}
return data
From fff70c3ed2873fc3d8cf5d44dcef689e9dbfdd73 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 22 May 2024 15:16:10 +0000
Subject: [PATCH 20/40] NumpyToTensor transformation
---
src/transformers/image_transforms.py | 8 ++++++++
src/transformers/models/vit/image_processing_vit_fast.py | 4 ++--
2 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py
index 739bedc755199a..137c2c1836e186 100644
--- a/src/transformers/image_transforms.py
+++ b/src/transformers/image_transforms.py
@@ -838,3 +838,11 @@ def __init__(self, rescale_factor: float = 1.0):
def __call__(self, image: "torch.Tensor"):
return image.mul(self.rescale_factor)
+
+
+class NumpyToTensor:
+ """
+ Convert a numpy array to a PyTorch tensor.
+ """
+ def __call__(self, image: np.ndarray):
+ return torch.from_numpy(image)
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index f2d4d4ded42862..ee2b891ca308c9 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -19,7 +19,7 @@
from ...image_processing_utils import get_size_dict
from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict
-from ...image_transforms import FusedRescaleNormalize, Rescale
+from ...image_transforms import FusedRescaleNormalize, Rescale, NumpyToTensor
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
@@ -144,7 +144,7 @@ def _build_transforms(
elif image_type == ImageType.NUMPY:
# Do we want to permute the channels here?
- transforms.append(Lambda(lambda x: torch.from_numpy(x)))
+ transforms.append(NumpyToTensor())
# We can combine rescale and normalize into a single operation for speed
if do_rescale and do_normalize:
From 8b09622624d02fb455ba569530b9f225775d8397 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 22 May 2024 16:14:14 +0000
Subject: [PATCH 21/40] Rebase - include changes from upstream in
ImageProcessingMixin
---
src/transformers/image_processing_base.py | 34 +-
src/transformers/image_processing_utils.py | 527 +-----------------
.../models/vit/image_processing_vit_fast.py | 4 +-
3 files changed, 26 insertions(+), 539 deletions(-)
diff --git a/src/transformers/image_processing_base.py b/src/transformers/image_processing_base.py
index 2146afa2108cc2..6c80aee0164722 100644
--- a/src/transformers/image_processing_base.py
+++ b/src/transformers/image_processing_base.py
@@ -30,7 +30,9 @@
IMAGE_PROCESSOR_NAME,
PushToHubMixin,
add_model_info_to_auto_map,
+ add_model_info_to_custom_pipelines,
cached_file,
+ copy_func,
download_url,
is_offline_mode,
is_remote_url,
@@ -110,8 +112,7 @@ def from_pretrained(
This can be either:
- a string, the *model id* of a pretrained image_processor hosted inside a model repo on
- huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
- namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ huggingface.co.
- a path to a *directory* containing a image processor file saved using the
[`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
`./my_model_directory/`.
@@ -123,9 +124,9 @@ def from_pretrained(
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the image processor files and override the cached versions if
they exist.
- resume_download (`bool`, *optional*, defaults to `False`):
- Whether or not to delete incompletely received file. Attempts to resume the download if such a file
- exists.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
@@ -287,7 +288,7 @@ def get_image_processor_dict(
"""
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
- resume_download = kwargs.pop("resume_download", False)
+ resume_download = kwargs.pop("resume_download", None)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
@@ -375,11 +376,15 @@ def get_image_processor_dict(
f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
)
- if "auto_map" in image_processor_dict and not is_local:
- image_processor_dict["auto_map"] = add_model_info_to_auto_map(
- image_processor_dict["auto_map"], pretrained_model_name_or_path
- )
-
+ if not is_local:
+ if "auto_map" in image_processor_dict:
+ image_processor_dict["auto_map"] = add_model_info_to_auto_map(
+ image_processor_dict["auto_map"], pretrained_model_name_or_path
+ )
+ if "custom_pipelines" in image_processor_dict:
+ image_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
+ image_processor_dict["custom_pipelines"], pretrained_model_name_or_path
+ )
return image_processor_dict, kwargs
@classmethod
@@ -540,3 +545,10 @@ def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
return Image.open(BytesIO(response.content))
else:
raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
+
+
+ImageProcessingMixin.push_to_hub = copy_func(ImageProcessingMixin.push_to_hub)
+if ImageProcessingMixin.push_to_hub.__doc__ is not None:
+ ImageProcessingMixin.push_to_hub.__doc__ = ImageProcessingMixin.push_to_hub.__doc__.format(
+ object="image processor", object_class="AutoImageProcessor", object_files="image processor file"
+ )
diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py
index 95c4b42abcf24c..4b263446b54e2a 100644
--- a/src/transformers/image_processing_utils.py
+++ b/src/transformers/image_processing_utils.py
@@ -20,28 +20,9 @@
from .image_processing_base import BatchFeature, ImageProcessingMixin
from .image_transforms import center_crop, normalize, rescale
from .image_utils import ChannelDimension
-from .utils import copy_func, logging
-
-from .image_processing_base import ImageProcessingMixin, BatchFeature
-from .image_utils import ChannelDimension
-from .utils import (
- IMAGE_PROCESSOR_NAME,
- PushToHubMixin,
- add_model_info_to_auto_map,
- add_model_info_to_custom_pipelines,
- cached_file,
- copy_func,
- download_url,
- is_offline_mode,
- is_remote_url,
- is_vision_available,
- logging,
-)
+from .utils import logging
-if is_vision_available():
- from PIL import Image
-
logger = logging.get_logger(__name__)
@@ -51,505 +32,6 @@
]
-# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
-# We override the class string here, but logic is the same.
-class BatchFeature(BaseBatchFeature):
- r"""
- Holds the output of the image processor specific `__call__` methods.
-
- This class is derived from a python dictionary and can be used as a dictionary.
-
- Args:
- data (`dict`):
- Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
- tensor_type (`Union[None, str, TensorType]`, *optional*):
- You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
- initialization.
- """
-
-
-# TODO: (Amy) - factor out the common parts of this and the feature extractor
-class ImageProcessingMixin(PushToHubMixin):
- """
- This is an image processor mixin used to provide saving/loading functionality for sequential and image feature
- extractors.
- """
-
- _auto_class = None
-
- def __init__(self, **kwargs):
- """Set elements of `kwargs` as attributes."""
- # This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use
- # `XXXImageProcessor`, this attribute and its value are misleading.
- kwargs.pop("feature_extractor_type", None)
- # Pop "processor_class" as it should be saved as private attribute
- self._processor_class = kwargs.pop("processor_class", None)
- # Additional attributes without default values
- for key, value in kwargs.items():
- try:
- setattr(self, key, value)
- except AttributeError as err:
- logger.error(f"Can't set {key} with value {value} for {self}")
- raise err
-
- def _set_processor_class(self, processor_class: str):
- """Sets processor class as an attribute."""
- self._processor_class = processor_class
-
- @classmethod
- def from_pretrained(
- cls,
- pretrained_model_name_or_path: Union[str, os.PathLike],
- cache_dir: Optional[Union[str, os.PathLike]] = None,
- force_download: bool = False,
- local_files_only: bool = False,
- token: Optional[Union[str, bool]] = None,
- revision: str = "main",
- **kwargs,
- ):
- r"""
- Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor.
-
- Args:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- This can be either:
-
- - a string, the *model id* of a pretrained image_processor hosted inside a model repo on
- huggingface.co.
- - a path to a *directory* containing a image processor file saved using the
- [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
- `./my_model_directory/`.
- - a path or url to a saved image processor JSON *file*, e.g.,
- `./my_model_directory/preprocessor_config.json`.
- cache_dir (`str` or `os.PathLike`, *optional*):
- Path to a directory in which a downloaded pretrained model image processor should be cached if the
- standard cache should not be used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force to (re-)download the image processor files and override the cached versions if
- they exist.
- resume_download:
- Deprecated and ignored. All downloads are now resumed by default when possible.
- Will be removed in v5 of Transformers.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
- token (`str` or `bool`, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
- the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
-
-
-
-
- To test a pull request you made on the Hub, you can pass `revision="refs/pr/".
-
-
-
- return_unused_kwargs (`bool`, *optional*, defaults to `False`):
- If `False`, then this function returns just the final image processor object. If `True`, then this
- functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
- consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
- `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
- subfolder (`str`, *optional*, defaults to `""`):
- In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
- specify the folder name here.
- kwargs (`Dict[str, Any]`, *optional*):
- The values in kwargs of any keys which are image processor attributes will be used to override the
- loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
- controlled by the `return_unused_kwargs` keyword parameter.
-
- Returns:
- A image processor of type [`~image_processing_utils.ImageProcessingMixin`].
-
- Examples:
-
- ```python
- # We can't instantiate directly the base class *ImageProcessingMixin* so let's show the examples on a
- # derived class: *CLIPImageProcessor*
- image_processor = CLIPImageProcessor.from_pretrained(
- "openai/clip-vit-base-patch32"
- ) # Download image_processing_config from huggingface.co and cache.
- image_processor = CLIPImageProcessor.from_pretrained(
- "./test/saved_model/"
- ) # E.g. image processor (or model) was saved using *save_pretrained('./test/saved_model/')*
- image_processor = CLIPImageProcessor.from_pretrained("./test/saved_model/preprocessor_config.json")
- image_processor = CLIPImageProcessor.from_pretrained(
- "openai/clip-vit-base-patch32", do_normalize=False, foo=False
- )
- assert image_processor.do_normalize is False
- image_processor, unused_kwargs = CLIPImageProcessor.from_pretrained(
- "openai/clip-vit-base-patch32", do_normalize=False, foo=False, return_unused_kwargs=True
- )
- assert image_processor.do_normalize is False
- assert unused_kwargs == {"foo": False}
- ```"""
- kwargs["cache_dir"] = cache_dir
- kwargs["force_download"] = force_download
- kwargs["local_files_only"] = local_files_only
- kwargs["revision"] = revision
-
- use_auth_token = kwargs.pop("use_auth_token", None)
- if use_auth_token is not None:
- warnings.warn(
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
- FutureWarning,
- )
- if token is not None:
- raise ValueError(
- "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
- )
- token = use_auth_token
-
- if token is not None:
- kwargs["token"] = token
-
- image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
-
- return cls.from_dict(image_processor_dict, **kwargs)
-
- def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
- """
- Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the
- [`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method.
-
- Args:
- save_directory (`str` or `os.PathLike`):
- Directory where the image processor JSON file will be saved (will be created if it does not exist).
- push_to_hub (`bool`, *optional*, defaults to `False`):
- Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
- repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
- namespace).
- kwargs (`Dict[str, Any]`, *optional*):
- Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
- """
- use_auth_token = kwargs.pop("use_auth_token", None)
-
- if use_auth_token is not None:
- warnings.warn(
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
- FutureWarning,
- )
- if kwargs.get("token", None) is not None:
- raise ValueError(
- "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
- )
- kwargs["token"] = use_auth_token
-
- if os.path.isfile(save_directory):
- raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
-
- os.makedirs(save_directory, exist_ok=True)
-
- if push_to_hub:
- commit_message = kwargs.pop("commit_message", None)
- repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
- repo_id = self._create_repo(repo_id, **kwargs)
- files_timestamps = self._get_files_timestamps(save_directory)
-
- # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
- # loaded from the Hub.
- if self._auto_class is not None:
- custom_object_save(self, save_directory, config=self)
-
- # If we save using the predefined names, we can load using `from_pretrained`
- output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME)
-
- self.to_json_file(output_image_processor_file)
- logger.info(f"Image processor saved in {output_image_processor_file}")
-
- if push_to_hub:
- self._upload_modified_files(
- save_directory,
- repo_id,
- files_timestamps,
- commit_message=commit_message,
- token=kwargs.get("token"),
- )
-
- return [output_image_processor_file]
-
- @classmethod
- def get_image_processor_dict(
- cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
- """
- From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
- image processor of type [`~image_processor_utils.ImageProcessingMixin`] using `from_dict`.
-
- Parameters:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
- subfolder (`str`, *optional*, defaults to `""`):
- In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
- specify the folder name here.
-
- Returns:
- `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
- """
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- resume_download = kwargs.pop("resume_download", None)
- proxies = kwargs.pop("proxies", None)
- token = kwargs.pop("token", None)
- use_auth_token = kwargs.pop("use_auth_token", None)
- local_files_only = kwargs.pop("local_files_only", False)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", "")
-
- from_pipeline = kwargs.pop("_from_pipeline", None)
- from_auto_class = kwargs.pop("_from_auto", False)
-
- if use_auth_token is not None:
- warnings.warn(
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
- FutureWarning,
- )
- if token is not None:
- raise ValueError(
- "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
- )
- token = use_auth_token
-
- user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class}
- if from_pipeline is not None:
- user_agent["using_pipeline"] = from_pipeline
-
- if is_offline_mode() and not local_files_only:
- logger.info("Offline mode: forcing local_files_only=True")
- local_files_only = True
-
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
- is_local = os.path.isdir(pretrained_model_name_or_path)
- if os.path.isdir(pretrained_model_name_or_path):
- image_processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
- if os.path.isfile(pretrained_model_name_or_path):
- resolved_image_processor_file = pretrained_model_name_or_path
- is_local = True
- elif is_remote_url(pretrained_model_name_or_path):
- image_processor_file = pretrained_model_name_or_path
- resolved_image_processor_file = download_url(pretrained_model_name_or_path)
- else:
- image_processor_file = IMAGE_PROCESSOR_NAME
- try:
- # Load from local folder or from cache or download from model Hub and cache
- resolved_image_processor_file = cached_file(
- pretrained_model_name_or_path,
- image_processor_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- token=token,
- user_agent=user_agent,
- revision=revision,
- subfolder=subfolder,
- )
- except EnvironmentError:
- # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
- # the original exception.
- raise
- except Exception:
- # For any other exception, we throw a generic error.
- raise EnvironmentError(
- f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
- " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
- f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
- f" directory containing a {IMAGE_PROCESSOR_NAME} file"
- )
-
- try:
- # Load image_processor dict
- with open(resolved_image_processor_file, "r", encoding="utf-8") as reader:
- text = reader.read()
- image_processor_dict = json.loads(text)
-
- except json.JSONDecodeError:
- raise EnvironmentError(
- f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
- )
-
- if is_local:
- logger.info(f"loading configuration file {resolved_image_processor_file}")
- else:
- logger.info(
- f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
- )
-
- if not is_local:
- if "auto_map" in image_processor_dict:
- image_processor_dict["auto_map"] = add_model_info_to_auto_map(
- image_processor_dict["auto_map"], pretrained_model_name_or_path
- )
- if "custom_pipelines" in image_processor_dict:
- image_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
- image_processor_dict["custom_pipelines"], pretrained_model_name_or_path
- )
- return image_processor_dict, kwargs
-
- @classmethod
- def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
- """
- Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters.
-
- Args:
- image_processor_dict (`Dict[str, Any]`):
- Dictionary that will be used to instantiate the image processor object. Such a dictionary can be
- retrieved from a pretrained checkpoint by leveraging the
- [`~image_processing_utils.ImageProcessingMixin.to_dict`] method.
- kwargs (`Dict[str, Any]`):
- Additional parameters from which to initialize the image processor object.
-
- Returns:
- [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those
- parameters.
- """
- image_processor_dict = image_processor_dict.copy()
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
-
- # The `size` parameter is a dict and was previously an int or tuple in feature extractors.
- # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
- # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
- if "size" in kwargs and "size" in image_processor_dict:
- image_processor_dict["size"] = kwargs.pop("size")
- if "crop_size" in kwargs and "crop_size" in image_processor_dict:
- image_processor_dict["crop_size"] = kwargs.pop("crop_size")
-
- image_processor = cls(**image_processor_dict)
-
- # Update image_processor with kwargs if needed
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(image_processor, key):
- setattr(image_processor, key, value)
- to_remove.append(key)
- for key in to_remove:
- kwargs.pop(key, None)
-
- logger.info(f"Image processor {image_processor}")
- if return_unused_kwargs:
- return image_processor, kwargs
- else:
- return image_processor
-
- def to_dict(self) -> Dict[str, Any]:
- """
- Serializes this instance to a Python dictionary.
-
- Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance.
- """
- output = copy.deepcopy(self.__dict__)
- output["image_processor_type"] = self.__class__.__name__
-
- return output
-
- @classmethod
- def from_json_file(cls, json_file: Union[str, os.PathLike]):
- """
- Instantiates a image processor of type [`~image_processing_utils.ImageProcessingMixin`] from the path to a JSON
- file of parameters.
-
- Args:
- json_file (`str` or `os.PathLike`):
- Path to the JSON file containing the parameters.
-
- Returns:
- A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object
- instantiated from that JSON file.
- """
- with open(json_file, "r", encoding="utf-8") as reader:
- text = reader.read()
- image_processor_dict = json.loads(text)
- return cls(**image_processor_dict)
-
- def to_json_string(self) -> str:
- """
- Serializes this instance to a JSON string.
-
- Returns:
- `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
- """
- dictionary = self.to_dict()
-
- for key, value in dictionary.items():
- if isinstance(value, np.ndarray):
- dictionary[key] = value.tolist()
-
- # make sure private name "_processor_class" is correctly
- # saved as "processor_class"
- _processor_class = dictionary.pop("_processor_class", None)
- if _processor_class is not None:
- dictionary["processor_class"] = _processor_class
-
- return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
-
- def to_json_file(self, json_file_path: Union[str, os.PathLike]):
- """
- Save this instance to a JSON file.
-
- Args:
- json_file_path (`str` or `os.PathLike`):
- Path to the JSON file in which this image_processor instance's parameters will be saved.
- """
- with open(json_file_path, "w", encoding="utf-8") as writer:
- writer.write(self.to_json_string())
-
- def __repr__(self):
- return f"{self.__class__.__name__} {self.to_json_string()}"
-
- @classmethod
- def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
- """
- Register this class with a given auto class. This should only be used for custom image processors as the ones
- in the library are already mapped with `AutoImageProcessor `.
-
-
-
- This API is experimental and may have some slight breaking changes in the next releases.
-
-
-
- Args:
- auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`):
- The auto class to register this new image processor with.
- """
- if not isinstance(auto_class, str):
- auto_class = auto_class.__name__
-
- import transformers.models.auto as auto_module
-
- if not hasattr(auto_module, auto_class):
- raise ValueError(f"{auto_class} is not a valid auto class.")
-
- cls._auto_class = auto_class
-
- def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
- """
- Convert a single or a list of urls into the corresponding `PIL.Image` objects.
-
- If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
- returned.
- """
- headers = {
- "User-Agent": (
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
- " Safari/537.36"
- )
- }
- if isinstance(image_url_or_urls, list):
- return [self.fetch_images(x) for x in image_url_or_urls]
- elif isinstance(image_url_or_urls, str):
- response = requests.get(image_url_or_urls, stream=True, headers=headers)
- response.raise_for_status()
- return Image.open(BytesIO(response.content))
- else:
- raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
-
-
class BaseImageProcessor(ImageProcessingMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -798,10 +280,3 @@ def select_best_resolution(original_size: tuple, possible_resolutions: list) ->
best_fit = (height, width)
return best_fit
-
-
-ImageProcessingMixin.push_to_hub = copy_func(ImageProcessingMixin.push_to_hub)
-if ImageProcessingMixin.push_to_hub.__doc__ is not None:
- ImageProcessingMixin.push_to_hub.__doc__ = ImageProcessingMixin.push_to_hub.__doc__.format(
- object="image processor", object_class="AutoImageProcessor", object_files="image processor file"
- )
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index ee2b891ca308c9..61b84ce3b268e1 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -19,7 +19,7 @@
from ...image_processing_utils import get_size_dict
from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict
-from ...image_transforms import FusedRescaleNormalize, Rescale, NumpyToTensor
+from ...image_transforms import FusedRescaleNormalize, NumpyToTensor, Rescale
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
@@ -43,7 +43,7 @@
if is_torchvision_available():
- from torchvision.transforms import Compose, Lambda, Normalize, PILToTensor, Resize
+ from torchvision.transforms import Compose, Normalize, PILToTensor, Resize
class ViTImageProcessorFast(BaseImageProcessorFast):
From 8d82609c128637fe71577e7a10aef8d359ecac3b Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 22 May 2024 16:20:33 +0000
Subject: [PATCH 22/40] Safe typing
---
src/transformers/image_processing_utils_fast.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index a7ffc6411f3f96..236639fdd663ce 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -67,13 +67,13 @@ def _same_transforms_settings(self, **kwargs) -> bool:
return False
return True
- def _build_transforms(self, **kwargs) -> Compose:
+ def _build_transforms(self, **kwargs) -> "Compose":
"""
Given the input settings e.g. do_resize, build the image transforms.
"""
raise NotImplementedError
- def set_transforms(self, **kwargs) -> Compose:
+ def set_transforms(self, **kwargs) -> "Compose":
"""
Set the image transforms based on the given settings.
If the settings are the same as the current ones, do nothing.
@@ -87,7 +87,7 @@ def set_transforms(self, **kwargs) -> Compose:
return transforms
@functools.lru_cache(maxsize=1)
- def _maybe_update_transforms(self, **kwargs) -> Compose:
+ def _maybe_update_transforms(self, **kwargs) -> "Compose":
"""
If settings are different from those stored in `self._transform_settings`, update
the image transforms to apply
From 849e27bb8c61cdbf78d7b01b80f8ce25f9752a0a Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 22 May 2024 16:23:51 +0000
Subject: [PATCH 23/40] Fix up
---
src/transformers/image_transforms.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py
index 137c2c1836e186..f64eb3882db1df 100644
--- a/src/transformers/image_transforms.py
+++ b/src/transformers/image_transforms.py
@@ -844,5 +844,6 @@ class NumpyToTensor:
"""
Convert a numpy array to a PyTorch tensor.
"""
+
def __call__(self, image: np.ndarray):
return torch.from_numpy(image)
From fdd4e5decc011a1e6ec8dc1337977a12a43b1055 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 22 May 2024 18:28:45 +0000
Subject: [PATCH 24/40] convert mean/std to tesnor to rescale
---
src/transformers/image_transforms.py | 9 ++++++---
src/transformers/models/vit/image_processing_vit_fast.py | 7 ++++---
2 files changed, 10 insertions(+), 6 deletions(-)
diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py
index f64eb3882db1df..62c607ba1e1fe0 100644
--- a/src/transformers/image_transforms.py
+++ b/src/transformers/image_transforms.py
@@ -378,6 +378,7 @@ def normalize(
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
+
channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
num_channels = image.shape[channel_axis]
@@ -820,8 +821,9 @@ class FusedRescaleNormalize:
"""
def __init__(self, mean, std, rescale_factor: float = 1.0, inplace: bool = False):
- self.mean = mean * (1.0 / rescale_factor)
- self.std = std * (1.0 / rescale_factor)
+ self.mean = torch.tensor(mean) * (1.0 / rescale_factor)
+ self.std = torch.tensor(std) * (1.0 / rescale_factor)
+ self.inplace = inplace
def __call__(self, image: "torch.Tensor"):
image = _cast_tensor_to_float(image)
@@ -837,7 +839,8 @@ def __init__(self, rescale_factor: float = 1.0):
self.rescale_factor = rescale_factor
def __call__(self, image: "torch.Tensor"):
- return image.mul(self.rescale_factor)
+ image = image * self.rescale_factor
+ return image
class NumpyToTensor:
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index 61b84ce3b268e1..aa8d0a864234e3 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -18,6 +18,7 @@
from typing import Any, Dict, List, Optional, Union
from ...image_processing_utils import get_size_dict
+from ...image_processing_base import BatchFeature
from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict
from ...image_transforms import FusedRescaleNormalize, NumpyToTensor, Rescale
from ...image_utils import (
@@ -43,7 +44,7 @@
if is_torchvision_available():
- from torchvision.transforms import Compose, Normalize, PILToTensor, Resize
+ from torchvision.transforms import Compose, PILToTensor, Resize, Lambda, Normalize
class ViTImageProcessorFast(BaseImageProcessorFast):
@@ -121,7 +122,7 @@ def _build_transforms(
size: Dict[str, int],
resample: PILImageResampling,
do_rescale: bool,
- rescale_factor: float, # dummy
+ rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
@@ -286,7 +287,7 @@ def preprocess(
transformed_images = [transforms(image) for image in images]
data = {"pixel_values": torch.vstack(transformed_images)}
- return data
+ return BatchFeature(data, tensor_type=return_tensors)
def to_dict(self) -> Dict[str, Any]:
result = super().to_dict()
From 0ad7e710fd10cdcc3a61d63f58178778c2143069 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Fri, 24 May 2024 18:11:28 +0000
Subject: [PATCH 25/40] Don't store transforms in state
---
.../image_processing_utils_fast.py | 48 +++----------------
.../models/vit/image_processing_vit_fast.py | 12 ++---
2 files changed, 13 insertions(+), 47 deletions(-)
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index 236639fdd663ce..2aa749cb8cb676 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -45,27 +45,6 @@ def __getitem__(self, key):
class BaseImageProcessorFast(BaseImageProcessor):
_transform_params = None
- _transform_settings = None
-
- def _set_transform_settings(self, **kwargs) -> None:
- settings = {}
- for k, v in kwargs.items():
- if k not in self._transform_params:
- raise ValueError(f"Invalid transform parameter {k}={v}.")
- settings[k] = v
- self._transform_settings = settings
-
- def _same_transforms_settings(self, **kwargs) -> bool:
- """
- Check if the current settings are the same as the current transforms.
- """
- if self._transform_settings is None:
- return False
-
- for key, value in kwargs.items():
- if value not in self._transform_settings or value != self._transform_settings[key]:
- return False
- return True
def _build_transforms(self, **kwargs) -> "Compose":
"""
@@ -73,25 +52,12 @@ def _build_transforms(self, **kwargs) -> "Compose":
"""
raise NotImplementedError
- def set_transforms(self, **kwargs) -> "Compose":
- """
- Set the image transforms based on the given settings.
- If the settings are the same as the current ones, do nothing.
- """
- if self._same_transforms_settings(**kwargs):
- return self._transforms
-
- transforms = self._build_transforms(**kwargs)
- self._set_transform_settings(**kwargs)
- self._transforms = transforms
- return transforms
+ def _validate_params(self, **kwargs) -> None:
+ for k, v in kwargs.items():
+ if k not in self._transform_params:
+ raise ValueError(f"Invalid transform parameter {k}={v}.")
@functools.lru_cache(maxsize=1)
- def _maybe_update_transforms(self, **kwargs) -> "Compose":
- """
- If settings are different from those stored in `self._transform_settings`, update
- the image transforms to apply
- """
- if self._same_transforms_settings(**kwargs):
- return self._transforms
- return self.set_transforms(**kwargs)
+ def get_transforms(self, **kwargs) -> "Compose":
+ self._validate_params(**kwargs)
+ return self._build_transforms(**kwargs)
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index aa8d0a864234e3..db2fcb2f29791b 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -133,11 +133,6 @@ def _build_transforms(
"""
transforms = []
- if do_resize:
- transforms.append(
- Resize((size["height"], size["width"]), interpolation=pil_torch_interpolation_mapping[resample])
- )
-
# All PIL and numpy values need to be converted to a torch tensor
# to keep cross compatibility with slow image processors
if image_type == ImageType.PIL:
@@ -147,6 +142,11 @@ def _build_transforms(
# Do we want to permute the channels here?
transforms.append(NumpyToTensor())
+ if do_resize:
+ transforms.append(
+ Resize((size["height"], size["width"]), interpolation=pil_torch_interpolation_mapping[resample])
+ )
+
# We can combine rescale and normalize into a single operation for speed
if do_rescale and do_normalize:
transforms.append(FusedRescaleNormalize(image_mean, image_std, rescale_factor=rescale_factor))
@@ -273,7 +273,7 @@ def preprocess(
image_type=image_type,
)
- transforms = self._maybe_update_transforms(
+ transforms = self.get_transforms(
do_resize=do_resize,
do_rescale=do_rescale,
do_normalize=do_normalize,
From 1b5885b45e4ac8599facc569f2532ab105525f0b Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Fri, 24 May 2024 18:20:03 +0000
Subject: [PATCH 26/40] Fix up
---
src/transformers/models/vit/image_processing_vit_fast.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index db2fcb2f29791b..ebe5b26317a2cb 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -17,8 +17,8 @@
import functools
from typing import Any, Dict, List, Optional, Union
-from ...image_processing_utils import get_size_dict
from ...image_processing_base import BatchFeature
+from ...image_processing_utils import get_size_dict
from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict
from ...image_transforms import FusedRescaleNormalize, NumpyToTensor, Rescale
from ...image_utils import (
@@ -44,7 +44,7 @@
if is_torchvision_available():
- from torchvision.transforms import Compose, PILToTensor, Resize, Lambda, Normalize
+ from torchvision.transforms import Compose, Normalize, PILToTensor, Resize
class ViTImageProcessorFast(BaseImageProcessorFast):
From e29150ca6a7b9d4ab6e17e2dcb6b1e19e77acf90 Mon Sep 17 00:00:00 2001
From: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 5 Jun 2024 16:43:38 +0100
Subject: [PATCH 27/40] Update src/transformers/image_processing_utils_fast.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
---
src/transformers/image_processing_utils_fast.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index 2aa749cb8cb676..daeee3e1bd5bba 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2022 The HuggingFace Inc. team.
+# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
From a1f718b6dd7955aec03599dd505457348bad7ab2 Mon Sep 17 00:00:00 2001
From: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 5 Jun 2024 16:51:41 +0100
Subject: [PATCH 28/40] Update
src/transformers/models/auto/image_processing_auto.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
---
src/transformers/models/auto/image_processing_auto.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 843eae599694b8..6787415c401716 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -467,7 +467,7 @@ def register(
if slow_image_processor_class is not None:
raise ValueError("Cannot specify both image_processor_class and slow_image_processor_class")
warnings.warn(
- "The image_processor_class argument is deprecated and will be removed in v4.42. Please use slow_image_processor_class, or fast_image_processor_class instead",
+ "The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead",
FutureWarning,
)
slow_image_processor_class = image_processor_class
From af52ee2c12cc478a601ddc00c16f4d20cbf5e0cc Mon Sep 17 00:00:00 2001
From: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 5 Jun 2024 16:52:10 +0100
Subject: [PATCH 29/40] Update
src/transformers/models/auto/image_processing_auto.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
---
src/transformers/models/auto/image_processing_auto.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 6787415c401716..a477d0a51838e2 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -486,7 +486,7 @@ def register(
and fast_image_processor_class.slow_image_processor_class != slow_image_processor_class
):
raise ValueError(
- "The fast tokenizer class you are passing has a `slow_image_processor_class` attribute that is not "
+ "The fast processor class you are passing has a `slow_image_processor_class` attribute that is not "
"consistent with the slow tokenizer class you passed (fast tokenizer has "
f"{fast_image_processor_class.slow_image_processor_class} and you passed {slow_image_processor_class}. Fix one of those "
"so they match!"
From 34b8859b832778732ee0acfe4221e83b129f48b6 Mon Sep 17 00:00:00 2001
From: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 5 Jun 2024 16:52:25 +0100
Subject: [PATCH 30/40] Update
src/transformers/models/auto/image_processing_auto.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
---
src/transformers/models/auto/image_processing_auto.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index a477d0a51838e2..768703eddb0316 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -487,7 +487,7 @@ def register(
):
raise ValueError(
"The fast processor class you are passing has a `slow_image_processor_class` attribute that is not "
- "consistent with the slow tokenizer class you passed (fast tokenizer has "
+ "consistent with the slow processor class you passed (fast tokenizer has "
f"{fast_image_processor_class.slow_image_processor_class} and you passed {slow_image_processor_class}. Fix one of those "
"so they match!"
)
From 5e7a30d48d7b8bd94715760360c475a3a7357949 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 5 Jun 2024 16:26:38 +0000
Subject: [PATCH 31/40] Warn if fast image processor available
---
.../models/auto/image_processing_auto.py | 13 +++++++++++++
1 file changed, 13 insertions(+)
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 768703eddb0316..3c5121a4791cad 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -267,6 +267,13 @@ def get_image_processor_config(
return json.load(reader)
+def _warning_fast_image_processor_available(fast_class):
+ logger.warning(
+ f"Fast image processor class {fast_class} is available for this model. "
+ "Using slow image processor class. To use the fast image processor class set `use_fast=True`."
+ )
+
+
class AutoImageProcessor:
r"""
This is a generic image processor class that will be instantiated as one of the image processor classes of the
@@ -414,6 +421,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
image_processor_auto_map = (image_processor_auto_map, None)
if has_remote_code and trust_remote_code:
+ if not use_fast and image_processor_auto_map[1] is not None:
+ _warning_fast_image_processor_available(image_processor_auto_map[1])
+
if use_fast and image_processor_auto_map[1] is not None:
class_ref = image_processor_auto_map[1]
else:
@@ -431,6 +441,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
image_processor_class_py, image_processor_class_fast = image_processor_tuple
+ if not use_fast and image_processor_class_fast is not None:
+ _warning_fast_image_processor_available(image_processor_class_fast)
+
if image_processor_class_fast and (use_fast or image_processor_class_py is None):
return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
From 2d756071ccf9bec509b696cd201961e76e2f583a Mon Sep 17 00:00:00 2001
From: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 5 Jun 2024 17:29:34 +0100
Subject: [PATCH 32/40] Update
src/transformers/models/vit/image_processing_vit_fast.py
---
src/transformers/models/vit/image_processing_vit_fast.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index ebe5b26317a2cb..48100899dd4de4 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -139,7 +139,6 @@ def _build_transforms(
transforms.append(PILToTensor())
elif image_type == ImageType.NUMPY:
- # Do we want to permute the channels here?
transforms.append(NumpyToTensor())
if do_resize:
From a43cabc398b89a5a2d6f304a5c2924ba756560a3 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 5 Jun 2024 16:34:49 +0000
Subject: [PATCH 33/40] Transpose incoming numpy images to be in CHW format
---
src/transformers/image_transforms.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py
index 62c607ba1e1fe0..4e4812879eed1c 100644
--- a/src/transformers/image_transforms.py
+++ b/src/transformers/image_transforms.py
@@ -849,4 +849,6 @@ class NumpyToTensor:
"""
def __call__(self, image: np.ndarray):
- return torch.from_numpy(image)
+ # Same as in PyTorch, we assume incoming numpy images are in HWC format
+ # c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154
+ return torch.from_numpy(image.transpose(2, 0, 1)).contiguous()
From 6acf27f4e1b07ca59c18adde3ed1847618b087f7 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 5 Jun 2024 17:25:21 +0000
Subject: [PATCH 34/40] Update mapping names based on packages, auto set fast
to None
---
.../models/auto/image_processing_auto.py | 193 ++++++++++--------
1 file changed, 106 insertions(+), 87 deletions(-)
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 3c5121a4791cad..cfbefc46f605d1 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -26,7 +26,14 @@
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...image_processing_utils import BaseImageProcessor, ImageProcessingMixin
from ...image_processing_utils_fast import BaseImageProcessorFast
-from ...utils import CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, is_torchvision_available, logging
+from ...utils import (
+ CONFIG_NAME,
+ IMAGE_PROCESSOR_NAME,
+ get_file_from_repo,
+ is_torchvision_available,
+ is_vision_available,
+ logging,
+)
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
@@ -46,96 +53,108 @@
else:
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
[
- ("align", ("EfficientNetImageProcessor", None)),
- ("beit", ("BeitImageProcessor", None)),
- ("bit", ("BitImageProcessor", None)),
- ("blip", ("BlipImageProcessor", None)),
- ("blip-2", ("BlipImageProcessor", None)),
- ("bridgetower", ("BridgeTowerImageProcessor", None)),
- ("chinese_clip", ("ChineseCLIPImageProcessor", None)),
- ("clip", ("CLIPImageProcessor", None)),
- ("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
- ("conditional_detr", ("ConditionalDetrImageProcessor", None)),
- ("convnext", ("ConvNextImageProcessor", None)),
- ("convnextv2", ("ConvNextImageProcessor", None)),
- ("cvt", ("ConvNextImageProcessor", None)),
- ("data2vec-vision", ("BeitImageProcessor", None)),
- ("deformable_detr", ("DeformableDetrImageProcessor", None)),
- ("deit", ("DeiTImageProcessor", None)),
- ("depth_anything", ("DPTImageProcessor", None)),
- ("deta", ("DetaImageProcessor", None)),
- ("detr", ("DetrImageProcessor", None)),
- ("dinat", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
- ("dinov2", ("BitImageProcessor", None)),
- ("donut-swin", ("DonutImageProcessor", None)),
- ("dpt", ("DPTImageProcessor", None)),
- ("efficientformer", ("EfficientFormerImageProcessor", None)),
- ("efficientnet", ("EfficientNetImageProcessor", None)),
- ("flava", ("FlavaImageProcessor", None)),
- ("focalnet", ("BitImageProcessor", None)),
- ("fuyu", ("FuyuImageProcessor", None)),
- ("git", ("CLIPImageProcessor", None)),
- ("glpn", ("GLPNImageProcessor", None)),
- ("grounding-dino", ("GroundingDinoImageProcessor", None)),
- ("groupvit", ("CLIPImageProcessor", None)),
- ("idefics", ("IdeficsImageProcessor", None)),
- ("idefics2", ("Idefics2ImageProcessor", None)),
- ("imagegpt", ("ImageGPTImageProcessor", None)),
- ("instructblip", ("BlipImageProcessor", None)),
- ("kosmos-2", ("CLIPImageProcessor", None)),
- ("layoutlmv2", ("LayoutLMv2ImageProcessor", None)),
- ("layoutlmv3", ("LayoutLMv3ImageProcessor", None)),
- ("levit", ("LevitImageProcessor", None)),
- ("llava", ("CLIPImageProcessor", None)),
- ("llava_next", ("LlavaNextImageProcessor", None)),
- ("mask2former", ("Mask2FormerImageProcessor", None)),
- ("maskformer", ("MaskFormerImageProcessor", None)),
- ("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
- ("mobilenet_v1", ("MobileNetV1ImageProcessor", None)),
- ("mobilenet_v2", ("MobileNetV2ImageProcessor", None)),
- ("mobilevit", ("MobileViTImageProcessor", None)),
- ("mobilevit", ("MobileViTImageProcessor", None)),
- ("mobilevitv2", ("MobileViTImageProcessor", None)),
- ("nat", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
- ("nougat", ("NougatImageProcessor", None)),
- ("oneformer", ("OneFormerImageProcessor", None)),
- ("owlv2", ("Owlv2ImageProcessor", None)),
- ("owlvit", ("OwlViTImageProcessor", None)),
- ("perceiver", ("PerceiverImageProcessor", None)),
- ("pix2struct", ("Pix2StructImageProcessor", None)),
- ("poolformer", ("PoolFormerImageProcessor", None)),
- ("pvt", ("PvtImageProcessor", None)),
- ("pvt_v2", ("PvtImageProcessor", None)),
- ("regnet", ("ConvNextImageProcessor", None)),
- ("resnet", ("ConvNextImageProcessor", None)),
- ("sam", ("SamImageProcessor", None)),
- ("segformer", ("SegformerImageProcessor", None)),
- ("seggpt", ("SegGptImageProcessor", None)),
- ("siglip", ("SiglipImageProcessor", None)),
- ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
+ ("align", ("EfficientNetImageProcessor",)),
+ ("beit", ("BeitImageProcessor",)),
+ ("bit", ("BitImageProcessor",)),
+ ("blip", ("BlipImageProcessor",)),
+ ("blip-2", ("BlipImageProcessor",)),
+ ("bridgetower", ("BridgeTowerImageProcessor",)),
+ ("chinese_clip", ("ChineseCLIPImageProcessor",)),
+ ("clip", ("CLIPImageProcessor",)),
+ ("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("conditional_detr", ("ConditionalDetrImageProcessor",)),
+ ("convnext", ("ConvNextImageProcessor",)),
+ ("convnextv2", ("ConvNextImageProcessor",)),
+ ("cvt", ("ConvNextImageProcessor",)),
+ ("data2vec-vision", ("BeitImageProcessor",)),
+ ("deformable_detr", ("DeformableDetrImageProcessor",)),
+ ("deit", ("DeiTImageProcessor",)),
+ ("depth_anything", ("DPTImageProcessor",)),
+ ("deta", ("DetaImageProcessor",)),
+ ("detr", ("DetrImageProcessor",)),
+ ("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("dinov2", ("BitImageProcessor",)),
+ ("donut-swin", ("DonutImageProcessor",)),
+ ("dpt", ("DPTImageProcessor",)),
+ ("efficientformer", ("EfficientFormerImageProcessor",)),
+ ("efficientnet", ("EfficientNetImageProcessor",)),
+ ("flava", ("FlavaImageProcessor",)),
+ ("focalnet", ("BitImageProcessor",)),
+ ("fuyu", ("FuyuImageProcessor",)),
+ ("git", ("CLIPImageProcessor",)),
+ ("glpn", ("GLPNImageProcessor",)),
+ ("grounding-dino", ("GroundingDinoImageProcessor",)),
+ ("groupvit", ("CLIPImageProcessor",)),
+ ("idefics", ("IdeficsImageProcessor",)),
+ ("idefics2", ("Idefics2ImageProcessor",)),
+ ("imagegpt", ("ImageGPTImageProcessor",)),
+ ("instructblip", ("BlipImageProcessor",)),
+ ("kosmos-2", ("CLIPImageProcessor",)),
+ ("layoutlmv2", ("LayoutLMv2ImageProcessor",)),
+ ("layoutlmv3", ("LayoutLMv3ImageProcessor",)),
+ ("levit", ("LevitImageProcessor",)),
+ ("llava", ("CLIPImageProcessor",)),
+ ("llava_next", ("LlavaNextImageProcessor",)),
+ ("mask2former", ("Mask2FormerImageProcessor",)),
+ ("maskformer", ("MaskFormerImageProcessor",)),
+ ("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("mobilenet_v1", ("MobileNetV1ImageProcessor",)),
+ ("mobilenet_v2", ("MobileNetV2ImageProcessor",)),
+ ("mobilevit", ("MobileViTImageProcessor",)),
+ ("mobilevit", ("MobileViTImageProcessor",)),
+ ("mobilevitv2", ("MobileViTImageProcessor",)),
+ ("nat", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("nougat", ("NougatImageProcessor",)),
+ ("oneformer", ("OneFormerImageProcessor",)),
+ ("owlv2", ("Owlv2ImageProcessor",)),
+ ("owlvit", ("OwlViTImageProcessor",)),
+ ("perceiver", ("PerceiverImageProcessor",)),
+ ("pix2struct", ("Pix2StructImageProcessor",)),
+ ("poolformer", ("PoolFormerImageProcessor",)),
+ ("pvt", ("PvtImageProcessor",)),
+ ("pvt_v2", ("PvtImageProcessor",)),
+ ("regnet", ("ConvNextImageProcessor",)),
+ ("resnet", ("ConvNextImageProcessor",)),
+ ("sam", ("SamImageProcessor",)),
+ ("segformer", ("SegformerImageProcessor",)),
+ ("seggpt", ("SegGptImageProcessor",)),
+ ("siglip", ("SiglipImageProcessor",)),
+ ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
- ("swin2sr", ("Swin2SRImageProcessor", None)),
- ("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
- ("table-transformer", ("DetrImageProcessor", None)),
- ("timesformer", ("VideoMAEImageProcessor", None)),
- ("tvlt", ("TvltImageProcessor", None)),
- ("tvp", ("TvpImageProcessor", None)),
- ("udop", ("LayoutLMv3ImageProcessor", None)),
- ("upernet", ("SegformerImageProcessor", None)),
- ("van", ("ConvNextImageProcessor", None)),
- ("videomae", ("VideoMAEImageProcessor", None)),
- ("vilt", ("ViltImageProcessor", None)),
- ("vipllava", ("CLIPImageProcessor", None)),
- ("vit", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
- ("vit_hybrid", ("ViTHybridImageProcessor", None)),
- ("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
- ("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast" if is_torchvision_available() else None)),
- ("vitmatte", ("VitMatteImageProcessor", None)),
- ("xclip", ("CLIPImageProcessor", None)),
- ("yolos", ("YolosImageProcessor", None)),
+ ("swin2sr", ("Swin2SRImageProcessor",)),
+ ("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("table-transformer", ("DetrImageProcessor",)),
+ ("timesformer", ("VideoMAEImageProcessor",)),
+ ("tvlt", ("TvltImageProcessor",)),
+ ("tvp", ("TvpImageProcessor",)),
+ ("udop", ("LayoutLMv3ImageProcessor",)),
+ ("upernet", ("SegformerImageProcessor",)),
+ ("van", ("ConvNextImageProcessor",)),
+ ("videomae", ("VideoMAEImageProcessor",)),
+ ("vilt", ("ViltImageProcessor",)),
+ ("vipllava", ("CLIPImageProcessor",)),
+ ("vit", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("vit_hybrid", ("ViTHybridImageProcessor",)),
+ ("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")),
+ ("vitmatte", ("VitMatteImageProcessor",)),
+ ("xclip", ("CLIPImageProcessor",)),
+ ("yolos", ("YolosImageProcessor",)),
]
)
+for model_type, image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
+ slow_image_processor_class, *fast_image_processor_class = image_processors
+ if not is_vision_available():
+ slow_image_processor_class = None
+
+ # If the fast image processor is not defined, or torchvision is not available, we set it to None
+ if not fast_image_processor_class or fast_image_processor_class[0] is None or not is_torchvision_available():
+ fast_image_processor_class = None
+
+ IMAGE_PROCESSOR_MAPPING_NAMES[model_type] = (slow_image_processor_class, fast_image_processor_class)
+
+
IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
From a38d3ee81d72007a02228b3eb6c2565da2349c76 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 5 Jun 2024 17:38:30 +0000
Subject: [PATCH 35/40] Fix up
---
src/transformers/__init__.py | 2 --
src/transformers/utils/dummy_vision_objects.py | 7 -------
2 files changed, 9 deletions(-)
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index c668c397979745..4976a4a1b90e7e 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -1164,7 +1164,6 @@
_import_structure["models.videomae"].extend(["VideoMAEFeatureExtractor", "VideoMAEImageProcessor"])
_import_structure["models.vilt"].extend(["ViltFeatureExtractor", "ViltImageProcessor", "ViltProcessor"])
_import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor"])
- _import_structure["models.vit_hybrid"].extend(["ViTHybridImageProcessor"])
_import_structure["models.vitmatte"].append("VitMatteImageProcessor")
_import_structure["models.vivit"].append("VivitImageProcessor")
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
@@ -5804,7 +5803,6 @@
from .models.videomae import VideoMAEFeatureExtractor, VideoMAEImageProcessor
from .models.vilt import ViltFeatureExtractor, ViltImageProcessor, ViltProcessor
from .models.vit import ViTFeatureExtractor, ViTImageProcessor
- from .models.vit_hybrid import ViTHybridImageProcessor
from .models.vitmatte import VitMatteImageProcessor
from .models.vivit import VivitImageProcessor
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py
index e60e869dcf7af0..a27dc024447f42 100644
--- a/src/transformers/utils/dummy_vision_objects.py
+++ b/src/transformers/utils/dummy_vision_objects.py
@@ -604,13 +604,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
-class ViTHybridImageProcessor(metaclass=DummyObject):
- _backends = ["vision"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["vision"])
-
-
class VitMatteImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
From 942286f70e9abcaadeb98113e7ca07d393a40581 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 5 Jun 2024 17:50:36 +0000
Subject: [PATCH 36/40] Fix
---
src/transformers/models/auto/image_processing_auto.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index cfbefc46f605d1..547fe7f2b8dbc9 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -151,6 +151,8 @@
# If the fast image processor is not defined, or torchvision is not available, we set it to None
if not fast_image_processor_class or fast_image_processor_class[0] is None or not is_torchvision_available():
fast_image_processor_class = None
+ else:
+ fast_image_processor_class = fast_image_processor_class[0]
IMAGE_PROCESSOR_MAPPING_NAMES[model_type] = (slow_image_processor_class, fast_image_processor_class)
From 954ee20f59385941a57b02ba93d80b52920122b6 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Wed, 5 Jun 2024 18:40:04 +0000
Subject: [PATCH 37/40] Add AutoImageProcessor.from_pretrained(checkpoint,
use_fast=True) test
---
.../models/auto/image_processing_auto.py | 5 +++++
.../models/auto/test_image_processing_auto.py | 21 ++++++++++++++++++-
2 files changed, 25 insertions(+), 1 deletion(-)
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 547fe7f2b8dbc9..b316a1a55ddeed 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -429,6 +429,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
image_processor_auto_map = config.auto_map["AutoImageProcessor"]
if image_processor_class is not None:
+ # Update class name to reflect the use_fast option. If class is not found, None is returned.
+ if use_fast and not image_processor_class.endswith("Fast"):
+ image_processor_class += "Fast"
+ elif not use_fast and image_processor_class.endswith("Fast"):
+ image_processor_class = image_processor_class[:-4]
image_processor_class = image_processor_class_from_name(image_processor_class)
has_remote_code = image_processor_auto_map is not None
diff --git a/tests/models/auto/test_image_processing_auto.py b/tests/models/auto/test_image_processing_auto.py
index 0fb22b6c2b1f16..b571e7a860b04a 100644
--- a/tests/models/auto/test_image_processing_auto.py
+++ b/tests/models/auto/test_image_processing_auto.py
@@ -27,8 +27,10 @@
AutoImageProcessor,
CLIPConfig,
CLIPImageProcessor,
+ ViTImageProcessor,
+ ViTImageProcessorFast,
)
-from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
+from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_torchvision, require_vision
sys.path.append(str(Path(__file__).parent.parent.parent.parent / "utils"))
@@ -133,6 +135,23 @@ def test_image_processor_not_found(self):
):
_ = AutoImageProcessor.from_pretrained("hf-internal-testing/config-no-model")
+ @require_vision
+ @require_torchvision
+ def test_use_fast_selection(self):
+ checkpoint = "hf-internal-testing/tiny-random-vit"
+
+ # Slow image processor is selected by default
+ image_processor = AutoImageProcessor.from_pretrained(checkpoint)
+ self.assertIsInstance(image_processor, ViTImageProcessor)
+
+ # Fast image processor is selected when use_fast=True
+ image_processor = AutoImageProcessor.from_pretrained(checkpoint, use_fast=True)
+ self.assertIsInstance(image_processor, ViTImageProcessorFast)
+
+ # Slow image processor is selected when use_fast=False
+ image_processor = AutoImageProcessor.from_pretrained(checkpoint, use_fast=False)
+ self.assertIsInstance(image_processor, ViTImageProcessor)
+
def test_from_pretrained_dynamic_image_processor(self):
# If remote code is not set, we will time out when asking whether to load the model.
with self.assertRaises(ValueError):
From ee06a6af09119bf2da07a098b24f787206f24a8b Mon Sep 17 00:00:00 2001
From: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Date: Thu, 6 Jun 2024 21:55:21 +0100
Subject: [PATCH 38/40] Update
src/transformers/models/vit/image_processing_vit_fast.py
Co-authored-by: Pavel Iakubovskii
---
src/transformers/models/vit/image_processing_vit_fast.py | 6 ------
1 file changed, 6 deletions(-)
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index 48100899dd4de4..630e41cfed91ce 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -287,9 +287,3 @@ def preprocess(
data = {"pixel_values": torch.vstack(transformed_images)}
return BatchFeature(data, tensor_type=return_tensors)
-
- def to_dict(self) -> Dict[str, Any]:
- result = super().to_dict()
- result.pop("_transforms", None)
- result.pop("_transform_settings", None)
- return result
From 1d1d41660303effa542a4c6b8c05edf0431e480a Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Fri, 7 Jun 2024 16:51:09 +0000
Subject: [PATCH 39/40] Add equivalence and speed tests
---
tests/test_image_processing_common.py | 51 +++++++++++++++++++++++++++
1 file changed, 51 insertions(+)
diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py
index d929997ee87369..e9b9467f580a2a 100644
--- a/tests/test_image_processing_common.py
+++ b/tests/test_image_processing_common.py
@@ -19,6 +19,8 @@
import pathlib
import tempfile
+import requests
+
from transformers import AutoImageProcessor, BatchFeature
from transformers.image_utils import AnnotationFormat, AnnotionFormat
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision
@@ -146,6 +148,55 @@ def setUp(self):
self.image_processor_list = image_processor_list
+ @require_vision
+ @require_torch
+ def test_slow_fast_equivalence(self):
+ dummy_image = Image.open(
+ requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
+ )
+
+ if not self.test_slow_image_processor or not self.test_fast_image_processor:
+ self.skipTest("Skipping slow/fast equivalence test")
+
+ if self.image_processing_class is None or self.fast_image_processing_class is None:
+ self.skipTest("Skipping slow/fast equivalence test as one of the image processors is not defined")
+
+ image_processor_slow = self.image_processing_class(**self.image_processor_dict)
+ image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
+
+ encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
+ encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
+
+ self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-3))
+
+ @require_vision
+ @require_torch
+ def test_fast_is_faster_than_slow(self):
+ import time
+
+ def measure_time(self, image_processor, dummy_image):
+ start = time.time()
+ _ = image_processor(dummy_image, return_tensors="pt")
+ return time.time() - start
+
+ dummy_image = Image.open(
+ requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
+ )
+
+ if not self.test_slow_image_processor or not self.test_fast_image_processor:
+ self.skipTest("Skipping speed test")
+
+ if self.image_processing_class is None or self.fast_image_processing_class is None:
+ self.skipTest("Skipping speed test as one of the image processors is not defined")
+
+ image_processor_slow = self.image_processing_class(**self.image_processor_dict)
+ image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
+
+ slow_time = self.measure_time(image_processor_slow, dummy_image)
+ fast_time = self.measure_time(image_processor_fast, dummy_image)
+
+ self.assertLessEqual(fast_time, slow_time)
+
def test_image_processor_to_json_string(self):
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
From d598b5aa5a3918308d92924771b9e04ec3e7b561 Mon Sep 17 00:00:00 2001
From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com>
Date: Fri, 7 Jun 2024 17:01:23 +0000
Subject: [PATCH 40/40] Fix up
---
src/transformers/models/vit/image_processing_vit_fast.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py
index 630e41cfed91ce..09113761655b93 100644
--- a/src/transformers/models/vit/image_processing_vit_fast.py
+++ b/src/transformers/models/vit/image_processing_vit_fast.py
@@ -15,7 +15,7 @@
"""Fast Image processor class for ViT."""
import functools
-from typing import Any, Dict, List, Optional, Union
+from typing import Dict, List, Optional, Union
from ...image_processing_base import BatchFeature
from ...image_processing_utils import get_size_dict