Skip to content

Commit

Permalink
Tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Feb 2, 2024
1 parent 88c87e0 commit d4722d4
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 121 deletions.
7 changes: 6 additions & 1 deletion docs/source/en/model_doc/vit.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Following the original Vision Transformer, some follow-up works have been made:
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be
found [here](https://github.com/google-research/vision_transformer).

Note that we converted the weights from Ross Wightman's [timm library](https://github.com/rwightman/pytorch-image-models),
Note that we converted the weights from Ross Wightman's [timm library](https://github.com/rwightman/pytorch-image-models),
who already converted the weights from JAX to PyTorch. Credits go to him!

## Usage tips
Expand Down Expand Up @@ -130,6 +130,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] ViTImageProcessor
- preprocess

## ViTImageProcessorFast

[[autodoc]] ViTImageProcessorFast
- preprocess

<frameworkcontent>
<pt>

Expand Down
8 changes: 6 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,9 @@
name for name in dir(dummy_vision_objects) if not name.startswith("_")
]
else:
_import_structure["image_processing_utils"] = ["ImageProcessingMixin"]
_import_structure["image_processing_base"] = ["ImageProcessingMixin"]
_import_structure["image_processing_utils"] = ["BaseImageProcessor"]
_import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
_import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"])
_import_structure["models.bit"].extend(["BitImageProcessor"])
Expand Down Expand Up @@ -5976,7 +5978,9 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_vision_objects import *
else:
from .image_processing_utils import ImageProcessingMixin
from .image_processing_base import ImageProcessingMixin
from .image_processing_utils import BaseImageProcessor
from .image_processing_utils_fast import BaseImageProcessorFast
from .image_utils import ImageFeatureExtractionMixin
from .models.beit import BeitFeatureExtractor, BeitImageProcessor
from .models.bit import BitImageProcessor
Expand Down
60 changes: 53 additions & 7 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,58 @@
from .image_processing_base import ImageProcessingMixin
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import cache

class BaseImageProcessorFast(ImageProcessingMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
from .image_processing_utils import BaseImageProcessor

def __call__(self, images, **kwargs):
return self.preprocess(images, **kwargs)

def preprocess(self, images, **kwargs):
class BaseImageProcessorFast(BaseImageProcessor):
_transform_params = None

def _set_transform_settings(self, **kwargs):
settings = {}
for k, v in kwargs.items():
if k not in self._transform_params:
raise ValueError(f"Invalid transform parameter {k}={v}.")
settings[k] = v
self._transform_settings = settings

def _same_transforms_settings(self, **kwargs):
"""
Check if the current settings are the same as the current transforms.
"""
for key, value in kwargs.items():
if value not in self._transform_settings or value != self._transform_settings[key]:
return False
return True

def _build_transforms(self, **kwargs):
raise NotImplementedError

def set_transforms(self, **kwargs):
# FIXME - put input validation or kwargs for all these methods

if self._same_transforms_settings(**kwargs):
return self._transforms

transforms = self._build_transforms(**kwargs)
self._set_transform_settings(**kwargs)
self._transforms = transforms

@cache
def _maybe_update_transforms(self, **kwargs):
if self._same_transforms_settings(**kwargs):
return
self.set_transforms(**kwargs)
115 changes: 4 additions & 111 deletions src/transformers/models/vit/image_processing_vit_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# limitations under the License.
"""Image processor class for ViT."""

from dataclasses import dataclass
from functools import cache
from typing import Any, Dict, List, Optional, Union

import numpy as np

from ...image_processing_utils import BaseImageProcessor, get_size_dict
from ...image_processing_utils import get_size_dict
from ...image_processing_utils_fast import BaseImageProcessorFast
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
Expand Down Expand Up @@ -55,35 +57,20 @@
}


from dataclasses import dataclass


@dataclass(frozen=True)
class SizeDict:
height: int = None
width: int = None
longest_edge: int = None
shortest_edge: int = None

def todict(self):
output = {}
if self.height is not None:
output["height"] = self.height
if self.width is not None:
output["width"] = self.width
if self.longest_edge is not None:
output["longest_edge"] = self.longest_edge
if self.shortest_edge is not None:
output["shortest_edge"] = self.shortest_edge
return output

def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
raise KeyError(f"Key {key} not found in SizeDict.")


class ViTImageProcessorFast(BaseImageProcessor):
class ViTImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a ViT image processor.
Expand Down Expand Up @@ -161,23 +148,6 @@ def __init__(
image_std=image_std,
)

def _set_transform_settings(self, **kwargs):
settings = {}
for k, v in kwargs.items():
if k not in self._transform_params:
raise ValueError(f"Invalid transform parameter {k}={v}.")
settings[k] = v
self._transform_settings = settings

def _same_transforms_settings(self, **kwargs):
"""
Check if the current settings are the same as the current transforms.
"""
for key, value in kwargs.items():
if value not in self._transform_settings or value != self._transform_settings[key]:
return False
return True

def _build_transforms(
self,
do_resize: bool,
Expand All @@ -201,49 +171,6 @@ def _build_transforms(
transforms.append(Normalize(image_mean, image_std))
return Compose(transforms)

def set_transforms(
self,
do_resize: bool,
do_rescale: bool,
do_normalize: bool,
size: Dict[str, int],
resample: PILImageResampling,
rescale_factor: float,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
):
if self._same_transforms_settings(
do_resize=do_resize,
do_rescale=do_rescale,
do_normalize=do_normalize,
size=size,
resample=resample,
rescale_factor=rescale_factor,
image_mean=image_mean,
image_std=image_std,
):
return self._transforms
transforms = self._build_transforms(
do_resize=do_resize,
size_dict=size,
resample=resample,
do_rescale=do_rescale,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
)
self._set_transform_settings(
do_resize=do_resize,
do_rescale=do_rescale,
do_normalize=do_normalize,
size=size,
resample=resample,
rescale_factor=rescale_factor,
image_mean=image_mean,
image_std=image_std,
)
self._transforms = transforms

@cache
def _validate_input_arguments(
self,
Expand All @@ -270,40 +197,6 @@ def _validate_input_arguments(
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

@cache
def _maybe_update_transforms(
self,
do_resize: bool,
size: Dict[str, int],
resample: PILImageResampling,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
):
if self._same_transforms_settings(
do_resize=do_resize,
do_rescale=do_rescale,
do_normalize=do_normalize,
size=size,
resample=resample,
rescale_factor=rescale_factor,
image_mean=image_mean,
image_std=image_std,
):
return
self.set_transforms(
do_resize=do_resize,
do_rescale=do_rescale,
do_normalize=do_normalize,
size=size,
resample=resample,
rescale_factor=rescale_factor,
image_mean=image_mean,
image_std=image_std,
)

def preprocess(
self,
images: ImageInput,
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/utils/dummy_vision_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])


class BaseImageProcessor(metaclass=DummyObject):
_backends = ["vision"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])


class BaseImageProcessorFast(metaclass=DummyObject):
_backends = ["vision"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])


class ImageFeatureExtractionMixin(metaclass=DummyObject):
_backends = ["vision"]

Expand Down

0 comments on commit d4722d4

Please sign in to comment.