Skip to content

Commit

Permalink
Make kwargs uniform for OneFormer
Browse files Browse the repository at this point in the history
  • Loading branch information
tibor-reiss committed Oct 31, 2024
1 parent 9c4ac3e commit fad4111
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 78 deletions.
155 changes: 91 additions & 64 deletions src/transformers/models/oneformer/processing_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,37 @@
Image/Text processor class for OneFormer
"""

from typing import List
from typing import Dict, List, Optional, Union

from ...processing_utils import ProcessorMixin
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_available


if is_torch_available():
import torch


class OneFormerImagesKwargs(ImagesKwargs):
segmentation_maps: Optional[ImageInput]
task_inputs: Optional[Union[TextInput, PreTokenizedInput]]
instance_id_to_semantic_id: Optional[Dict[int, int]]
pad_and_return_pixel_mask: Optional[bool]
ignore_index: Optional[int]
do_reduce_labels: bool
repo_path: Optional[str]
class_info_file: Optional[str]
num_text: Optional[int]
num_labels: Optional[int]


class OneFormerProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: OneFormerImagesKwargs
_defaults = {}


class OneFormerProcessor(ProcessorMixin):
r"""
Constructs an OneFormer processor which wraps [`OneFormerImageProcessor`] and
Expand All @@ -37,9 +58,9 @@ class OneFormerProcessor(ProcessorMixin):
The image processor is a required input.
tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]):
The tokenizer is a required input.
max_seq_len (`int`, *optional*, defaults to 77)):
max_seq_length (`int`, *optional*, defaults to 77):
Sequence length for input text list.
task_seq_len (`int`, *optional*, defaults to 77):
task_seq_length (`int`, *optional*, defaults to 77):
Sequence length for input task token.
"""

Expand All @@ -48,22 +69,23 @@ class OneFormerProcessor(ProcessorMixin):
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")

def __init__(
self, image_processor=None, tokenizer=None, max_seq_length: int = 77, task_seq_length: int = 77, **kwargs
self,
image_processor=None,
tokenizer=None,
max_seq_length: int = 77,
task_seq_length: int = 77,
):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")

self.max_seq_length = max_seq_length
self.task_seq_length = task_seq_length

super().__init__(image_processor, tokenizer)

def _preprocess_text(self, text_list=None, max_length=77):
if text_list is None:
raise ValueError("tokens cannot be None.")
self.max_seq_length = max_seq_length
self.task_seq_length = task_seq_length

def _preprocess_text(self, text_list: PreTokenizedInput, max_length: int = 77):
tokens = self.tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True)

attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"]
Expand All @@ -76,13 +98,41 @@ def _preprocess_text(self, text_list=None, max_length=77):
token_inputs = torch.cat(token_inputs, dim=0)
return token_inputs

def __call__(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs):
@staticmethod
def _check_args(
images: Optional[ImageInput] = None,
task_inputs: Optional[Union[TextInput, PreTokenizedInput]] = None,
):
if task_inputs is None:
raise ValueError("You have to specify the task_inputs. Found None.")
elif images is None:
raise ValueError("You have to specify the images. Found None.")

if isinstance(task_inputs, str):
task_inputs = [task_inputs]

if not isinstance(task_inputs, List) or not task_inputs:
raise TypeError("task_inputs should be a string or a list of strings.")

if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs):
raise ValueError("task_inputs must be semantic, instance, or panoptic.")

return task_inputs

def __call__(
self,
images: Optional[ImageInput] = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[OneFormerProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several task input(s) and image(s). This method forwards the
`task_inputs` and `kwargs` arguments to CLIPTokenizer's [`~CLIPTokenizer.__call__`] if `task_inputs` is not
`None` to encode. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
OneFormerImageProcessor's [`~OneFormerImageProcessor.__call__`] if `images` is not `None`. Please refer to the
doctsring of the above two methods for more information.
docstring of the above two methods for more information.
Args:
task_inputs (`str`, `List[str]`):
Expand All @@ -108,36 +158,28 @@ def __call__(self, images=None, task_inputs=None, segmentation_maps=None, **kwar
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""

if task_inputs is None:
raise ValueError("You have to specify the task_input. Found None.")
elif images is None:
raise ValueError("You have to specify the image. Found None.")

if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs):
raise ValueError("task_inputs must be semantic, instance, or panoptic.")
output_kwargs = self._merge_kwargs(
OneFormerProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
segmentation_maps = output_kwargs["images_kwargs"].pop("segmentation_maps", None)
task_inputs = output_kwargs["images_kwargs"].pop("task_inputs", None)
task_inputs = self._check_args(images, task_inputs)

encoded_inputs = self.image_processor(images, task_inputs, segmentation_maps, **kwargs)
encoded_inputs = self.image_processor(images, task_inputs, segmentation_maps, **output_kwargs["images_kwargs"])

if isinstance(task_inputs, str):
task_inputs = [task_inputs]

if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs):
task_token_inputs = []
for task in task_inputs:
task_input = f"the task is {task}"
task_token_inputs.append(task_input)
encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length)
else:
raise TypeError("Task Inputs should be a string or a list of strings.")
task_token_inputs = []
for task in task_inputs:
task_input = f"the task is {task}"
task_token_inputs.append(task_input)
encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length)

if hasattr(encoded_inputs, "text_inputs"):
texts_list = encoded_inputs.text_inputs

text_inputs = []
for texts in texts_list:
text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length)
text_inputs.append(text_input_list.unsqueeze(0))

text_inputs = [
self._preprocess_text(texts, max_length=self.max_seq_length).unsqueeze(0)
for texts in encoded_inputs.text_inputs
]
encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0)

return encoded_inputs
Expand All @@ -148,36 +190,21 @@ def encode_inputs(self, images=None, task_inputs=None, segmentation_maps=None, *
task_inputs. Please refer to the docstring of this method for more information.
"""

