diff --git a/src/transformers/models/sam/processing_sam.py b/src/transformers/models/sam/processing_sam.py index 9e67be1e1e55c2..7ea1d573544e4d 100644 --- a/src/transformers/models/sam/processing_sam.py +++ b/src/transformers/models/sam/processing_sam.py @@ -17,13 +17,14 @@ """ from copy import deepcopy -from typing import Optional, Union +from typing import List, Optional, Union import numpy as np -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding -from ...utils import TensorType, is_tf_available, is_torch_available +from ...image_utils import ImageInput, VideoInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput +from ...utils import is_tf_available, is_torch_available if is_torch_available(): @@ -33,6 +34,23 @@ import tensorflow as tf +class SamImagesKwargs(ImagesKwargs): + segmentation_maps: Optional[ImageInput] + input_points: Optional[List[List[float]]] + input_labels: Optional[List[List[int]]] + input_boxes: Optional[List[List[List[float]]]] + point_pad_value: Optional[int] + + +class SamProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: SamImagesKwargs + _defaults = { + "images_kwargs": { + "point_pad_value": -10, + } + } + + class SamProcessor(ProcessorMixin): r""" Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a @@ -48,32 +66,50 @@ class SamProcessor(ProcessorMixin): attributes = ["image_processor"] image_processor_class = "SamImageProcessor" + # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + optional_call_args = [ + "segmentation_maps", + "input_points", + "input_labels", + "input_boxes", + ] def __init__(self, image_processor): super().__init__(image_processor) - self.current_processor = self.image_processor - self.point_pad_value = -10 self.target_size = self.image_processor.size["longest_edge"] def __call__( self, - images=None, - segmentation_maps=None, - input_points=None, - input_labels=None, - input_boxes=None, - return_tensors: Optional[Union[str, TensorType]] = None, + images: Optional[ImageInput] = None, + # The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes` + # arguments that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, + # or this conversation for more context: + # https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 + # This behavior is only needed for backward compatibility and will be removed in future versions. + *args, # to be deprecated + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio: Optional[AudioInput] = None, + video: Optional[VideoInput] = None, **kwargs, ) -> BatchEncoding: """ This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D points and bounding boxes for the model if they are provided. """ + output_kwargs = self._merge_kwargs( + SamProcessorKwargs, + tokenizer_init_kwargs={}, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + input_points = output_kwargs["images_kwargs"].pop("input_points", None) + input_labels = output_kwargs["images_kwargs"].pop("input_labels", None) + input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None) + encoding_image_processor = self.image_processor( images, - segmentation_maps=segmentation_maps, - return_tensors=return_tensors, - **kwargs, + **output_kwargs["images_kwargs"], ) # pop arguments that are not used in the foward but used nevertheless @@ -94,7 +130,8 @@ def __call__( input_points=input_points, input_labels=input_labels, input_boxes=input_boxes, - return_tensors=return_tensors, + return_tensors=output_kwargs["common_kwargs"].get("return_tensors"), + point_pad_value=output_kwargs["images_kwargs"].get("point_pad_value"), ) return encoding_image_processor @@ -107,6 +144,7 @@ def _normalize_and_convert( input_labels=None, input_boxes=None, return_tensors="pt", + point_pad_value=-10, ): if input_points is not None: if len(original_sizes) != len(input_points): @@ -121,7 +159,9 @@ def _normalize_and_convert( # check that all arrays have the same shape if not all(point.shape == input_points[0].shape for point in input_points): if input_labels is not None: - input_points, input_labels = self._pad_points_and_labels(input_points, input_labels) + input_points, input_labels = self._pad_points_and_labels( + input_points, input_labels, point_pad_value + ) input_points = np.array(input_points) @@ -174,7 +214,7 @@ def _normalize_and_convert( return encoding_image_processor - def _pad_points_and_labels(self, input_points, input_labels): + def _pad_points_and_labels(self, input_points, input_labels, point_pad_value): r""" The method pads the 2D points and labels to the maximum number of points in the batch. """ @@ -183,9 +223,9 @@ def _pad_points_and_labels(self, input_points, input_labels): for i, point in enumerate(input_points): if point.shape[0] != expected_nb_points: point = np.concatenate( - [point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0 + [point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0 ) - input_labels[i] = np.append(input_labels[i], [self.point_pad_value]) + input_labels[i] = np.append(input_labels[i], [point_pad_value]) processed_input_points.append(point) input_points = processed_input_points return input_points, input_labels diff --git a/tests/models/sam/test_processor_sam.py b/tests/models/sam/test_processor_sam.py index 22eb88d03d6b04..654f892062625a 100644 --- a/tests/models/sam/test_processor_sam.py +++ b/tests/models/sam/test_processor_sam.py @@ -26,7 +26,7 @@ ) from transformers.utils import is_tf_available, is_torch_available, is_vision_available -from ...test_processing_common import prepare_image_inputs +from ...test_processing_common import ProcessorTesterMixin, prepare_image_inputs if is_vision_available(): @@ -43,7 +43,9 @@ @require_vision @require_torchvision -class SamProcessorTest(unittest.TestCase): +class SamProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = SamProcessor + def setUp(self): self.tmpdirname = tempfile.mkdtemp() image_processor = SamImageProcessor() @@ -56,11 +58,6 @@ def get_image_processor(self, **kwargs): def tearDown(self): shutil.rmtree(self.tmpdirname) - # Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor - def prepare_image_inputs(self): - """This function prepares a list of PIL images.""" - return prepare_image_inputs() - def prepare_mask_inputs(self): """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, or a list of PyTorch tensors if one specifies torchify=True. @@ -69,6 +66,21 @@ def prepare_mask_inputs(self): mask_inputs = [Image.fromarray(x) for x in mask_inputs] return mask_inputs + def test_chat_template_save_loading(self): + self.skipTest("SamProcessor does not have a tokenizer") + + def test_image_processor_defaults_preserved_by_image_kwargs(self): + self.skipTest("SamProcessor does not have a tokenizer") + + def test_kwargs_overrides_default_image_processor_kwargs(self): + self.skipTest("SamProcessor does not have a tokenizer") + + def test_kwargs_overrides_default_tokenizer_kwargs(self): + self.skipTest("SamProcessor does not have a tokenizer") + + def test_tokenizer_defaults_preserved_by_kwargs(self): + self.skipTest("SamProcessor does not have a tokenizer") + def test_save_load_pretrained_additional_features(self): processor = SamProcessor(image_processor=self.get_image_processor()) processor.save_pretrained(self.tmpdirname) @@ -165,7 +177,7 @@ def get_image_processor(self, **kwargs): def tearDown(self): shutil.rmtree(self.tmpdirname) - # Processor tester class can't use ProcessorTesterMixin as processor is atypical e.g. only contains an image processor and it assumes torch + # This is to avoid repeating the skipping of the common tests def prepare_image_inputs(self): """This function prepares a list of PIL images.""" return prepare_image_inputs() @@ -248,7 +260,7 @@ def get_image_processor(self, **kwargs): def tearDown(self): shutil.rmtree(self.tmpdirname) - # Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor + # This is to avoid repeating the skipping of the common tests def prepare_image_inputs(self): """This function prepares a list of PIL images.""" return prepare_image_inputs()