From 5c2efdac4a911118dcef411ea2fc96aae11ff15e Mon Sep 17 00:00:00 2001 From: Tibor Reiss Date: Sat, 2 Nov 2024 10:31:33 +0100 Subject: [PATCH] Make kwargs uniform for SAM --- src/transformers/models/sam/processing_sam.py | 55 ++++++++++++++----- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/sam/processing_sam.py b/src/transformers/models/sam/processing_sam.py index 9e67be1e1e55c2..473d31947afd51 100644 --- a/src/transformers/models/sam/processing_sam.py +++ b/src/transformers/models/sam/processing_sam.py @@ -17,13 +17,14 @@ """ from copy import deepcopy -from typing import Optional, Union +from typing import List, Optional, Union import numpy as np -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding -from ...utils import TensorType, is_tf_available, is_torch_available +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput +from ...utils import is_tf_available, is_torch_available if is_torch_available(): @@ -33,6 +34,18 @@ import tensorflow as tf +class SamImagesKwargs(ImagesKwargs): + segmentation_maps: Optional[ImageInput] + input_points: Optional[List[List[float]]] + input_labels: Optional[List[List[int]]] + input_boxes: Optional[List[List[List[float]]]] + + +class SamProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: SamImagesKwargs + _defaults = {} + + class SamProcessor(ProcessorMixin): r""" Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a @@ -55,25 +68,41 @@ def __init__(self, image_processor): self.point_pad_value = -10 self.target_size = self.image_processor.size["longest_edge"] + @staticmethod + def _add_args_for_backward_compatibility(args): + """ + Remove this function once support for args is dropped in __call__ + """ + if len(args) > 4: + raise ValueError("Too many positional arguments") + return dict(zip(("segmentation_maps", "input_points", "input_labels", "input_boxes"), args)) + def __call__( self, images=None, - segmentation_maps=None, - input_points=None, - input_labels=None, - input_boxes=None, - return_tensors: Optional[Union[str, TensorType]] = None, + *args, # to be deprecated + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + video=None, **kwargs, ) -> BatchEncoding: """ This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D points and bounding boxes for the model if they are provided. """ + output_kwargs = self._merge_kwargs( + SamProcessorKwargs, + tokenizer_init_kwargs={}, + **kwargs, + **self._add_args_for_backward_compatibility(args), + ) + input_points = output_kwargs["images_kwargs"].pop("input_points", None) + input_labels = output_kwargs["images_kwargs"].pop("input_labels", None) + input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None) + encoding_image_processor = self.image_processor( images, - segmentation_maps=segmentation_maps, - return_tensors=return_tensors, - **kwargs, + **output_kwargs["images_kwargs"], ) # pop arguments that are not used in the foward but used nevertheless @@ -94,7 +123,7 @@ def __call__( input_points=input_points, input_labels=input_labels, input_boxes=input_boxes, - return_tensors=return_tensors, + return_tensors=output_kwargs["common_kwargs"].get("return_tensors"), ) return encoding_image_processor