-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
uniformize kwargs for SAM #34578
uniformize kwargs for SAM #34578
Changes from all commits
c3b52af
ff3c3e9
f9c3fe1
3da5ddc
96dd98f
d92cd4b
6f24d6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,13 +17,14 @@ | |||||||||||||||
""" | ||||||||||||||||
|
||||||||||||||||
from copy import deepcopy | ||||||||||||||||
from typing import Optional, Union | ||||||||||||||||
from typing import List, Optional, Union | ||||||||||||||||
|
||||||||||||||||
import numpy as np | ||||||||||||||||
|
||||||||||||||||
from ...processing_utils import ProcessorMixin | ||||||||||||||||
from ...tokenization_utils_base import BatchEncoding | ||||||||||||||||
from ...utils import TensorType, is_tf_available, is_torch_available | ||||||||||||||||
from ...image_utils import ImageInput, VideoInput | ||||||||||||||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin | ||||||||||||||||
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput | ||||||||||||||||
from ...utils import is_tf_available, is_torch_available | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
if is_torch_available(): | ||||||||||||||||
|
@@ -33,6 +34,23 @@ | |||||||||||||||
import tensorflow as tf | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class SamImagesKwargs(ImagesKwargs): | ||||||||||||||||
segmentation_maps: Optional[ImageInput] | ||||||||||||||||
input_points: Optional[List[List[float]]] | ||||||||||||||||
input_labels: Optional[List[List[int]]] | ||||||||||||||||
input_boxes: Optional[List[List[List[float]]]] | ||||||||||||||||
point_pad_value: Optional[int] | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class SamProcessorKwargs(ProcessingKwargs, total=False): | ||||||||||||||||
images_kwargs: SamImagesKwargs | ||||||||||||||||
_defaults = { | ||||||||||||||||
"images_kwargs": { | ||||||||||||||||
"point_pad_value": -10, | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class SamProcessor(ProcessorMixin): | ||||||||||||||||
r""" | ||||||||||||||||
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a | ||||||||||||||||
|
@@ -48,32 +66,50 @@ class SamProcessor(ProcessorMixin): | |||||||||||||||
|
||||||||||||||||
attributes = ["image_processor"] | ||||||||||||||||
image_processor_class = "SamImageProcessor" | ||||||||||||||||
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. | ||||||||||||||||
optional_call_args = [ | ||||||||||||||||
"segmentation_maps", | ||||||||||||||||
"input_points", | ||||||||||||||||
"input_labels", | ||||||||||||||||
"input_boxes", | ||||||||||||||||
] | ||||||||||||||||
|
||||||||||||||||
def __init__(self, image_processor): | ||||||||||||||||
super().__init__(image_processor) | ||||||||||||||||
self.current_processor = self.image_processor | ||||||||||||||||
self.point_pad_value = -10 | ||||||||||||||||
self.target_size = self.image_processor.size["longest_edge"] | ||||||||||||||||
|
||||||||||||||||
def __call__( | ||||||||||||||||
self, | ||||||||||||||||
images=None, | ||||||||||||||||
segmentation_maps=None, | ||||||||||||||||
input_points=None, | ||||||||||||||||
input_labels=None, | ||||||||||||||||
input_boxes=None, | ||||||||||||||||
Comment on lines
-61
to
-64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These should be added to the |
||||||||||||||||
return_tensors: Optional[Union[str, TensorType]] = None, | ||||||||||||||||
images: Optional[ImageInput] = None, | ||||||||||||||||
# The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes` | ||||||||||||||||
# arguments that may be passed as a positional argument. | ||||||||||||||||
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, | ||||||||||||||||
# or this conversation for more context: | ||||||||||||||||
# https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 | ||||||||||||||||
# This behavior is only needed for backward compatibility and will be removed in future versions. | ||||||||||||||||
*args, # to be deprecated | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there reason for the existance of this *args? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, see comment above. Purely for bc and should be removed at a later stage. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, | ||||||||||||||||
audio: Optional[AudioInput] = None, | ||||||||||||||||
video: Optional[VideoInput] = None, | ||||||||||||||||
**kwargs, | ||||||||||||||||
) -> BatchEncoding: | ||||||||||||||||
""" | ||||||||||||||||
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D | ||||||||||||||||
points and bounding boxes for the model if they are provided. | ||||||||||||||||
""" | ||||||||||||||||
output_kwargs = self._merge_kwargs( | ||||||||||||||||
SamProcessorKwargs, | ||||||||||||||||
tokenizer_init_kwargs={}, | ||||||||||||||||
**kwargs, | ||||||||||||||||
**self.prepare_and_validate_optional_call_args(*args), | ||||||||||||||||
) | ||||||||||||||||
input_points = output_kwargs["images_kwargs"].pop("input_points", None) | ||||||||||||||||
input_labels = output_kwargs["images_kwargs"].pop("input_labels", None) | ||||||||||||||||
input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None) | ||||||||||||||||
|
||||||||||||||||
encoding_image_processor = self.image_processor( | ||||||||||||||||
images, | ||||||||||||||||
segmentation_maps=segmentation_maps, | ||||||||||||||||
return_tensors=return_tensors, | ||||||||||||||||
**kwargs, | ||||||||||||||||
**output_kwargs["images_kwargs"], | ||||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
# pop arguments that are not used in the foward but used nevertheless | ||||||||||||||||
|
@@ -94,7 +130,8 @@ def __call__( | |||||||||||||||
input_points=input_points, | ||||||||||||||||
input_labels=input_labels, | ||||||||||||||||
input_boxes=input_boxes, | ||||||||||||||||
return_tensors=return_tensors, | ||||||||||||||||
return_tensors=output_kwargs["common_kwargs"].get("return_tensors"), | ||||||||||||||||
point_pad_value=output_kwargs["images_kwargs"].get("point_pad_value"), | ||||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
return encoding_image_processor | ||||||||||||||||
|
@@ -107,6 +144,7 @@ def _normalize_and_convert( | |||||||||||||||
input_labels=None, | ||||||||||||||||
input_boxes=None, | ||||||||||||||||
return_tensors="pt", | ||||||||||||||||
point_pad_value=-10, | ||||||||||||||||
): | ||||||||||||||||
if input_points is not None: | ||||||||||||||||
if len(original_sizes) != len(input_points): | ||||||||||||||||
|
@@ -121,7 +159,9 @@ def _normalize_and_convert( | |||||||||||||||
# check that all arrays have the same shape | ||||||||||||||||
if not all(point.shape == input_points[0].shape for point in input_points): | ||||||||||||||||
if input_labels is not None: | ||||||||||||||||
input_points, input_labels = self._pad_points_and_labels(input_points, input_labels) | ||||||||||||||||
input_points, input_labels = self._pad_points_and_labels( | ||||||||||||||||
input_points, input_labels, point_pad_value | ||||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
input_points = np.array(input_points) | ||||||||||||||||
|
||||||||||||||||
|
@@ -174,7 +214,7 @@ def _normalize_and_convert( | |||||||||||||||
|
||||||||||||||||
return encoding_image_processor | ||||||||||||||||
|
||||||||||||||||
def _pad_points_and_labels(self, input_points, input_labels): | ||||||||||||||||
def _pad_points_and_labels(self, input_points, input_labels, point_pad_value): | ||||||||||||||||
r""" | ||||||||||||||||
The method pads the 2D points and labels to the maximum number of points in the batch. | ||||||||||||||||
""" | ||||||||||||||||
|
@@ -183,9 +223,9 @@ def _pad_points_and_labels(self, input_points, input_labels): | |||||||||||||||
for i, point in enumerate(input_points): | ||||||||||||||||
if point.shape[0] != expected_nb_points: | ||||||||||||||||
point = np.concatenate( | ||||||||||||||||
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0 | ||||||||||||||||
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0 | ||||||||||||||||
) | ||||||||||||||||
input_labels[i] = np.append(input_labels[i], [self.point_pad_value]) | ||||||||||||||||
input_labels[i] = np.append(input_labels[i], [point_pad_value]) | ||||||||||||||||
processed_input_points.append(point) | ||||||||||||||||
input_points = processed_input_points | ||||||||||||||||
return input_points, input_labels | ||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it this
segmentation_maps
is not required for the inference of SAM.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is needed for backwards compatibility - even if it is not used. E.g. take the following call:
processor(images, None, input_points)
. If I would remove it from_add_args_for_backward_compatibility
, such calls would break.If I understood correctly, #31911 is about uniform kwargs, and then at a later step the API would be cleaned up further. In SAM's case,
_add_args_for_backward_compatibility
will be removed, and thensegmentation_maps
can be removed as well.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually have an existing way to handle such positional args, you can take a look at udop processor for example :)