diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 99eac953bc3208..2f2868507fb362 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -337,6 +337,47 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = return image +def validate_preprocess_arguments( + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + size_divisibility: Optional[int] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[Dict[str, int]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: Optional["PILImageResampling"] = None, +): + """ + Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method. + Raises `ValueError` if arguments incompatibility is caught. + Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`, + sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow + existing arguments when possible. + + """ + if do_rescale and rescale_factor is None: + raise ValueError("rescale_factor must be specified if do_rescale is True.") + + if do_pad and size_divisibility is None: + # Here, size_divisor might be passed as the value of size + raise ValueError( + "Depending on moel, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True." + ) + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError("image_mean and image_std must both be specified if do_normalize is True.") + + if do_center_crop and crop_size is None: + raise ValueError("crop_size must be specified if do_center_crop is True.") + + if do_resize and (size is None or resample is None): + raise ValueError("size and resample must be specified if do_resize is True.") + + # In the future we can add a TF implementation here when we have TF models. class ImageFeatureExtractionMixin: """ diff --git a/src/transformers/models/beit/image_processing_beit.py b/src/transformers/models/beit/image_processing_beit.py index 6f8ce403e0a59c..52c1a813f6091a 100644 --- a/src/transformers/models/beit/image_processing_beit.py +++ b/src/transformers/models/beit/image_processing_beit.py @@ -32,6 +32,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging @@ -396,32 +397,33 @@ def preprocess( do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels images = make_list_of_images(images) + if segmentation_maps is not None: segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) - if not valid_images(images): + if segmentation_maps is not None and not valid_images(segmentation_maps): raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "Invalid segmentation_maps type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if segmentation_maps is not None and not valid_images(segmentation_maps): + if not valid_images(images): raise ValueError( - "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) images = [ self._preprocess_image( diff --git a/src/transformers/models/bit/image_processing_bit.py b/src/transformers/models/bit/image_processing_bit.py index 7aa49145ae0527..df9336c347955b 100644 --- a/src/transformers/models/bit/image_processing_bit.py +++ b/src/transformers/models/bit/image_processing_bit.py @@ -36,6 +36,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -263,17 +264,18 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # PIL RGBA images are converted to RGB if do_convert_rgb: diff --git a/src/transformers/models/blip/image_processing_blip.py b/src/transformers/models/blip/image_processing_blip.py index d915c5e48b3f56..fa65624937f35e 100644 --- a/src/transformers/models/blip/image_processing_blip.py +++ b/src/transformers/models/blip/image_processing_blip.py @@ -31,6 +31,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -239,15 +240,16 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") - + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) # PIL RGBA images are converted to RGB if do_convert_rgb: images = [convert_to_rgb(image) for image in images] diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower.py b/src/transformers/models/bridgetower/image_processing_bridgetower.py index 2332fa7bc70df6..3053c72a4c5bb7 100644 --- a/src/transformers/models/bridgetower/image_processing_bridgetower.py +++ b/src/transformers/models/bridgetower/image_processing_bridgetower.py @@ -32,6 +32,7 @@ is_scaled_image, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -128,7 +129,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor): do_resize (`bool`, *optional*, defaults to `True`): Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the `do_resize` parameter in the `preprocess` method. - size (`Dict[str, int]` *optional*, defaults to 288): + size (`Dict[str, int]` *optional*, defaults to `{'shortest_edge': 288}`): Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under `int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if `do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method. @@ -158,6 +159,9 @@ class BridgeTowerImageProcessor(BaseImageProcessor): do_center_crop (`bool`, *optional*, defaults to `True`): Whether to center crop the image. Can be overridden by the `do_center_crop` parameter in the `preprocess` method. + crop_size (`Dict[str, int]`, *optional*): + Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`. + Can be overridden by the `crop_size` parameter in the `preprocess` method. If unset defaults to `size`, do_pad (`bool`, *optional*, defaults to `True`): Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by the `do_pad` parameter in the `preprocess` method. @@ -168,7 +172,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor): def __init__( self, do_resize: bool = True, - size: Dict[str, int] = 288, + size: Dict[str, int] = None, size_divisor: int = 32, resample: PILImageResampling = PILImageResampling.BICUBIC, do_rescale: bool = True, @@ -177,6 +181,7 @@ def __init__( image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_center_crop: bool = True, + crop_size: Dict[str, int] = None, do_pad: bool = True, **kwargs, ) -> None: @@ -198,6 +203,7 @@ def __init__( self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.do_pad = do_pad self.do_center_crop = do_center_crop + self.crop_size = crop_size # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.resize def resize( @@ -378,6 +384,7 @@ def preprocess( image_std: Optional[Union[float, List[float]]] = None, do_pad: Optional[bool] = None, do_center_crop: Optional[bool] = None, + crop_size: Dict[str, int] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -417,6 +424,9 @@ def preprocess( do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image is padded with 0's and then center cropped. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be + padded with zeros and then cropped return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. @@ -446,6 +456,11 @@ def preprocess( image_std = image_std if image_std is not None else self.image_std do_pad = do_pad if do_pad is not None else self.do_pad do_center_crop if do_center_crop is not None else self.do_center_crop + # For backwards compatibility. Initial version of this processor was cropping to the "size" argument, which + # it should default to if crop_size is undefined. + crop_size = ( + crop_size if crop_size is not None else (self.crop_size if self.crop_size is not None else self.size) + ) size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) @@ -458,16 +473,21 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") - + # Here, crop_size is used only if it is set, else size will be used. + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size_divisor, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] @@ -491,7 +511,7 @@ def preprocess( if do_center_crop: images = [ - self.center_crop(image=image, size=size, input_data_format=input_data_format) for image in images + self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images ] if do_rescale: diff --git a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py index 4f1048a45e6ac6..0216bc5431ea7f 100644 --- a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py +++ b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py @@ -36,6 +36,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -251,20 +252,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") - - # PIL RGBA images are converted to RGB + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) if do_convert_rgb: images = [convert_to_rgb(image) for image in images] diff --git a/src/transformers/models/clip/image_processing_clip.py b/src/transformers/models/clip/image_processing_clip.py index 2c829d0aab948a..6549a572d864f3 100644 --- a/src/transformers/models/clip/image_processing_clip.py +++ b/src/transformers/models/clip/image_processing_clip.py @@ -36,6 +36,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -265,20 +266,19 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") - - # PIL RGBA images are converted to RGB if do_convert_rgb: images = [convert_to_rgb(image) for image in images] diff --git a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py index d266ef9a899ea6..0af79bbcb93efa 100644 --- a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py +++ b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py @@ -49,6 +49,7 @@ to_numpy_array, valid_images, validate_annotations, + validate_preprocess_arguments, ) from ...utils import ( TensorType, @@ -1291,16 +1292,27 @@ def preprocess( do_pad = self.do_pad if do_pad is None else do_pad format = self.format if format is None else format - if do_resize is not None and size is None: - raise ValueError("Size and max_size must be specified if do_resize is True.") + images = make_list_of_images(images) - if do_rescale is not None and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) - if do_normalize is not None and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) - images = make_list_of_images(images) if annotations is not None and isinstance(annotations, dict): annotations = [annotations] @@ -1309,12 +1321,6 @@ def preprocess( f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." ) - if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) - format = AnnotationFormat(format) if annotations is not None: validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) diff --git a/src/transformers/models/convnext/image_processing_convnext.py b/src/transformers/models/convnext/image_processing_convnext.py index 09944527bbb905..6d6476e77214b0 100644 --- a/src/transformers/models/convnext/image_processing_convnext.py +++ b/src/transformers/models/convnext/image_processing_convnext.py @@ -36,6 +36,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -267,17 +268,16 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - - if do_resize and size["shortest_edge"] < 384 and crop_pct is None: - raise ValueError("crop_pct must be specified if size < 384.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py index 5bedc7d15e752f..ef4dc7f3e5763f 100644 --- a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py +++ b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py @@ -49,6 +49,7 @@ to_numpy_array, valid_images, validate_annotations, + validate_preprocess_arguments, ) from ...utils import ( TensorType, @@ -1289,16 +1290,27 @@ def preprocess( do_pad = self.do_pad if do_pad is None else do_pad format = self.format if format is None else format - if do_resize is not None and size is None: - raise ValueError("Size and max_size must be specified if do_resize is True.") + images = make_list_of_images(images) - if do_rescale is not None and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) - if do_normalize is not None and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) - images = make_list_of_images(images) if annotations is not None and isinstance(annotations, dict): annotations = [annotations] @@ -1307,12 +1319,6 @@ def preprocess( f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." ) - if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) - format = AnnotationFormat(format) if annotations is not None: validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) diff --git a/src/transformers/models/deit/image_processing_deit.py b/src/transformers/models/deit/image_processing_deit.py index 96425278adbd17..15e820570c08fe 100644 --- a/src/transformers/models/deit/image_processing_deit.py +++ b/src/transformers/models/deit/image_processing_deit.py @@ -31,6 +31,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -244,19 +245,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") - + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/deta/image_processing_deta.py b/src/transformers/models/deta/image_processing_deta.py index 69dc8bafd7ef4f..45c5c6cb285a8f 100644 --- a/src/transformers/models/deta/image_processing_deta.py +++ b/src/transformers/models/deta/image_processing_deta.py @@ -46,6 +46,7 @@ to_numpy_array, valid_images, validate_annotations, + validate_preprocess_arguments, ) from ...utils import ( is_flax_available, @@ -955,29 +956,32 @@ def preprocess( do_pad = self.do_pad if do_pad is None else do_pad format = self.format if format is None else format - if do_resize is not None and size is None: - raise ValueError("Size and max_size must be specified if do_resize is True.") - - if do_rescale is not None and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize is not None and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) if not is_batched(images): images = [images] annotations = [annotations] if annotations is not None else None - if annotations is not None and len(images) != len(annotations): - raise ValueError( - f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." - ) - if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) + if annotations is not None and len(images) != len(annotations): + raise ValueError( + f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." + ) format = AnnotationFormat(format) if annotations is not None: diff --git a/src/transformers/models/detr/image_processing_detr.py b/src/transformers/models/detr/image_processing_detr.py index e481321dabf889..0a7a6e2dbd5c38 100644 --- a/src/transformers/models/detr/image_processing_detr.py +++ b/src/transformers/models/detr/image_processing_detr.py @@ -48,6 +48,7 @@ to_numpy_array, valid_images, validate_annotations, + validate_preprocess_arguments, ) from ...utils import ( TensorType, @@ -1261,16 +1262,27 @@ def preprocess( do_pad = self.do_pad if do_pad is None else do_pad format = self.format if format is None else format - if do_resize is not None and size is None: - raise ValueError("Size and max_size must be specified if do_resize is True.") + images = make_list_of_images(images) - if do_rescale is not None and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) - if do_normalize is not None and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) - images = make_list_of_images(images) if annotations is not None and isinstance(annotations, dict): annotations = [annotations] @@ -1279,12 +1291,6 @@ def preprocess( f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." ) - if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) - format = AnnotationFormat(format) if annotations is not None: validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) diff --git a/src/transformers/models/donut/image_processing_donut.py b/src/transformers/models/donut/image_processing_donut.py index 2a1672e22041fb..a17593316248ac 100644 --- a/src/transformers/models/donut/image_processing_donut.py +++ b/src/transformers/models/donut/image_processing_donut.py @@ -37,6 +37,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, logging from ...utils.import_utils import is_vision_available @@ -392,18 +393,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_pad and size is None: - raise ValueError("Size must be specified if do_pad is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg. + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/dpt/image_processing_dpt.py b/src/transformers/models/dpt/image_processing_dpt.py index ec1b8fead27747..29aac9d005b406 100644 --- a/src/transformers/models/dpt/image_processing_dpt.py +++ b/src/transformers/models/dpt/image_processing_dpt.py @@ -35,6 +35,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -354,19 +355,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") - - if do_pad and size_divisor is None: - raise ValueError("Size divisibility must be specified if do_pad is True.") - + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size_divisor, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/efficientformer/image_processing_efficientformer.py b/src/transformers/models/efficientformer/image_processing_efficientformer.py index be8477678c5f98..7db37c20b7f9dc 100644 --- a/src/transformers/models/efficientformer/image_processing_efficientformer.py +++ b/src/transformers/models/efficientformer/image_processing_efficientformer.py @@ -35,6 +35,7 @@ is_scaled_image, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, logging @@ -245,16 +246,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/efficientnet/image_processing_efficientnet.py b/src/transformers/models/efficientnet/image_processing_efficientnet.py index 5f75d1692e8847..ee4690e0fb9cc4 100644 --- a/src/transformers/models/efficientnet/image_processing_efficientnet.py +++ b/src/transformers/models/efficientnet/image_processing_efficientnet.py @@ -31,6 +31,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -301,19 +302,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") - + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/flava/image_processing_flava.py b/src/transformers/models/flava/image_processing_flava.py index b098b7c634dd96..168e3e8e2e3ff4 100644 --- a/src/transformers/models/flava/image_processing_flava.py +++ b/src/transformers/models/flava/image_processing_flava.py @@ -34,6 +34,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -403,14 +404,19 @@ def _preprocess_image( input_data_format: Optional[ChannelDimension] = None, ) -> np.ndarray: """Preprocesses a single image.""" - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. image = to_numpy_array(image) diff --git a/src/transformers/models/fuyu/image_processing_fuyu.py b/src/transformers/models/fuyu/image_processing_fuyu.py index 2257dfa8e918b9..70ff3e725d2e00 100644 --- a/src/transformers/models/fuyu/image_processing_fuyu.py +++ b/src/transformers/models/fuyu/image_processing_fuyu.py @@ -35,6 +35,7 @@ is_valid_image, make_list_of_images, to_numpy_array, + validate_preprocess_arguments, ) from ...utils import ( TensorType, @@ -446,15 +447,18 @@ def preprocess( batch_images = make_list_of_list_of_images(images) - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and image_mean is None or image_std is None: - raise ValueError("image_mean and image_std must be specified if do_normalize is True.") - + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg. + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. batch_images = [[to_numpy_array(image) for image in images] for images in batch_images] diff --git a/src/transformers/models/glpn/image_processing_glpn.py b/src/transformers/models/glpn/image_processing_glpn.py index afed9188f7abac..2be3e3c90b3751 100644 --- a/src/transformers/models/glpn/image_processing_glpn.py +++ b/src/transformers/models/glpn/image_processing_glpn.py @@ -30,6 +30,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, logging @@ -173,13 +174,21 @@ def preprocess( size_divisor = size_divisor if size_divisor is not None else self.size_divisor resample = resample if resample is not None else self.resample - if do_resize and size_divisor is None: - raise ValueError("size_divisor is required for resizing") - images = make_list_of_images(images) if not valid_images(images): - raise ValueError("Invalid image(s)") + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + # Here, the rescale() method uses a constant rescale_factor. It does not need to be validated + # with a rescale_factor. + validate_preprocess_arguments( + do_resize=do_resize, + size=size_divisor, # Here, size_divisor is used as a parameter for optimal resizing instead of size. + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(img) for img in images] diff --git a/src/transformers/models/imagegpt/image_processing_imagegpt.py b/src/transformers/models/imagegpt/image_processing_imagegpt.py index ad421c910536fc..d85803a5a611c8 100644 --- a/src/transformers/models/imagegpt/image_processing_imagegpt.py +++ b/src/transformers/models/imagegpt/image_processing_imagegpt.py @@ -29,6 +29,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -243,8 +244,13 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") + # Here, normalize() is using a constant factor to divide pixel values. + # hence, the method does not need iamge_mean and image_std. + validate_preprocess_arguments( + do_resize=do_resize, + size=size, + resample=resample, + ) if do_color_quantize and clusters is None: raise ValueError("Clusters must be specified if do_color_quantize is True.") diff --git a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py index b1e6c0731d2954..a56cb8dd10a417 100644 --- a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py @@ -28,6 +28,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_pytesseract_available, is_vision_available, logging, requires_backends @@ -248,9 +249,11 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") + validate_preprocess_arguments( + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py index 26a5c7a1641837..c2461ad60dae4f 100644 --- a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py @@ -31,6 +31,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_pytesseract_available, is_vision_available, logging, requires_backends @@ -295,7 +296,6 @@ def preprocess( apply_ocr = apply_ocr if apply_ocr is not None else self.apply_ocr ocr_lang = ocr_lang if ocr_lang is not None else self.ocr_lang tesseract_config = tesseract_config if tesseract_config is not None else self.tesseract_config - images = make_list_of_images(images) if not valid_images(images): @@ -303,15 +303,16 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("If do_normalize is True, image_mean and image_std must be specified.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/levit/image_processing_levit.py b/src/transformers/models/levit/image_processing_levit.py index 77de1ec33366dc..a21e5750c7048a 100644 --- a/src/transformers/models/levit/image_processing_levit.py +++ b/src/transformers/models/levit/image_processing_levit.py @@ -35,6 +35,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, logging @@ -251,7 +252,6 @@ def preprocess( size = get_size_dict(size, default_to_square=False) crop_size = crop_size if crop_size is not None else self.crop_size crop_size = get_size_dict(crop_size, param_name="crop_size") - images = make_list_of_images(images) if not valid_images(images): @@ -259,19 +259,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") - + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/mask2former/image_processing_mask2former.py b/src/transformers/models/mask2former/image_processing_mask2former.py index 3a6d6f783b535d..154a531c8b0d72 100644 --- a/src/transformers/models/mask2former/image_processing_mask2former.py +++ b/src/transformers/models/mask2former/image_processing_mask2former.py @@ -39,6 +39,7 @@ is_scaled_image, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import ( IMAGENET_DEFAULT_MEAN, @@ -707,21 +708,23 @@ def preprocess( ignore_index = ignore_index if ignore_index is not None else self.ignore_index reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels - if do_resize is not None and size is None or size_divisor is None: - raise ValueError("If `do_resize` is True, `size` and `size_divisor` must be provided.") - - if do_rescale is not None and rescale_factor is None: - raise ValueError("If `do_rescale` is True, `rescale_factor` must be provided.") - - if do_normalize is not None and (image_mean is None or image_std is None): - raise ValueError("If `do_normalize` is True, `image_mean` and `image_std` must be provided.") - if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + if segmentation_maps is not None and not valid_images(segmentation_maps): raise ValueError( "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " diff --git a/src/transformers/models/maskformer/image_processing_maskformer.py b/src/transformers/models/maskformer/image_processing_maskformer.py index 151868eb235b08..a5d940c6531482 100644 --- a/src/transformers/models/maskformer/image_processing_maskformer.py +++ b/src/transformers/models/maskformer/image_processing_maskformer.py @@ -39,6 +39,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import ( IMAGENET_DEFAULT_MEAN, @@ -724,20 +725,21 @@ def preprocess( ignore_index = ignore_index if ignore_index is not None else self.ignore_index do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels - if do_resize is not None and size is None or size_divisor is None: - raise ValueError("If `do_resize` is True, `size` and `size_divisor` must be provided.") - - if do_rescale is not None and rescale_factor is None: - raise ValueError("If `do_rescale` is True, `rescale_factor` must be provided.") - - if do_normalize is not None and (image_mean is None or image_std is None): - raise ValueError("If `do_normalize` is True, `image_mean` and `image_std` must be provided.") - if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) if segmentation_maps is not None and not valid_images(segmentation_maps): raise ValueError( diff --git a/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py b/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py index 73bb296d7ed144..9f59c17d1d5487 100644 --- a/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +++ b/src/transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py @@ -35,6 +35,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, logging @@ -249,18 +250,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py index aa97d854d7f47a..dcf82e8d1681b6 100644 --- a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py @@ -35,6 +35,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_torch_available, is_torch_tensor, logging @@ -253,19 +254,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") - + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit.py b/src/transformers/models/mobilevit/image_processing_mobilevit.py index 2e7433fa02b8c7..32bbf3d5d36f56 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit.py @@ -29,6 +29,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging @@ -368,6 +369,8 @@ def preprocess( if segmentation_maps is not None: segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) + images = make_list_of_images(images) + if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " @@ -380,14 +383,15 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) images = [ self._preprocess_image( diff --git a/src/transformers/models/nougat/image_processing_nougat.py b/src/transformers/models/nougat/image_processing_nougat.py index 882614059f9df6..448c9f21c4a181 100644 --- a/src/transformers/models/nougat/image_processing_nougat.py +++ b/src/transformers/models/nougat/image_processing_nougat.py @@ -38,6 +38,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, logging from ...utils.import_utils import is_cv2_available, is_vision_available @@ -446,18 +447,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_pad and size is None: - raise ValueError("Size must be specified if do_pad is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg. + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/oneformer/image_processing_oneformer.py b/src/transformers/models/oneformer/image_processing_oneformer.py index 8eb286475cb4ad..23b3fa69569f13 100644 --- a/src/transformers/models/oneformer/image_processing_oneformer.py +++ b/src/transformers/models/oneformer/image_processing_oneformer.py @@ -42,6 +42,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import ( IMAGENET_DEFAULT_MEAN, @@ -708,20 +709,21 @@ def preprocess( ignore_index = ignore_index if ignore_index is not None else self.ignore_index do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels - if do_resize is not None and size is None: - raise ValueError("If `do_resize` is True, `size` must be provided.") - - if do_rescale is not None and rescale_factor is None: - raise ValueError("If `do_rescale` is True, `rescale_factor` must be provided.") - - if do_normalize is not None and (image_mean is None or image_std is None): - raise ValueError("If `do_normalize` is True, `image_mean` and `image_std` must be provided.") - if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) if segmentation_maps is not None and not valid_images(segmentation_maps): raise ValueError( diff --git a/src/transformers/models/owlv2/image_processing_owlv2.py b/src/transformers/models/owlv2/image_processing_owlv2.py index bb309b40d3141e..21f09060cd0b9e 100644 --- a/src/transformers/models/owlv2/image_processing_owlv2.py +++ b/src/transformers/models/owlv2/image_processing_owlv2.py @@ -37,6 +37,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import ( TensorType, @@ -405,15 +406,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + # Here, pad and resize methods are different from the rest of image processors + # as they don't have any resampling in resize() + # or pad size in pad() (the maximum of (height, width) is taken instead). + # hence, these arguments don't need to be passed in validate_preprocess_arguments. + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + size=size, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/owlvit/image_processing_owlvit.py b/src/transformers/models/owlvit/image_processing_owlvit.py index d190bc1d636ea3..961707725db75c 100644 --- a/src/transformers/models/owlvit/image_processing_owlvit.py +++ b/src/transformers/models/owlvit/image_processing_owlvit.py @@ -38,6 +38,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_torch_available, logging @@ -348,18 +349,6 @@ def preprocess( image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std - if do_resize is not None and size is None: - raise ValueError("Size and max_size must be specified if do_resize is True.") - - if do_center_crop is not None and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale is not None and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize is not None and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") - images = make_list_of_images(images) if not valid_images(images): @@ -368,6 +357,19 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + # All transformations expect numpy arrays images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/perceiver/image_processing_perceiver.py b/src/transformers/models/perceiver/image_processing_perceiver.py index 272cf32fa5eb97..599e48d77a0f0e 100644 --- a/src/transformers/models/perceiver/image_processing_perceiver.py +++ b/src/transformers/models/perceiver/image_processing_perceiver.py @@ -32,6 +32,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -290,18 +291,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_center_crop and crop_size is None: - raise ValueError("If `do_center_crop` is set to `True`, `crop_size` must be provided.") - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and image standard deviation must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/poolformer/image_processing_poolformer.py b/src/transformers/models/poolformer/image_processing_poolformer.py index b5773d3146f437..dab7392fbb08f6 100644 --- a/src/transformers/models/poolformer/image_processing_poolformer.py +++ b/src/transformers/models/poolformer/image_processing_poolformer.py @@ -35,6 +35,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -297,18 +298,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - - if do_center_crop and crop_pct is None: - raise ValueError("Crop_pct must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/pvt/image_processing_pvt.py b/src/transformers/models/pvt/image_processing_pvt.py index 37d65778b07356..ada7eaec4aaabd 100644 --- a/src/transformers/models/pvt/image_processing_pvt.py +++ b/src/transformers/models/pvt/image_processing_pvt.py @@ -31,6 +31,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, logging @@ -222,12 +223,16 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/sam/image_processing_sam.py b/src/transformers/models/sam/image_processing_sam.py index 5b208dd34a5a25..911e3fd0ff5a9e 100644 --- a/src/transformers/models/sam/image_processing_sam.py +++ b/src/transformers/models/sam/image_processing_sam.py @@ -34,6 +34,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import ( TensorType, @@ -504,18 +505,18 @@ def preprocess( "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and (size is None or resample is None): - raise ValueError("Size and resample must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") - - if do_pad and pad_size is None: - raise ValueError("Pad size must be specified if do_pad is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=pad_size, # Here _preprocess needs do_pad and pad_size. + do_resize=do_resize, + size=size, + resample=resample, + ) images, original_sizes, reshaped_input_sizes = zip( *( diff --git a/src/transformers/models/segformer/image_processing_segformer.py b/src/transformers/models/segformer/image_processing_segformer.py index 57f2628a9cd36e..ff12108a301a3b 100644 --- a/src/transformers/models/segformer/image_processing_segformer.py +++ b/src/transformers/models/segformer/image_processing_segformer.py @@ -32,6 +32,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging @@ -387,21 +388,16 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if segmentation_maps is not None and not valid_images(segmentation_maps): - raise ValueError( - "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) - - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) images = [ self._preprocess_image( diff --git a/src/transformers/models/siglip/image_processing_siglip.py b/src/transformers/models/siglip/image_processing_siglip.py index 285b6e9e559f32..7796a6e3d290c4 100644 --- a/src/transformers/models/siglip/image_processing_siglip.py +++ b/src/transformers/models/siglip/image_processing_siglip.py @@ -32,6 +32,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -178,13 +179,16 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/swin2sr/image_processing_swin2sr.py b/src/transformers/models/swin2sr/image_processing_swin2sr.py index 95eafb3d01d95c..d86b1e28e8dd50 100644 --- a/src/transformers/models/swin2sr/image_processing_swin2sr.py +++ b/src/transformers/models/swin2sr/image_processing_swin2sr.py @@ -28,6 +28,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, logging @@ -165,9 +166,12 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_pad=do_pad, + size_divisibility=pad_size, # Here the pad function simply requires pad_size. + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/tvlt/image_processing_tvlt.py b/src/transformers/models/tvlt/image_processing_tvlt.py index f5860b2c1dcca5..618dcf089048f2 100644 --- a/src/transformers/models/tvlt/image_processing_tvlt.py +++ b/src/transformers/models/tvlt/image_processing_tvlt.py @@ -34,6 +34,7 @@ is_valid_image, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, logging @@ -212,17 +213,19 @@ def _preprocess_image( input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: """Preprocesses a single image.""" - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. image = to_numpy_array(image) diff --git a/src/transformers/models/tvp/image_processing_tvp.py b/src/transformers/models/tvp/image_processing_tvp.py index 5363d504319520..b14e2ce264f04d 100644 --- a/src/transformers/models/tvp/image_processing_tvp.py +++ b/src/transformers/models/tvp/image_processing_tvp.py @@ -36,6 +36,7 @@ is_valid_image, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -285,20 +286,21 @@ def _preprocess_image( **kwargs, ) -> np.ndarray: """Preprocesses a single image.""" - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_pad and pad_size is None: - raise ValueError("Padding size must be specified if do_pad is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=pad_size, # here the pad() method simply requires the pad_size argument. + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. image = to_numpy_array(image) diff --git a/src/transformers/models/videomae/image_processing_videomae.py b/src/transformers/models/videomae/image_processing_videomae.py index 6df708eec3ea04..dc69a57f59bd94 100644 --- a/src/transformers/models/videomae/image_processing_videomae.py +++ b/src/transformers/models/videomae/image_processing_videomae.py @@ -35,6 +35,7 @@ is_valid_image, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -191,17 +192,18 @@ def _preprocess_image( input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: """Preprocesses a single image.""" - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. image = to_numpy_array(image) diff --git a/src/transformers/models/vilt/image_processing_vilt.py b/src/transformers/models/vilt/image_processing_vilt.py index 78e44efccf8381..aee5e298630436 100644 --- a/src/transformers/models/vilt/image_processing_vilt.py +++ b/src/transformers/models/vilt/image_processing_vilt.py @@ -32,6 +32,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -421,14 +422,18 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + # Here the pad() method does not require any additional argument as it takes the maximum of (height, width). + # Hence, it does not need to be passed to a validate_preprocess_arguments() method. + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/vit/image_processing_vit.py b/src/transformers/models/vit/image_processing_vit.py index be806d94c4d2f2..63886bef4ca466 100644 --- a/src/transformers/models/vit/image_processing_vit.py +++ b/src/transformers/models/vit/image_processing_vit.py @@ -31,6 +31,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, logging @@ -221,12 +222,16 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py b/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py index 1e4b0652ff5b4e..cf27d204456470 100644 --- a/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py @@ -36,6 +36,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, is_vision_available, logging @@ -262,18 +263,18 @@ def preprocess( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - - if do_resize and size is None: - raise ValueError("Size must be specified if do_resize is True.") - - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) # PIL RGBA images are converted to RGB if do_convert_rgb: diff --git a/src/transformers/models/vitmatte/image_processing_vitmatte.py b/src/transformers/models/vitmatte/image_processing_vitmatte.py index 602b1fbefa8cea..fe7767e8e1379b 100644 --- a/src/transformers/models/vitmatte/image_processing_vitmatte.py +++ b/src/transformers/models/vitmatte/image_processing_vitmatte.py @@ -31,6 +31,7 @@ make_list_of_images, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import TensorType, logging @@ -197,25 +198,28 @@ def preprocess( images = make_list_of_images(images) trimaps = make_list_of_images(trimaps, expected_ndims=2) - if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) if not valid_images(trimaps): raise ValueError( "Invalid trimap type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_pad and size_divisibility is None: - raise ValueError("Size divisilibyt must be specified if do_pad is True.") + images = make_list_of_images(images) - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size_divisibility, + ) # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] diff --git a/src/transformers/models/vivit/image_processing_vivit.py b/src/transformers/models/vivit/image_processing_vivit.py index f32dd0d3aea415..664ba6d7098ac9 100644 --- a/src/transformers/models/vivit/image_processing_vivit.py +++ b/src/transformers/models/vivit/image_processing_vivit.py @@ -38,6 +38,7 @@ is_valid_image, to_numpy_array, valid_images, + validate_preprocess_arguments, ) from ...utils import logging @@ -240,17 +241,19 @@ def _preprocess_image( input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: """Preprocesses a single image.""" - if do_resize and size is None or resample is None: - raise ValueError("Size and resample must be specified if do_resize is True.") - if do_center_crop and crop_size is None: - raise ValueError("Crop size must be specified if do_center_crop is True.") - - if do_rescale and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") - - if do_normalize and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) if offset and not do_rescale: raise ValueError("For offset, do_rescale must also be set to True.") diff --git a/src/transformers/models/yolos/image_processing_yolos.py b/src/transformers/models/yolos/image_processing_yolos.py index d964f6f02f4187..6ae30d50a1b0f3 100644 --- a/src/transformers/models/yolos/image_processing_yolos.py +++ b/src/transformers/models/yolos/image_processing_yolos.py @@ -47,6 +47,7 @@ to_numpy_array, valid_images, validate_annotations, + validate_preprocess_arguments, ) from ...utils import ( TensorType, @@ -1185,16 +1186,25 @@ def preprocess( do_pad = self.do_pad if do_pad is None else do_pad format = self.format if format is None else format - if do_resize is not None and size is None: - raise ValueError("Size and max_size must be specified if do_resize is True.") - - if do_rescale is not None and rescale_factor is None: - raise ValueError("Rescale factor must be specified if do_rescale is True.") + images = make_list_of_images(images) - if do_normalize is not None and (image_mean is None or image_std is None): - raise ValueError("Image mean and std must be specified if do_normalize is True.") + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + # Here the pad() method pads using the max of (width, height) and does not need to be validated. + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) - images = make_list_of_images(images) if annotations is not None and isinstance(annotations, dict): annotations = [annotations] @@ -1203,12 +1213,6 @@ def preprocess( f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." ) - if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) - format = AnnotationFormat(format) if annotations is not None: validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)