Skip to content

Commit

Permalink
Take and accept all input arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed May 1, 2024
1 parent 19732f1 commit c531e10
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/transformers/image_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def __init__(self, **kwargs):

def __call__(self, images, **kwargs) -> BatchFeature:
"""Preprocess an image or a batch of images."""
self._validate_inputs(**kwargs)
self._validate_inputs(images, **kwargs)
return self.preprocess(images, **kwargs)

def _validate_preprocess_arguments(self, **kwargs):
Expand All @@ -628,6 +628,19 @@ def _validate_image_inputs(self, images, segmentation_maps=None, do_rescale=Fals

def _validate_inputs(self, images, segmentation_maps=None, **kwargs):
"""Check if the arguments passed to the preprocess method are valid."""
return_tensors = kwargs.pop("return_tensors", None)
input_data_format = kwargs.pop("input_data_format", None)
data_format = kwargs.pop("data_format", None)

if return_tensors not in (None, "np", "pt", "tf", "jax"):
raise ValueError("return_tensors should be one of 'np', 'pt', 'tf', 'jax'.")

if input_data_format not in (None, ChannelDimension.FIRST, ChannelDimension.LAST):
raise ValueError("input_data_format should be one of 'channels_first' or 'channels_last'.")

if data_format not in (None, ChannelDimension.FIRST, ChannelDimension.LAST):
raise ValueError("data_format should be one of 'channels_first' or 'channels_last'.")

if self._valid_processor_keys is None:
raise ValueError("Each image processor must define self._valid_processor_keys")

Expand Down

0 comments on commit c531e10

Please sign in to comment.