Skip to content

Commit

Permalink
Make kwargs uniform for SAM
Browse files Browse the repository at this point in the history
  • Loading branch information
tibor-reiss committed Nov 2, 2024
1 parent 203e270 commit 5c2efda
Showing 1 changed file with 42 additions and 13 deletions.
55 changes: 42 additions & 13 deletions src/transformers/models/sam/processing_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
"""

from copy import deepcopy
from typing import Optional, Union
from typing import List, Optional, Union

import numpy as np

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, is_tf_available, is_torch_available
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import is_tf_available, is_torch_available


if is_torch_available():
Expand All @@ -33,6 +34,18 @@
import tensorflow as tf


class SamImagesKwargs(ImagesKwargs):
segmentation_maps: Optional[ImageInput]
input_points: Optional[List[List[float]]]
input_labels: Optional[List[List[int]]]
input_boxes: Optional[List[List[List[float]]]]


class SamProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: SamImagesKwargs
_defaults = {}


class SamProcessor(ProcessorMixin):
r"""
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
Expand All @@ -55,25 +68,41 @@ def __init__(self, image_processor):
self.point_pad_value = -10
self.target_size = self.image_processor.size["longest_edge"]

@staticmethod
def _add_args_for_backward_compatibility(args):
"""
Remove this function once support for args is dropped in __call__
"""
if len(args) > 4:
raise ValueError("Too many positional arguments")
return dict(zip(("segmentation_maps", "input_points", "input_labels", "input_boxes"), args))

def __call__(
self,
images=None,
segmentation_maps=None,
input_points=None,
input_labels=None,
input_boxes=None,
return_tensors: Optional[Union[str, TensorType]] = None,
*args, # to be deprecated
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
video=None,
**kwargs,
) -> BatchEncoding:
"""
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
points and bounding boxes for the model if they are provided.
"""
output_kwargs = self._merge_kwargs(
SamProcessorKwargs,
tokenizer_init_kwargs={},
**kwargs,
**self._add_args_for_backward_compatibility(args),
)
input_points = output_kwargs["images_kwargs"].pop("input_points", None)
input_labels = output_kwargs["images_kwargs"].pop("input_labels", None)
input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None)

encoding_image_processor = self.image_processor(
images,
segmentation_maps=segmentation_maps,
return_tensors=return_tensors,
**kwargs,
**output_kwargs["images_kwargs"],
)

# pop arguments that are not used in the foward but used nevertheless
Expand All @@ -94,7 +123,7 @@ def __call__(
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
return_tensors=return_tensors,
return_tensors=output_kwargs["common_kwargs"].get("return_tensors"),
)

return encoding_image_processor
Expand Down

0 comments on commit 5c2efda

Please sign in to comment.