Skip to content

Commit

Permalink
Move out input validation into base image processor
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Apr 26, 2024
1 parent 20081c7 commit fc226af
Show file tree
Hide file tree
Showing 55 changed files with 481 additions and 1,559 deletions.
89 changes: 88 additions & 1 deletion src/transformers/image_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .dynamic_module_utils import custom_object_save
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .image_transforms import center_crop, normalize, rescale
from .image_utils import ChannelDimension
from .image_utils import ChannelDimension, PILImageResampling, is_scaled_image, to_numpy_array, valid_images
from .utils import (
IMAGE_PROCESSOR_NAME,
PushToHubMixin,
Expand All @@ -47,6 +47,55 @@
logger = logging.get_logger(__name__)


def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
if unused_keys:
unused_key_str = ", ".join(unused_keys)
# TODO raise a warning here instead of simply logging?
logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")


def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
):
"""
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
existing arguments when possible.
"""
if do_rescale and rescale_factor is None:
raise ValueError("rescale_factor must be specified if do_rescale is True.")

if do_pad and size_divisibility is None:
# Here, size_divisor might be passed as the value of size
raise ValueError(
"Depending on moel, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True."
)

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("image_mean and image_std must both be specified if do_normalize is True.")

if do_center_crop and crop_size is None:
raise ValueError("crop_size must be specified if do_center_crop is True.")

if do_resize and (size is None or resample is None):
raise ValueError("size and resample must be specified if do_resize is True.")


# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
# We override the class string here, but logic is the same.
class BatchFeature(BaseBatchFeature):
Expand Down Expand Up @@ -543,13 +592,51 @@ def fetch_images(self, image_url_or_urls: Union[str, List[str]]):


class BaseImageProcessor(ImageProcessingMixin):
_valid_processor_keys = None

def __init__(self, **kwargs):
super().__init__(**kwargs)

def __call__(self, images, **kwargs) -> BatchFeature:
"""Preprocess an image or a batch of images."""
self._validate_inputs(**kwargs)
return self.preprocess(images, **kwargs)

def _validate_preprocess_arguments(self, **kwargs):
"""Check if the arguments passed to the preprocess method have compatible settings e.g. if `size` is defined when `do_resize` is set to True."""
validate_preprocess_arguments(**kwargs)

def _validate_image_inputs(self, images, segmentation_maps=None, do_rescale=False):
"""Check if the images and segmentation maps are valid."""
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray."
)

if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_rescale and is_scaled_image(to_numpy_array(images[0])):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)

def _validate_inputs(self, images, segmentation_maps=None, **kwargs):
"""Check if the arguments passed to the preprocess method are valid."""
if self._valid_processor_keys is None:
raise ValueError("Each image processor must define self._valid_processor_keys")

for key in kwargs:
if key not in self._valid_processor_keys:
raise ValueError(f"Invalid argument {key} passed to preprocess method.")

validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
self._validate_preprocess_arguments(**kwargs)
self._validate_image_inputs(images, segmentation_maps, do_rescale=kwargs.get("do_rescale", False))

def preprocess(self, images, **kwargs) -> BatchFeature:
raise NotImplementedError("Each image processor must implement its own preprocess method")

Expand Down
49 changes: 0 additions & 49 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,47 +337,6 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image


def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
):
"""
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
existing arguments when possible.
"""
if do_rescale and rescale_factor is None:
raise ValueError("rescale_factor must be specified if do_rescale is True.")

if do_pad and size_divisibility is None:
# Here, size_divisor might be passed as the value of size
raise ValueError(
"Depending on moel, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True."
)

if do_normalize and (image_mean is None or image_std is None):
raise ValueError("image_mean and image_std must both be specified if do_normalize is True.")

if do_center_crop and crop_size is None:
raise ValueError("crop_size must be specified if do_center_crop is True.")

if do_resize and (size is None or resample is None):
raise ValueError("size and resample must be specified if do_resize is True.")


# In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin:
"""
Expand Down Expand Up @@ -759,11 +718,3 @@ def validate_annotations(
"(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
"the latter being a list of annotations in the COCO format."
)


def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
if unused_keys:
unused_key_str = ", ".join(unused_keys)
# TODO raise a warning here instead of simply logging?
logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
35 changes: 0 additions & 35 deletions src/transformers/models/beit/image_processing_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_available, is_torch_tensor, is_vision_available, logging

Expand Down Expand Up @@ -257,11 +253,6 @@ def _preprocess_image(
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
image = to_numpy_array(image)
if is_scaled_image(image) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess(
Expand Down Expand Up @@ -418,37 +409,11 @@ def preprocess(
image_std = image_std if image_std is not None else self.image_std
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels

validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)

images = make_list_of_images(images)

if segmentation_maps is not None:
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)

if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation_maps type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)

images = [
self._preprocess_image(
image=img,
Expand Down
31 changes: 0 additions & 31 deletions src/transformers/models/bit/image_processing_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,8 @@
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -274,42 +270,15 @@ def preprocess(
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb

validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)

images = make_list_of_images(images)

if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)

# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]

# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]

if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
Expand Down
28 changes: 0 additions & 28 deletions src/transformers/models/blip/image_processing_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging

Expand Down Expand Up @@ -250,37 +246,13 @@ def preprocess(

images = make_list_of_images(images)

validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)

if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]

# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]

if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
Expand Down
Loading

0 comments on commit fc226af

Please sign in to comment.