Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

uniformize kwargs for SAM #34578

Merged
merged 7 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 60 additions & 20 deletions src/transformers/models/sam/processing_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
"""

from copy import deepcopy
from typing import Optional, Union
from typing import List, Optional, Union

import numpy as np

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, is_tf_available, is_torch_available
from ...image_utils import ImageInput, VideoInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
from ...utils import is_tf_available, is_torch_available


if is_torch_available():
Expand All @@ -33,6 +34,23 @@
import tensorflow as tf


class SamImagesKwargs(ImagesKwargs):
segmentation_maps: Optional[ImageInput]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it this segmentation_maps is not required for the inference of SAM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed for backwards compatibility - even if it is not used. E.g. take the following call: processor(images, None, input_points). If I would remove it from _add_args_for_backward_compatibility, such calls would break.
If I understood correctly, #31911 is about uniform kwargs, and then at a later step the API would be cleaned up further. In SAM's case, _add_args_for_backward_compatibility will be removed, and then segmentation_maps can be removed as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually have an existing way to handle such positional args, you can take a look at udop processor for example :)

input_points: Optional[List[List[float]]]
input_labels: Optional[List[List[int]]]
input_boxes: Optional[List[List[List[float]]]]
point_pad_value: Optional[int]


class SamProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: SamImagesKwargs
_defaults = {
"images_kwargs": {
"point_pad_value": -10,
}
}


class SamProcessor(ProcessorMixin):
r"""
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
Expand All @@ -48,32 +66,50 @@ class SamProcessor(ProcessorMixin):

attributes = ["image_processor"]
image_processor_class = "SamImageProcessor"
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
optional_call_args = [
"segmentation_maps",
"input_points",
"input_labels",
"input_boxes",
]

def __init__(self, image_processor):
super().__init__(image_processor)
self.current_processor = self.image_processor
self.point_pad_value = -10
self.target_size = self.image_processor.size["longest_edge"]

def __call__(
self,
images=None,
segmentation_maps=None,
input_points=None,
input_labels=None,
input_boxes=None,
Comment on lines -61 to -64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be added to the optional_call_args attribute (see udop processor)

return_tensors: Optional[Union[str, TensorType]] = None,
images: Optional[ImageInput] = None,
# The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes`
# arguments that may be passed as a positional argument.
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
# or this conversation for more context:
# https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
# This behavior is only needed for backward compatibility and will be removed in future versions.
*args, # to be deprecated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there reason for the existance of this *args?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, see comment above. Purely for bc and should be removed at a later stage.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
*args, # to be deprecated
# The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes` arguments that may be passed as a positional argument.
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
# or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
# This behavior is only needed for backward compatibility and will be removed in future versions.
#
*args,

text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio: Optional[AudioInput] = None,
video: Optional[VideoInput] = None,
**kwargs,
) -> BatchEncoding:
"""
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
points and bounding boxes for the model if they are provided.
"""
output_kwargs = self._merge_kwargs(
SamProcessorKwargs,
tokenizer_init_kwargs={},
**kwargs,
**self.prepare_and_validate_optional_call_args(*args),
)
input_points = output_kwargs["images_kwargs"].pop("input_points", None)
input_labels = output_kwargs["images_kwargs"].pop("input_labels", None)
input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None)

encoding_image_processor = self.image_processor(
images,
segmentation_maps=segmentation_maps,
return_tensors=return_tensors,
**kwargs,
**output_kwargs["images_kwargs"],
)

# pop arguments that are not used in the foward but used nevertheless
Expand All @@ -94,7 +130,8 @@ def __call__(
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
return_tensors=return_tensors,
return_tensors=output_kwargs["common_kwargs"].get("return_tensors"),
point_pad_value=output_kwargs["images_kwargs"].get("point_pad_value"),
)

return encoding_image_processor
Expand All @@ -107,6 +144,7 @@ def _normalize_and_convert(
input_labels=None,
input_boxes=None,
return_tensors="pt",
point_pad_value=-10,
):
if input_points is not None:
if len(original_sizes) != len(input_points):
Expand All @@ -121,7 +159,9 @@ def _normalize_and_convert(
# check that all arrays have the same shape
if not all(point.shape == input_points[0].shape for point in input_points):
if input_labels is not None:
input_points, input_labels = self._pad_points_and_labels(input_points, input_labels)
input_points, input_labels = self._pad_points_and_labels(
input_points, input_labels, point_pad_value
)

input_points = np.array(input_points)

Expand Down Expand Up @@ -174,7 +214,7 @@ def _normalize_and_convert(

return encoding_image_processor

def _pad_points_and_labels(self, input_points, input_labels):
def _pad_points_and_labels(self, input_points, input_labels, point_pad_value):
r"""
The method pads the 2D points and labels to the maximum number of points in the batch.
"""
Expand All @@ -183,9 +223,9 @@ def _pad_points_and_labels(self, input_points, input_labels):
for i, point in enumerate(input_points):
if point.shape[0] != expected_nb_points:
point = np.concatenate(
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
)
input_labels[i] = np.append(input_labels[i], [self.point_pad_value])
input_labels[i] = np.append(input_labels[i], [point_pad_value])
processed_input_points.append(point)
input_points = processed_input_points
return input_points, input_labels
Expand Down
30 changes: 21 additions & 9 deletions tests/models/sam/test_processor_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from transformers.utils import is_tf_available, is_torch_available, is_vision_available

from ...test_processing_common import prepare_image_inputs
from ...test_processing_common import ProcessorTesterMixin, prepare_image_inputs


if is_vision_available():
Expand All @@ -43,7 +43,9 @@

@require_vision
@require_torchvision
class SamProcessorTest(unittest.TestCase):
class SamProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = SamProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor()
Expand All @@ -56,11 +58,6 @@ def get_image_processor(self, **kwargs):
def tearDown(self):
shutil.rmtree(self.tmpdirname)

# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
return prepare_image_inputs()

def prepare_mask_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
Expand All @@ -69,6 +66,21 @@ def prepare_mask_inputs(self):
mask_inputs = [Image.fromarray(x) for x in mask_inputs]
return mask_inputs

def test_chat_template_save_loading(self):
self.skipTest("SamProcessor does not have a tokenizer")

def test_image_processor_defaults_preserved_by_image_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")

def test_kwargs_overrides_default_image_processor_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")

def test_kwargs_overrides_default_tokenizer_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")

def test_tokenizer_defaults_preserved_by_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")

def test_save_load_pretrained_additional_features(self):
processor = SamProcessor(image_processor=self.get_image_processor())
processor.save_pretrained(self.tmpdirname)
Expand Down Expand Up @@ -165,7 +177,7 @@ def get_image_processor(self, **kwargs):
def tearDown(self):
shutil.rmtree(self.tmpdirname)

# Processor tester class can't use ProcessorTesterMixin as processor is atypical e.g. only contains an image processor and it assumes torch
# This is to avoid repeating the skipping of the common tests
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
return prepare_image_inputs()
Expand Down Expand Up @@ -248,7 +260,7 @@ def get_image_processor(self, **kwargs):
def tearDown(self):
shutil.rmtree(self.tmpdirname)

# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor
# This is to avoid repeating the skipping of the common tests
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
return prepare_image_inputs()
Expand Down