Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract image processor arg checks. #28843

Merged
merged 23 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,47 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image


def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
):
"""
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
existing arguments when possible.

"""
if do_rescale and rescale_factor is None:
raise ValueError("rescale_factor must be specified if do_rescale is True.")

if do_pad and size_divisibility is None:
# Here, size_divisor might be passed as the value of size
raise ValueError(
"Depending on moel, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True."
)

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("image_mean and image_std must both be specified if do_normalize is True.")

if do_center_crop and crop_size is None:
raise ValueError("crop_size must be specified if do_center_crop is True.")

if do_resize and (size is None or resample is None):
raise ValueError("size and resample must be specified if do_resize is True.")


# In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin:
"""
Expand Down
34 changes: 18 additions & 16 deletions src/transformers/models/beit/image_processing_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging

Expand Down Expand Up @@ -396,32 +397,33 @@ def preprocess(
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels

images = make_list_of_images(images)

if segmentation_maps is not None:
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)

if not valid_images(images):
if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"Invalid segmentation_maps type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if segmentation_maps is not None and not valid_images(segmentation_maps):
if not valid_images(images):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")

if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)

images = [
self._preprocess_image(
Expand Down
24 changes: 13 additions & 11 deletions src/transformers/models/bit/image_processing_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -263,17 +264,18 @@ def preprocess(
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")

if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)

# PIL RGBA images are converted to RGB
if do_convert_rgb:
Expand Down
20 changes: 11 additions & 9 deletions src/transformers/models/blip/image_processing_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -239,15 +240,16 @@ def preprocess(
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")

validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
Expand Down
46 changes: 33 additions & 13 deletions src/transformers/models/bridgetower/image_processing_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
is_scaled_image,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -128,7 +129,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to 288):
size (`Dict[str, int]` *optional*, defaults to `{'shortest_edge': 288}`):
Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under
`int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if
`do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method.
Expand Down Expand Up @@ -158,6 +159,9 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the image. Can be overridden by the `do_center_crop` parameter in the `preprocess`
method.
crop_size (`Dict[str, int]`, *optional*):
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
Can be overridden by the `crop_size` parameter in the `preprocess` method.
molbap marked this conversation as resolved.
Show resolved Hide resolved
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by
the `do_pad` parameter in the `preprocess` method.
Expand All @@ -168,7 +172,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = 288,
size: Dict[str, int] = None,
size_divisor: int = 32,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
Expand All @@ -177,6 +181,7 @@ def __init__(
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_pad: bool = True,
**kwargs,
) -> None:
Expand All @@ -198,6 +203,7 @@ def __init__(
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_pad = do_pad
self.do_center_crop = do_center_crop
self.crop_size = crop_size
molbap marked this conversation as resolved.
Show resolved Hide resolved

# Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor.resize
def resize(
Expand Down Expand Up @@ -378,6 +384,7 @@ def preprocess(
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
do_center_crop: Optional[bool] = None,
crop_size: Dict[str, int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
Expand Down Expand Up @@ -417,6 +424,9 @@ def preprocess(
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the
image is padded with 0's and then center cropped.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
padded with zeros and then cropped
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
Expand Down Expand Up @@ -446,6 +456,11 @@ def preprocess(
image_std = image_std if image_std is not None else self.image_std
do_pad = do_pad if do_pad is not None else self.do_pad
do_center_crop if do_center_crop is not None else self.do_center_crop
# For backwards compatibility. Initial version of this processor was cropping to the "size" argument, which
# it should default to if crop_size is undefined.
crop_size = (
crop_size if crop_size is not None else (self.crop_size if self.crop_size is not None else self.size)
)

size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
Expand All @@ -458,16 +473,21 @@ def preprocess(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")

# Here, crop_size is used only if it is set, else size will be used.
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size_divisor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]

Expand All @@ -491,7 +511,7 @@ def preprocess(

if do_center_crop:
images = [
self.center_crop(image=image, size=size, input_data_format=input_data_format) for image in images
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]

if do_rescale:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -251,20 +252,18 @@ def preprocess(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")

if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")

# PIL RGBA images are converted to RGB
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]

Expand Down
26 changes: 13 additions & 13 deletions src/transformers/models/clip/image_processing_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -265,20 +266,19 @@ def preprocess(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)

if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")

if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")

# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]

Expand Down
Loading
Loading