if task_inputs is None:
raise ValueError("You have to specify the task_input. Found None.")
elif images is None:
raise ValueError("You have to specify the image. Found None.")

if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs):
raise ValueError("task_inputs must be semantic, instance, or panoptic.")
task_inputs = self._check_args(images, task_inputs)

encoded_inputs = self.image_processor.encode_inputs(images, task_inputs, segmentation_maps, **kwargs)

if isinstance(task_inputs, str):
task_inputs = [task_inputs]

if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs):
task_token_inputs = []
for task in task_inputs:
task_input = f"the task is {task}"
task_token_inputs.append(task_input)
encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length)
else:
raise TypeError("Task Inputs should be a string or a list of strings.")
task_token_inputs = []
for task in task_inputs:
task_input = f"the task is {task}"
task_token_inputs.append(task_input)
encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length)

if hasattr(encoded_inputs, "text_inputs"):
texts_list = encoded_inputs.text_inputs

text_inputs = []
for texts in texts_list:
text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length)
text_inputs.append(text_input_list.unsqueeze(0))

text_inputs = [
self._preprocess_text(texts, max_length=self.max_seq_length).unsqueeze(0)
for texts in encoded_inputs.text_inputs
]
encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0)

return encoded_inputs
Expand Down
64 changes: 50 additions & 14 deletions tests/models/oneformer/test_processor_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@ def test_call_pil(self):
self.assertIsInstance(image, Image.Image)

# Test not batched input
encoded_images = processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values
encoded_images = processor(
image_inputs[0],
task_inputs=["semantic"],
return_tensors="pt",
).pixel_values

expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values(
image_inputs
Expand All @@ -233,7 +237,11 @@ def test_call_pil(self):
(1, self.processing_tester.num_channels, expected_height, expected_width),
)

tokenized_task_inputs = processor(image_inputs[0], ["semantic"], return_tensors="pt").task_inputs
tokenized_task_inputs = processor(
image_inputs[0],
task_inputs=["semantic"],
return_tensors="pt",
).task_inputs

