From fad4111ec55698598c860e69933b0772399e5f4a Mon Sep 17 00:00:00 2001 From: Tibor Reiss Date: Wed, 30 Oct 2024 22:06:53 +0100 Subject: [PATCH] Make kwargs uniform for OneFormer --- .../models/oneformer/processing_oneformer.py | 155 ++++++++++-------- .../oneformer/test_processor_oneformer.py | 64 ++++++-- 2 files changed, 141 insertions(+), 78 deletions(-) diff --git a/src/transformers/models/oneformer/processing_oneformer.py b/src/transformers/models/oneformer/processing_oneformer.py index 9e55be5d6731c5..56867db436522a 100644 --- a/src/transformers/models/oneformer/processing_oneformer.py +++ b/src/transformers/models/oneformer/processing_oneformer.py @@ -16,9 +16,12 @@ Image/Text processor class for OneFormer """ -from typing import List +from typing import Dict, List, Optional, Union -from ...processing_utils import ProcessorMixin +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_torch_available @@ -26,6 +29,24 @@ import torch +class OneFormerImagesKwargs(ImagesKwargs): + segmentation_maps: Optional[ImageInput] + task_inputs: Optional[Union[TextInput, PreTokenizedInput]] + instance_id_to_semantic_id: Optional[Dict[int, int]] + pad_and_return_pixel_mask: Optional[bool] + ignore_index: Optional[int] + do_reduce_labels: bool + repo_path: Optional[str] + class_info_file: Optional[str] + num_text: Optional[int] + num_labels: Optional[int] + + +class OneFormerProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: OneFormerImagesKwargs + _defaults = {} + + class OneFormerProcessor(ProcessorMixin): r""" Constructs an OneFormer processor which wraps [`OneFormerImageProcessor`] and @@ -37,9 +58,9 @@ class OneFormerProcessor(ProcessorMixin): The image processor is a required input. tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]): The tokenizer is a required input. - max_seq_len (`int`, *optional*, defaults to 77)): + max_seq_length (`int`, *optional*, defaults to 77): Sequence length for input text list. - task_seq_len (`int`, *optional*, defaults to 77): + task_seq_length (`int`, *optional*, defaults to 77): Sequence length for input task token. """ @@ -48,22 +69,23 @@ class OneFormerProcessor(ProcessorMixin): tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast") def __init__( - self, image_processor=None, tokenizer=None, max_seq_length: int = 77, task_seq_length: int = 77, **kwargs + self, + image_processor=None, + tokenizer=None, + max_seq_length: int = 77, + task_seq_length: int = 77, ): if image_processor is None: raise ValueError("You need to specify an `image_processor`.") if tokenizer is None: raise ValueError("You need to specify a `tokenizer`.") - self.max_seq_length = max_seq_length - self.task_seq_length = task_seq_length - super().__init__(image_processor, tokenizer) - def _preprocess_text(self, text_list=None, max_length=77): - if text_list is None: - raise ValueError("tokens cannot be None.") + self.max_seq_length = max_seq_length + self.task_seq_length = task_seq_length + def _preprocess_text(self, text_list: PreTokenizedInput, max_length: int = 77): tokens = self.tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True) attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"] @@ -76,13 +98,41 @@ def _preprocess_text(self, text_list=None, max_length=77): token_inputs = torch.cat(token_inputs, dim=0) return token_inputs - def __call__(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs): + @staticmethod + def _check_args( + images: Optional[ImageInput] = None, + task_inputs: Optional[Union[TextInput, PreTokenizedInput]] = None, + ): + if task_inputs is None: + raise ValueError("You have to specify the task_inputs. Found None.") + elif images is None: + raise ValueError("You have to specify the images. Found None.") + + if isinstance(task_inputs, str): + task_inputs = [task_inputs] + + if not isinstance(task_inputs, List) or not task_inputs: + raise TypeError("task_inputs should be a string or a list of strings.") + + if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs): + raise ValueError("task_inputs must be semantic, instance, or panoptic.") + + return task_inputs + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[OneFormerProcessorKwargs], + ) -> BatchFeature: """ Main method to prepare for the model one or several task input(s) and image(s). This method forwards the `task_inputs` and `kwargs` arguments to CLIPTokenizer's [`~CLIPTokenizer.__call__`] if `task_inputs` is not `None` to encode. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to OneFormerImageProcessor's [`~OneFormerImageProcessor.__call__`] if `images` is not `None`. Please refer to the - doctsring of the above two methods for more information. + docstring of the above two methods for more information. Args: task_inputs (`str`, `List[str]`): @@ -108,36 +158,28 @@ def __call__(self, images=None, task_inputs=None, segmentation_maps=None, **kwar - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ - if task_inputs is None: - raise ValueError("You have to specify the task_input. Found None.") - elif images is None: - raise ValueError("You have to specify the image. Found None.") - - if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs): - raise ValueError("task_inputs must be semantic, instance, or panoptic.") + output_kwargs = self._merge_kwargs( + OneFormerProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + segmentation_maps = output_kwargs["images_kwargs"].pop("segmentation_maps", None) + task_inputs = output_kwargs["images_kwargs"].pop("task_inputs", None) + task_inputs = self._check_args(images, task_inputs) - encoded_inputs = self.image_processor(images, task_inputs, segmentation_maps, **kwargs) + encoded_inputs = self.image_processor(images, task_inputs, segmentation_maps, **output_kwargs["images_kwargs"]) - if isinstance(task_inputs, str): - task_inputs = [task_inputs] - - if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs): - task_token_inputs = [] - for task in task_inputs: - task_input = f"the task is {task}" - task_token_inputs.append(task_input) - encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length) - else: - raise TypeError("Task Inputs should be a string or a list of strings.") + task_token_inputs = [] + for task in task_inputs: + task_input = f"the task is {task}" + task_token_inputs.append(task_input) + encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length) if hasattr(encoded_inputs, "text_inputs"): - texts_list = encoded_inputs.text_inputs - - text_inputs = [] - for texts in texts_list: - text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length) - text_inputs.append(text_input_list.unsqueeze(0)) - + text_inputs = [ + self._preprocess_text(texts, max_length=self.max_seq_length).unsqueeze(0) + for texts in encoded_inputs.text_inputs + ] encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0) return encoded_inputs @@ -148,36 +190,21 @@ def encode_inputs(self, images=None, task_inputs=None, segmentation_maps=None, * task_inputs. Please refer to the docstring of this method for more information. """ - if task_inputs is None: - raise ValueError("You have to specify the task_input. Found None.") - elif images is None: - raise ValueError("You have to specify the image. Found None.") - - if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs): - raise ValueError("task_inputs must be semantic, instance, or panoptic.") + task_inputs = self._check_args(images, task_inputs) encoded_inputs = self.image_processor.encode_inputs(images, task_inputs, segmentation_maps, **kwargs) - if isinstance(task_inputs, str): - task_inputs = [task_inputs] - - if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs): - task_token_inputs = [] - for task in task_inputs: - task_input = f"the task is {task}" - task_token_inputs.append(task_input) - encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length) - else: - raise TypeError("Task Inputs should be a string or a list of strings.") + task_token_inputs = [] + for task in task_inputs: + task_input = f"the task is {task}" + task_token_inputs.append(task_input) + encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length) if hasattr(encoded_inputs, "text_inputs"): - texts_list = encoded_inputs.text_inputs - - text_inputs = [] - for texts in texts_list: - text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length) - text_inputs.append(text_input_list.unsqueeze(0)) - + text_inputs = [ + self._preprocess_text(texts, max_length=self.max_seq_length).unsqueeze(0) + for texts in encoded_inputs.text_inputs + ] encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0) return encoded_inputs diff --git a/tests/models/oneformer/test_processor_oneformer.py b/tests/models/oneformer/test_processor_oneformer.py index 3a8a378b49009e..e0ef61b77525da 100644 --- a/tests/models/oneformer/test_processor_oneformer.py +++ b/tests/models/oneformer/test_processor_oneformer.py @@ -222,7 +222,11 @@ def test_call_pil(self): self.assertIsInstance(image, Image.Image) # Test not batched input - encoded_images = processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values + encoded_images = processor( + image_inputs[0], + task_inputs=["semantic"], + return_tensors="pt", + ).pixel_values expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values( image_inputs @@ -233,7 +237,11 @@ def test_call_pil(self): (1, self.processing_tester.num_channels, expected_height, expected_width), ) - tokenized_task_inputs = processor(image_inputs[0], ["semantic"], return_tensors="pt").task_inputs + tokenized_task_inputs = processor( + image_inputs[0], + task_inputs=["semantic"], + return_tensors="pt", + ).task_inputs self.assertEqual( tokenized_task_inputs.shape, @@ -245,7 +253,11 @@ def test_call_pil(self): image_inputs, batched=True ) - encoded_images = processor(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt").pixel_values + encoded_images = processor( + image_inputs, + task_inputs=["semantic"] * len(image_inputs), + return_tensors="pt", + ).pixel_values self.assertEqual( encoded_images.shape, ( @@ -257,7 +269,7 @@ def test_call_pil(self): ) tokenized_task_inputs = processor( - image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt" + image_inputs, task_inputs=["semantic"] * len(image_inputs), return_tensors="pt" ).task_inputs self.assertEqual( @@ -274,7 +286,11 @@ def test_call_numpy(self): self.assertIsInstance(image, np.ndarray) # Test not batched input - encoded_images = processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values + encoded_images = processor( + image_inputs[0], + task_inputs=["semantic"], + return_tensors="pt", + ).pixel_values expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values( image_inputs @@ -285,7 +301,11 @@ def test_call_numpy(self): (1, self.processing_tester.num_channels, expected_height, expected_width), ) - tokenized_task_inputs = processor(image_inputs[0], ["semantic"], return_tensors="pt").task_inputs + tokenized_task_inputs = processor( + image_inputs[0], + task_inputs=["semantic"], + return_tensors="pt", + ).task_inputs self.assertEqual( tokenized_task_inputs.shape, @@ -297,7 +317,11 @@ def test_call_numpy(self): image_inputs, batched=True ) - encoded_images = processor(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt").pixel_values + encoded_images = processor( + image_inputs, + task_inputs=["semantic"] * len(image_inputs), + return_tensors="pt", + ).pixel_values self.assertEqual( encoded_images.shape, ( @@ -309,7 +333,7 @@ def test_call_numpy(self): ) tokenized_task_inputs = processor( - image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt" + image_inputs, task_inputs=["semantic"] * len(image_inputs), return_tensors="pt" ).task_inputs self.assertEqual( @@ -326,7 +350,11 @@ def test_call_pytorch(self): self.assertIsInstance(image, torch.Tensor) # Test not batched input - encoded_images = processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values + encoded_images = processor( + image_inputs[0], + task_inputs=["semantic"], + return_tensors="pt", + ).pixel_values expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values( image_inputs @@ -337,7 +365,11 @@ def test_call_pytorch(self): (1, self.processing_tester.num_channels, expected_height, expected_width), ) - tokenized_task_inputs = processor(image_inputs[0], ["semantic"], return_tensors="pt").task_inputs + tokenized_task_inputs = processor( + image_inputs[0], + task_inputs=["semantic"], + return_tensors="pt", + ).task_inputs self.assertEqual( tokenized_task_inputs.shape, @@ -349,7 +381,11 @@ def test_call_pytorch(self): image_inputs, batched=True ) - encoded_images = processor(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt").pixel_values + encoded_images = processor( + image_inputs, + task_inputs=["semantic"] * len(image_inputs), + return_tensors="pt", + ).pixel_values self.assertEqual( encoded_images.shape, ( @@ -361,7 +397,7 @@ def test_call_pytorch(self): ) tokenized_task_inputs = processor( - image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt" + image_inputs, task_inputs=["semantic"] * len(image_inputs), return_tensors="pt" ).task_inputs self.assertEqual( @@ -389,8 +425,8 @@ def comm_get_processor_inputs(self, with_segmentation_maps=False, is_instance_ma inputs = processor( image_inputs, - ["semantic"] * len(image_inputs), - annotations, + task_inputs=["semantic"] * len(image_inputs), + segmentation_maps=annotations, return_tensors="pt", instance_id_to_semantic_id=instance_id_to_semantic_id, pad_and_return_pixel_mask=True,