From c531e107cb35a5a894b5a000c9a7a9ad6d710967 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Wed, 1 May 2024 13:14:22 +0100 Subject: [PATCH] Take and accept all input arguments --- src/transformers/image_processing_utils.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py index acf201daec3cad..f622c4699af349 100644 --- a/src/transformers/image_processing_utils.py +++ b/src/transformers/image_processing_utils.py @@ -601,7 +601,7 @@ def __init__(self, **kwargs): def __call__(self, images, **kwargs) -> BatchFeature: """Preprocess an image or a batch of images.""" - self._validate_inputs(**kwargs) + self._validate_inputs(images, **kwargs) return self.preprocess(images, **kwargs) def _validate_preprocess_arguments(self, **kwargs): @@ -628,6 +628,19 @@ def _validate_image_inputs(self, images, segmentation_maps=None, do_rescale=Fals def _validate_inputs(self, images, segmentation_maps=None, **kwargs): """Check if the arguments passed to the preprocess method are valid.""" + return_tensors = kwargs.pop("return_tensors", None) + input_data_format = kwargs.pop("input_data_format", None) + data_format = kwargs.pop("data_format", None) + + if return_tensors not in (None, "np", "pt", "tf", "jax"): + raise ValueError("return_tensors should be one of 'np', 'pt', 'tf', 'jax'.") + + if input_data_format not in (None, ChannelDimension.FIRST, ChannelDimension.LAST): + raise ValueError("input_data_format should be one of 'channels_first' or 'channels_last'.") + + if data_format not in (None, ChannelDimension.FIRST, ChannelDimension.LAST): + raise ValueError("data_format should be one of 'channels_first' or 'channels_last'.") + if self._valid_processor_keys is None: raise ValueError("Each image processor must define self._valid_processor_keys")