Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move out input validation into base image processor #30486

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion src/transformers/image_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .dynamic_module_utils import custom_object_save
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .image_transforms import center_crop, normalize, rescale
from .image_utils import ChannelDimension
from .image_utils import ChannelDimension, is_scaled_image, to_numpy_array, valid_images
from .utils import (
IMAGE_PROCESSOR_NAME,
PushToHubMixin,
Expand All @@ -44,9 +44,60 @@
if is_vision_available():
from PIL import Image

from .image_utils import PILImageResampling

logger = logging.get_logger(__name__)


def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
if unused_keys:
unused_key_str = ", ".join(unused_keys)
# TODO raise a warning here instead of simply logging?
logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")


def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
):
"""
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
existing arguments when possible.

"""
if do_rescale and rescale_factor is None:
raise ValueError("rescale_factor must be specified if do_rescale is True.")

if do_pad and size_divisibility is None:
# Here, size_divisor might be passed as the value of size
raise ValueError(
"Depending on moel, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True."
)

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("image_mean and image_std must both be specified if do_normalize is True.")

if do_center_crop and crop_size is None:
raise ValueError("crop_size must be specified if do_center_crop is True.")

if do_resize and (size is None or resample is None):
raise ValueError("size and resample must be specified if do_resize is True.")


# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
# We override the class string here, but logic is the same.
class BatchFeature(BaseBatchFeature):
Expand Down Expand Up @@ -543,13 +594,64 @@ def fetch_images(self, image_url_or_urls: Union[str, List[str]]):


class BaseImageProcessor(ImageProcessingMixin):
_valid_processor_keys = None

def __init__(self, **kwargs):
super().__init__(**kwargs)

def __call__(self, images, **kwargs) -> BatchFeature:
"""Preprocess an image or a batch of images."""
self._validate_inputs(images, **kwargs)
return self.preprocess(images, **kwargs)

def _validate_preprocess_arguments(self, **kwargs):
"""Check if the arguments passed to the preprocess method have compatible settings e.g. if `size` is defined when `do_resize` is set to True."""
validate_preprocess_arguments(**kwargs)

def _validate_image_inputs(self, images, segmentation_maps=None, do_rescale=False):
"""Check if the images and segmentation maps are valid."""
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray."
)

if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_rescale and is_scaled_image(to_numpy_array(images[0])):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)

def _validate_inputs(self, images, segmentation_maps=None, **kwargs):
"""Check if the arguments passed to the preprocess method are valid."""
return_tensors = kwargs.pop("return_tensors", None)
input_data_format = kwargs.pop("input_data_format", None)
data_format = kwargs.pop("data_format", None)

if return_tensors not in (None, "np", "pt", "tf", "jax"):
raise ValueError("return_tensors should be one of 'np', 'pt', 'tf', 'jax'.")

if input_data_format not in (None, ChannelDimension.FIRST, ChannelDimension.LAST):
raise ValueError("input_data_format should be one of 'channels_first' or 'channels_last'.")

if data_format not in (None, ChannelDimension.FIRST, ChannelDimension.LAST):
raise ValueError("data_format should be one of 'channels_first' or 'channels_last'.")

if self._valid_processor_keys is None:
raise ValueError("Each image processor must define self._valid_processor_keys")

for key in kwargs:
if key not in self._valid_processor_keys:
raise ValueError(f"Invalid argument {key} passed to preprocess method.")

validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
self._validate_preprocess_arguments(**kwargs)
self._validate_image_inputs(images, segmentation_maps, do_rescale=kwargs.get("do_rescale", False))

def preprocess(self, images, **kwargs) -> BatchFeature:
raise NotImplementedError("Each image processor must implement its own preprocess method")

Expand Down
49 changes: 0 additions & 49 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,47 +337,6 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image


def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
):
"""
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
existing arguments when possible.

"""
if do_rescale and rescale_factor is None:
raise ValueError("rescale_factor must be specified if do_rescale is True.")

if do_pad and size_divisibility is None:
# Here, size_divisor might be passed as the value of size
raise ValueError(
"Depending on moel, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True."
)

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("image_mean and image_std must both be specified if do_normalize is True.")

if do_center_crop and crop_size is None:
raise ValueError("crop_size must be specified if do_center_crop is True.")

if do_resize and (size is None or resample is None):
raise ValueError("size and resample must be specified if do_resize is True.")


# In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin:
"""
Expand Down Expand Up @@ -759,11 +718,3 @@ def validate_annotations(
"(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
"the latter being a list of annotations in the COCO format."
)


def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
if unused_keys:
unused_key_str = ", ".join(unused_keys)
# TODO raise a warning here instead of simply logging?
logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
35 changes: 0 additions & 35 deletions src/transformers/models/beit/image_processing_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging

Expand Down Expand Up @@ -257,11 +253,6 @@ def _preprocess_image(
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
image = to_numpy_array(image)
if is_scaled_image(image) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess(
Expand Down Expand Up @@ -418,37 +409,11 @@ def preprocess(
image_std = image_std if image_std is not None else self.image_std
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels

validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)

images = make_list_of_images(images)

if segmentation_maps is not None:
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)

if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation_maps type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)

images = [
self._preprocess_image(
image=img,
Expand Down
31 changes: 0 additions & 31 deletions src/transformers/models/bit/image_processing_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,8 @@
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -274,42 +270,15 @@ def preprocess(
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb

validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)

images = make_list_of_images(images)

if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)

# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]

# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]

if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
Expand Down
28 changes: 0 additions & 28 deletions src/transformers/models/blip/image_processing_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -250,37 +246,13 @@ def preprocess(

images = make_list_of_images(images)

validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)

if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]

# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]

if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
Expand Down
Loading
Loading