self.assertEqual(
tokenized_task_inputs.shape,
Expand All @@ -245,7 +253,11 @@ def test_call_pil(self):
image_inputs, batched=True
)

encoded_images = processor(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt").pixel_values
encoded_images = processor(
image_inputs,
task_inputs=["semantic"] * len(image_inputs),
return_tensors="pt",
).pixel_values
self.assertEqual(
encoded_images.shape,
(
Expand All @@ -257,7 +269,7 @@ def test_call_pil(self):
)

tokenized_task_inputs = processor(
image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt"
image_inputs, task_inputs=["semantic"] * len(image_inputs), return_tensors="pt"
).task_inputs

self.assertEqual(
Expand All @@ -274,7 +286,11 @@ def test_call_numpy(self):
self.assertIsInstance(image, np.ndarray)

# Test not batched input
encoded_images = processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values
encoded_images = processor(
image_inputs[0],
task_inputs=["semantic"],
return_tensors="pt",
).pixel_values

expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values(
image_inputs
Expand All @@ -285,7 +301,11 @@ def test_call_numpy(self):
(1, self.processing_tester.num_channels, expected_height, expected_width),
)

tokenized_task_inputs = processor(image_inputs[0], ["semantic"], return_tensors="pt").task_inputs
tokenized_task_inputs = processor(
image_inputs[0],
task_inputs=["semantic"],
return_tensors="pt",
).task_inputs

self.assertEqual(
tokenized_task_inputs.shape,
Expand All @@ -297,7 +317,11 @@ def test_call_numpy(self):
image_inputs, batched=True
)

encoded_images = processor(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt").pixel_values
encoded_images = processor(
image_inputs,
task_inputs=["semantic"] * len(image_inputs),
return_tensors="pt",
).pixel_values
self.assertEqual(
encoded_images.shape,
(
Expand All @@ -309,7 +333,7 @@ def test_call_numpy(self):
)

tokenized_task_inputs = processor(
image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt"
image_inputs, task_inputs=["semantic"] * len(image_inputs), return_tensors="pt"
).task_inputs

self.assertEqual(
Expand All @@ -326,7 +350,11 @@ def test_call_pytorch(self):
self.assertIsInstance(image, torch.Tensor)

# Test not batched input
encoded_images = processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values
encoded_images = processor(
image_inputs[0],
task_inputs=["semantic"],
return_tensors="pt",
).pixel_values

expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values(
image_inputs
Expand All @@ -337,7 +365,11 @@ def test_call_pytorch(self):
(1, self.processing_tester.num_channels, expected_height, expected_width),
)

tokenized_task_inputs = processor(image_inputs[0], ["semantic"], return_tensors="pt").task_inputs
tokenized_task_inputs = processor(
image_inputs[0],
task_inputs=["semantic"],
return_tensors="pt",
).task_inputs

self.assertEqual(
tokenized_task_inputs.shape,
Expand All @@ -349,7 +381,11 @@ def test_call_pytorch(self):
image_inputs, batched=True
)

encoded_images = processor(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt").pixel_values
encoded_images = processor(
image_inputs,
task_inputs=["semantic"] * len(image_inputs),
return_tensors="pt",
).pixel_values
self.assertEqual(
encoded_images.shape,
(
Expand All @@ -361,7 +397,7 @@ def test_call_pytorch(self):
)

tokenized_task_inputs = processor(
image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt"
image_inputs, task_inputs=["semantic"] * len(image_inputs), return_tensors="pt"
).task_inputs

self.assertEqual(
Expand Down Expand Up @@ -389,8 +425,8 @@ def comm_get_processor_inputs(self, with_segmentation_maps=False, is_instance_ma

inputs = processor(
image_inputs,
["semantic"] * len(image_inputs),
annotations,
task_inputs=["semantic"] * len(image_inputs),
segmentation_maps=annotations,
return_tensors="pt",
instance_id_to_semantic_id=instance_id_to_semantic_id,
pad_and_return_pixel_mask=True,
Expand Down

0 comments on commit fad4111

Please sign in to comment.