Skip to content

Commit

Permalink
Abstract image processor arg checks. (#28843)
Browse files Browse the repository at this point in the history
* abstract image processor arg checks.

* fix signatures and quality

* add validate_ method to rescale-prone processors

* add more validations

* quality

* quality

* fix formatting

Co-authored-by: amyeroberts <[email protected]>

* fix formatting

Co-authored-by: amyeroberts <[email protected]>

* fix formatting

Co-authored-by: amyeroberts <[email protected]>

* Fix formatting mishap

Co-authored-by: amyeroberts <[email protected]>

* fix crop_size compatibility

* fix default mutable arg

* fix segmentation map + image arg validity

* remove segmentation check from arg validation

* fix quality

* fix missing segmap

* protect PILImageResampling type

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* add back segmentation maps check

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 14, 2024
1 parent f909e23 commit 2ea15af
Show file tree
Hide file tree
Showing 49 changed files with 685 additions and 505 deletions.
41 changes: 41 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,47 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image


def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
):
"""
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
existing arguments when possible.
"""
if do_rescale and rescale_factor is None:
raise ValueError("rescale_factor must be specified if do_rescale is True.")

if do_pad and size_divisibility is None:
# Here, size_divisor might be passed as the value of size
raise ValueError(
"Depending on moel, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True."
)

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("image_mean and image_std must both be specified if do_normalize is True.")

if do_center_crop and crop_size is None:
raise ValueError("crop_size must be specified if do_center_crop is True.")

if do_resize and (size is None or resample is None):
raise ValueError("size and resample must be specified if do_resize is True.")


# In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin:
"""
Expand Down
34 changes: 18 additions & 16 deletions src/transformers/models/beit/image_processing_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging

Expand Down Expand Up @@ -396,32 +397,33 @@ def preprocess(
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels

images = make_list_of_images(images)

if segmentation_maps is not None:
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)

if not valid_images(images):
if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"Invalid segmentation_maps type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if segmentation_maps is not None and not valid_images(segmentation_maps):
if not valid_images(images):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")

if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)

images = [
self._preprocess_image(
Expand Down
24 changes: 13 additions & 11 deletions src/transformers/models/bit/image_processing_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -263,17 +264,18 @@ def preprocess(
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")

if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)

# PIL RGBA images are converted to RGB
if do_convert_rgb:
Expand Down
20 changes: 11 additions & 9 deletions src/transformers/models/blip/image_processing_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -239,15 +240,16 @@ def preprocess(
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")

validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
Expand Down
46 changes: 33 additions & 13 deletions src/transformers/models/bridgetower/image_processing_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
is_scaled_image,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -128,7 +129,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to 288):
size (`Dict[str, int]` *optional*, defaults to `{'shortest_edge': 288}`):
Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under
`int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if
`do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method.
Expand Down Expand Up @@ -158,6 +159,9 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the image. Can be overridden by the `do_center_crop` parameter in the `preprocess`
method.
crop_size (`Dict[str, int]`, *optional*):
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
Can be overridden by the `crop_size` parameter in the `preprocess` method. If unset defaults to `size`,
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by
the `do_pad` parameter in the `preprocess` method.
Expand All @@ -168,7 +172,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = 288,
size: Dict[str, int] = None,
size_divisor: int = 32,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
Expand All @@ -177,6 +181,7 @@ def __init__(
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_pad: bool = True,
**kwargs,
) -> None:
Expand All @@ -198,6 +203,7 @@ def __init__(
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_pad = do_pad
self.do_center_crop = do_center_crop
self.crop_size = crop_size

# Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.resize
def resize(
Expand Down Expand Up @@ -378,6 +384,7 @@ def preprocess(
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
do_center_crop: Optional[bool] = None,
crop_size: Dict[str, int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
Expand Down Expand Up @@ -417,6 +424,9 @@ def preprocess(
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the
image is padded with 0's and then center cropped.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
padded with zeros and then cropped
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
Expand Down Expand Up @@ -446,6 +456,11 @@ def preprocess(
image_std = image_std if image_std is not None else self.image_std
do_pad = do_pad if do_pad is not None else self.do_pad
do_center_crop if do_center_crop is not None else self.do_center_crop
# For backwards compatibility. Initial version of this processor was cropping to the "size" argument, which
# it should default to if crop_size is undefined.
crop_size = (
crop_size if crop_size is not None else (self.crop_size if self.crop_size is not None else self.size)
)

size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
Expand All @@ -458,16 +473,21 @@ def preprocess(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")

# Here, crop_size is used only if it is set, else size will be used.
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size_divisor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]

Expand All @@ -491,7 +511,7 @@ def preprocess(

if do_center_crop:
images = [
self.center_crop(image=image, size=size, input_data_format=input_data_format) for image in images
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]

if do_rescale:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -251,20 +252,18 @@ def preprocess(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")

if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")

# PIL RGBA images are converted to RGB
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]

Expand Down
26 changes: 13 additions & 13 deletions src/transformers/models/clip/image_processing_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -265,20 +266,19 @@ def preprocess(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)

if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")

if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")

# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]

Expand Down
Loading

0 comments on commit 2ea15af

Please sign in to comment.