From d4722d4f58628bf2c6f8c99fa595014105723391 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Fri, 2 Feb 2024 20:25:29 +0000 Subject: [PATCH] Tidy up --- docs/source/en/model_doc/vit.md | 7 +- src/transformers/__init__.py | 8 +- .../image_processing_utils_fast.py | 60 +++++++-- .../models/vit/image_processing_vit_fast.py | 115 +----------------- .../utils/dummy_vision_objects.py | 14 +++ 5 files changed, 83 insertions(+), 121 deletions(-) diff --git a/docs/source/en/model_doc/vit.md b/docs/source/en/model_doc/vit.md index 25c3a6c8f537f4..f785da9c74faf2 100644 --- a/docs/source/en/model_doc/vit.md +++ b/docs/source/en/model_doc/vit.md @@ -62,7 +62,7 @@ Following the original Vision Transformer, some follow-up works have been made: This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be found [here](https://github.com/google-research/vision_transformer). -Note that we converted the weights from Ross Wightman's [timm library](https://github.com/rwightman/pytorch-image-models), +Note that we converted the weights from Ross Wightman's [timm library](https://github.com/rwightman/pytorch-image-models), who already converted the weights from JAX to PyTorch. Credits go to him! ## Usage tips @@ -130,6 +130,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] ViTImageProcessor - preprocess +## ViTImageProcessorFast + +[[autodoc]] ViTImageProcessorFast + - preprocess + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7cb41c82f6aedb..bdc5faa02a932f 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1266,7 +1266,9 @@ name for name in dir(dummy_vision_objects) if not name.startswith("_") ] else: - _import_structure["image_processing_utils"] = ["ImageProcessingMixin"] + _import_structure["image_processing_base"] = ["ImageProcessingMixin"] + _import_structure["image_processing_utils"] = ["BaseImageProcessor"] + _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"] _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] _import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"]) _import_structure["models.bit"].extend(["BitImageProcessor"]) @@ -5976,7 +5978,9 @@ except OptionalDependencyNotAvailable: from .utils.dummy_vision_objects import * else: - from .image_processing_utils import ImageProcessingMixin + from .image_processing_base import ImageProcessingMixin + from .image_processing_utils import BaseImageProcessor + from .image_processing_utils_fast import BaseImageProcessorFast from .image_utils import ImageFeatureExtractionMixin from .models.beit import BeitFeatureExtractor, BeitImageProcessor from .models.bit import BitImageProcessor diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index 44a12bb9b50541..e84c55d03ae1f8 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -1,12 +1,58 @@ -from .image_processing_base import ImageProcessingMixin +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import cache -class BaseImageProcessorFast(ImageProcessingMixin): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +from .image_processing_utils import BaseImageProcessor - def __call__(self, images, **kwargs): - return self.preprocess(images, **kwargs) - def preprocess(self, images, **kwargs): +class BaseImageProcessorFast(BaseImageProcessor): + _transform_params = None + + def _set_transform_settings(self, **kwargs): + settings = {} + for k, v in kwargs.items(): + if k not in self._transform_params: + raise ValueError(f"Invalid transform parameter {k}={v}.") + settings[k] = v + self._transform_settings = settings + + def _same_transforms_settings(self, **kwargs): + """ + Check if the current settings are the same as the current transforms. + """ + for key, value in kwargs.items(): + if value not in self._transform_settings or value != self._transform_settings[key]: + return False + return True + + def _build_transforms(self, **kwargs): raise NotImplementedError + + def set_transforms(self, **kwargs): + # FIXME - put input validation or kwargs for all these methods + + if self._same_transforms_settings(**kwargs): + return self._transforms + + transforms = self._build_transforms(**kwargs) + self._set_transform_settings(**kwargs) + self._transforms = transforms + + @cache + def _maybe_update_transforms(self, **kwargs): + if self._same_transforms_settings(**kwargs): + return + self.set_transforms(**kwargs) diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py index fcc08c91b04c4b..8dcfbe340fb456 100644 --- a/src/transformers/models/vit/image_processing_vit_fast.py +++ b/src/transformers/models/vit/image_processing_vit_fast.py @@ -14,12 +14,14 @@ # limitations under the License. """Image processor class for ViT.""" +from dataclasses import dataclass from functools import cache from typing import Any, Dict, List, Optional, Union import numpy as np -from ...image_processing_utils import BaseImageProcessor, get_size_dict +from ...image_processing_utils import get_size_dict +from ...image_processing_utils_fast import BaseImageProcessorFast from ...image_utils import ( IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, @@ -55,9 +57,6 @@ } -from dataclasses import dataclass - - @dataclass(frozen=True) class SizeDict: height: int = None @@ -65,25 +64,13 @@ class SizeDict: longest_edge: int = None shortest_edge: int = None - def todict(self): - output = {} - if self.height is not None: - output["height"] = self.height - if self.width is not None: - output["width"] = self.width - if self.longest_edge is not None: - output["longest_edge"] = self.longest_edge - if self.shortest_edge is not None: - output["shortest_edge"] = self.shortest_edge - return output - def __getitem__(self, key): if hasattr(self, key): return getattr(self, key) raise KeyError(f"Key {key} not found in SizeDict.") -class ViTImageProcessorFast(BaseImageProcessor): +class ViTImageProcessorFast(BaseImageProcessorFast): r""" Constructs a ViT image processor. @@ -161,23 +148,6 @@ def __init__( image_std=image_std, ) - def _set_transform_settings(self, **kwargs): - settings = {} - for k, v in kwargs.items(): - if k not in self._transform_params: - raise ValueError(f"Invalid transform parameter {k}={v}.") - settings[k] = v - self._transform_settings = settings - - def _same_transforms_settings(self, **kwargs): - """ - Check if the current settings are the same as the current transforms. - """ - for key, value in kwargs.items(): - if value not in self._transform_settings or value != self._transform_settings[key]: - return False - return True - def _build_transforms( self, do_resize: bool, @@ -201,49 +171,6 @@ def _build_transforms( transforms.append(Normalize(image_mean, image_std)) return Compose(transforms) - def set_transforms( - self, - do_resize: bool, - do_rescale: bool, - do_normalize: bool, - size: Dict[str, int], - resample: PILImageResampling, - rescale_factor: float, - image_mean: Union[float, List[float]], - image_std: Union[float, List[float]], - ): - if self._same_transforms_settings( - do_resize=do_resize, - do_rescale=do_rescale, - do_normalize=do_normalize, - size=size, - resample=resample, - rescale_factor=rescale_factor, - image_mean=image_mean, - image_std=image_std, - ): - return self._transforms - transforms = self._build_transforms( - do_resize=do_resize, - size_dict=size, - resample=resample, - do_rescale=do_rescale, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - ) - self._set_transform_settings( - do_resize=do_resize, - do_rescale=do_rescale, - do_normalize=do_normalize, - size=size, - resample=resample, - rescale_factor=rescale_factor, - image_mean=image_mean, - image_std=image_std, - ) - self._transforms = transforms - @cache def _validate_input_arguments( self, @@ -270,40 +197,6 @@ def _validate_input_arguments( if do_rescale and rescale_factor is None: raise ValueError("Rescale factor must be specified if do_rescale is True.") - @cache - def _maybe_update_transforms( - self, - do_resize: bool, - size: Dict[str, int], - resample: PILImageResampling, - do_rescale: bool, - rescale_factor: float, - do_normalize: bool, - image_mean: Union[float, List[float]], - image_std: Union[float, List[float]], - ): - if self._same_transforms_settings( - do_resize=do_resize, - do_rescale=do_rescale, - do_normalize=do_normalize, - size=size, - resample=resample, - rescale_factor=rescale_factor, - image_mean=image_mean, - image_std=image_std, - ): - return - self.set_transforms( - do_resize=do_resize, - do_rescale=do_rescale, - do_normalize=do_normalize, - size=size, - resample=resample, - rescale_factor=rescale_factor, - image_mean=image_mean, - image_std=image_std, - ) - def preprocess( self, images: ImageInput, diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 3c17d96096a2e2..29f13b1f29da13 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -9,6 +9,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class BaseImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + +class BaseImageProcessorFast(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class ImageFeatureExtractionMixin(metaclass=DummyObject): _backends = ["vision"]