Skip to content

Commit

Permalink
Code review - use existing methods
Browse files Browse the repository at this point in the history
  • Loading branch information
tibor-reiss committed Nov 29, 2024
1 parent f1dff72 commit 67d3d6b
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions src/transformers/models/sam/processing_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,27 @@ class SamProcessor(ProcessorMixin):

attributes = ["image_processor"]
image_processor_class = "SamImageProcessor"
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
optional_call_args = [
"segmentation_maps",
"input_points",
"input_labels",
"input_boxes",
]

def __init__(self, image_processor):
super().__init__(image_processor)
self.target_size = self.image_processor.size["longest_edge"]

@staticmethod
def _add_args_for_backward_compatibility(args):
"""
Remove this function once support for args is dropped in __call__
"""
if len(args) > 4:
raise ValueError("Too many positional arguments")
return dict(zip(("segmentation_maps", "input_points", "input_labels", "input_boxes"), args))

def __call__(
self,
images: Optional[ImageInput] = None,
# The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes`
# arguments that may be passed as a positional argument.
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
# or this conversation for more context:
# https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
# This behavior is only needed for backward compatibility and will be removed in future versions.
*args, # to be deprecated
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio: Optional[AudioInput] = None,
Expand All @@ -97,7 +101,7 @@ def __call__(
SamProcessorKwargs,
tokenizer_init_kwargs={},
**kwargs,
**self._add_args_for_backward_compatibility(args),
**self.prepare_and_validate_optional_call_args(*args),
)
input_points = output_kwargs["images_kwargs"].pop("input_points", None)
input_labels = output_kwargs["images_kwargs"].pop("input_labels", None)
Expand Down

0 comments on commit 67d3d6b

Please sign in to comment.