From a3e7b6a97d97e252c23a976d8afeeba31e4f24ac Mon Sep 17 00:00:00 2001 From: Harisankar Babu <58052269+harisankar95@users.noreply.github.com> Date: Thu, 11 Jan 2024 15:52:14 +0100 Subject: [PATCH] Optionally preprocess segmentation maps for MobileViT (#28420) * optionally preprocess segmentation maps for mobilevit * changed pretrained model name to that of segmentation model * removed voc-deeplabv3 from model archive list * added preprocess_image and preprocess_mask methods for processing images and segmentation masks respectively * added tests for segmentation masks based on segformer feature extractor * use crop_size instead of size * reverting to initial model --- .../mobilevit/image_processing_mobilevit.py | 195 ++++++++++++++---- .../test_image_processing_mobilevit.py | 135 +++++++++++- 2 files changed, 290 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit.py b/src/transformers/models/mobilevit/image_processing_mobilevit.py index ee16d439cb6b05..2e7433fa02b8c7 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit.py @@ -19,12 +19,7 @@ import numpy as np from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict -from ...image_transforms import ( - flip_channel_order, - get_resize_output_image_size, - resize, - to_channel_dimension_format, -) +from ...image_transforms import flip_channel_order, get_resize_output_image_size, resize, to_channel_dimension_format from ...image_utils import ( ChannelDimension, ImageInput, @@ -178,9 +173,126 @@ def flip_channel_order( """ return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format) + def __call__(self, images, segmentation_maps=None, **kwargs): + """ + Preprocesses a batch of images and optionally segmentation maps. + + Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be + passed in as positional arguments. + """ + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + + def _preprocess( + self, + image: ImageInput, + do_resize: bool, + do_rescale: bool, + do_center_crop: bool, + do_flip_channel_order: bool, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + rescale_factor: Optional[float] = None, + crop_size: Optional[Dict[str, int]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_flip_channel_order: + image = self.flip_channel_order(image, input_data_format=input_data_format) + + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_center_crop: bool = None, + crop_size: Dict[str, int] = None, + do_flip_channel_order: bool = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + image = to_numpy_array(image) + if is_scaled_image(image) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + image = self._preprocess( + image=image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_flip_channel_order=do_flip_channel_order, + input_data_format=input_data_format, + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image + + def _preprocess_mask( + self, + segmentation_map: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + do_center_crop: bool = None, + crop_size: Dict[str, int] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single mask.""" + segmentation_map = to_numpy_array(segmentation_map) + # Add channel dimension if missing - needed for certain transformations + if segmentation_map.ndim == 2: + added_channel_dim = True + segmentation_map = segmentation_map[None, ...] + input_data_format = ChannelDimension.FIRST + else: + added_channel_dim = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1) + + segmentation_map = self._preprocess( + image=segmentation_map, + do_resize=do_resize, + size=size, + resample=PILImageResampling.NEAREST, + do_rescale=False, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_flip_channel_order=False, + input_data_format=input_data_format, + ) + # Remove extra channel dimension if added for processing + if added_channel_dim: + segmentation_map = segmentation_map.squeeze(0) + segmentation_map = segmentation_map.astype(np.int64) + return segmentation_map + def preprocess( self, images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, do_resize: bool = None, size: Dict[str, int] = None, resample: PILImageResampling = None, @@ -201,6 +313,8 @@ def preprocess( images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + segmentation_maps (`ImageInput`, *optional*): + Segmentation map to preprocess. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): @@ -251,6 +365,8 @@ def preprocess( crop_size = get_size_dict(crop_size, param_name="crop_size") images = make_list_of_images(images) + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) if not valid_images(images): raise ValueError( @@ -258,6 +374,12 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) + if segmentation_maps is not None and not valid_images(segmentation_maps): + raise ValueError( + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + if do_resize and size is None: raise ValueError("Size must be specified if do_resize is True.") @@ -267,45 +389,40 @@ def preprocess( if do_center_crop and crop_size is None: raise ValueError("Crop size must be specified if do_center_crop is True.") - # All transformations expect numpy arrays. - images = [to_numpy_array(image) for image in images] - - if is_scaled_image(images[0]) and do_rescale: - logger.warning_once( - "It looks like you are trying to rescale already rescaled images. If the input" - " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + images = [ + self._preprocess_image( + image=img, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_flip_channel_order=do_flip_channel_order, + data_format=data_format, + input_data_format=input_data_format, ) + for img in images + ] - if input_data_format is None: - # We assume that all images have the same channel dimension format. - input_data_format = infer_channel_dimension_format(images[0]) - - if do_resize: - images = [ - self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) - for image in images - ] - - if do_center_crop: - images = [ - self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images - ] + data = {"pixel_values": images} - if do_rescale: - images = [ - self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) - for image in images + if segmentation_maps is not None: + segmentation_maps = [ + self._preprocess_mask( + segmentation_map=segmentation_map, + do_resize=do_resize, + size=size, + do_center_crop=do_center_crop, + crop_size=crop_size, + input_data_format=input_data_format, + ) + for segmentation_map in segmentation_maps ] - # the pretrained checkpoints assume images are BGR, not RGB - if do_flip_channel_order: - images = [self.flip_channel_order(image=image, input_data_format=input_data_format) for image in images] + data["labels"] = segmentation_maps - images = [ - to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images - ] - - data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileViT diff --git a/tests/models/mobilevit/test_image_processing_mobilevit.py b/tests/models/mobilevit/test_image_processing_mobilevit.py index edfdf0aed559d1..92e1a55947b168 100644 --- a/tests/models/mobilevit/test_image_processing_mobilevit.py +++ b/tests/models/mobilevit/test_image_processing_mobilevit.py @@ -16,13 +16,20 @@ import unittest +from datasets import load_dataset + from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torch_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs +if is_torch_available(): + import torch + if is_vision_available(): + from PIL import Image + from transformers import MobileViTImageProcessor @@ -79,6 +86,26 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F ) +def prepare_semantic_single_inputs(): + dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + + image = Image.open(dataset[0]["file"]) + map = Image.open(dataset[1]["file"]) + + return image, map + + +def prepare_semantic_batch_inputs(): + dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + + image1 = Image.open(dataset[0]["file"]) + map1 = Image.open(dataset[1]["file"]) + image2 = Image.open(dataset[2]["file"]) + map2 = Image.open(dataset[3]["file"]) + + return [image1, image2], [map1, map2] + + @require_torch @require_vision class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): @@ -107,3 +134,109 @@ def test_image_processor_from_dict_with_kwargs(self): image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) self.assertEqual(image_processor.size, {"shortest_edge": 42}) self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + + def test_call_segmentation_maps(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + maps = [] + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + maps.append(torch.zeros(image.shape[-2:]).long()) + + # Test not batched input + encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test batched + encoding = image_processing(image_inputs, maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test not batched input (PIL images) + image, segmentation_map = prepare_semantic_single_inputs() + + encoding = image_processing(image, segmentation_map, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test batched input (PIL images) + images, segmentation_maps = prepare_semantic_batch_inputs() + + encoding = image_processing(images, segmentation_maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 2, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 2, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255)