diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index e84c55d03ae1f8..b8128538bea312 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import cache +import functools from .image_processing_utils import BaseImageProcessor @@ -42,8 +42,6 @@ def _build_transforms(self, **kwargs): raise NotImplementedError def set_transforms(self, **kwargs): - # FIXME - put input validation or kwargs for all these methods - if self._same_transforms_settings(**kwargs): return self._transforms @@ -51,7 +49,7 @@ def set_transforms(self, **kwargs): self._set_transform_settings(**kwargs) self._transforms = transforms - @cache + @functools.lru_cache(maxsize=1) def _maybe_update_transforms(self, **kwargs): if self._same_transforms_settings(**kwargs): return diff --git a/src/transformers/models/vit/image_processing_vit_fast.py b/src/transformers/models/vit/image_processing_vit_fast.py index c05f2847a648cd..7ade1a2c572651 100644 --- a/src/transformers/models/vit/image_processing_vit_fast.py +++ b/src/transformers/models/vit/image_processing_vit_fast.py @@ -119,7 +119,7 @@ def __init__( size: Optional[Dict[str, int]] = None, resample: PILImageResampling = PILImageResampling.BILINEAR, do_rescale: bool = True, - rescale_factor: Union[int, float] = 1 / 255, + rescale_factor: Union[int, float] = None, do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, @@ -190,12 +190,15 @@ def _validate_input_arguments( if data_format != ChannelDimension.FIRST: raise ValueError("Only channel first data format is currently supported.") - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") + if do_resize and None in (size, resample): + raise ValueError("Size and resample must be specified if do_resize is True.") if do_rescale and rescale_factor is None: raise ValueError("Rescale factor must be specified if do_rescale is True.") + if do_normalize and None in (image_mean, image_std): + raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.") + def preprocess( self, images: ImageInput, @@ -238,17 +241,10 @@ def preprocess( image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation to use if `do_normalize` is set to `True`. return_tensors (`str` or `TensorType`, *optional*): - The type of tensors to return. Can be one of: - - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + The type of tensors to return. Only "pt" is supported data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): - The channel dimension format for the output image. Can be one of: + The channel dimension format for the output image. The following formats are currently supported: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - - Unset: Use the channel dimension format of the input image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: @@ -256,6 +252,15 @@ def preprocess( - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ + if return_tensors != "pt": + raise ValueError("Only returning PyTorch tensors is currently supported.") + + if input_data_format is not None and input_data_format != ChannelDimension.FIRST: + raise ValueError("Only channel first data format is currently supported.") + + if data_format != ChannelDimension.FIRST: + raise ValueError("Only channel first data format is currently supported.") + do_resize = do_resize if do_resize is not None else self.do_resize do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_normalize = do_normalize if do_normalize is not None else self.do_normalize