Skip to content

Commit

Permalink
Tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Apr 25, 2024
1 parent 89f6926 commit 3d9ec2d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
6 changes: 2 additions & 4 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import cache
import functools

from .image_processing_utils import BaseImageProcessor

Expand Down Expand Up @@ -42,16 +42,14 @@ def _build_transforms(self, **kwargs):
raise NotImplementedError

def set_transforms(self, **kwargs):
# FIXME - put input validation or kwargs for all these methods

if self._same_transforms_settings(**kwargs):
return self._transforms

transforms = self._build_transforms(**kwargs)
self._set_transform_settings(**kwargs)
self._transforms = transforms

@cache
@functools.lru_cache(maxsize=1)
def _maybe_update_transforms(self, **kwargs):
if self._same_transforms_settings(**kwargs):
return
Expand Down
29 changes: 17 additions & 12 deletions src/transformers/models/vit/image_processing_vit_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
rescale_factor: Union[int, float] = None,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
Expand Down Expand Up @@ -190,12 +190,15 @@ def _validate_input_arguments(
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")

if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_resize and None in (size, resample):
raise ValueError("Size and resample must be specified if do_resize is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and None in (image_mean, image_std):
raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")

def preprocess(
self,
images: ImageInput,
Expand Down Expand Up @@ -238,24 +241,26 @@ def preprocess(
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use if `do_normalize` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
The type of tensors to return. Only "pt" is supported
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
The channel dimension format for the output image. The following formats are currently supported:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
if return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")

if input_data_format is not None and input_data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")

if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")

do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
Expand Down

0 comments on commit 3d9ec2d

Please sign in to comment.