Skip to content

Commit

Permalink
Optionally preprocess segmentation maps for MobileViT (#28420)
Browse files Browse the repository at this point in the history
* optionally preprocess segmentation maps for mobilevit

* changed pretrained model name to that of segmentation model

* removed voc-deeplabv3 from model archive list

* added preprocess_image and preprocess_mask methods for processing images and segmentation masks respectively

* added tests for segmentation masks based on segformer feature extractor

* use crop_size instead of size

* reverting to initial model
  • Loading branch information
harisankar95 authored Jan 11, 2024
1 parent 95091e1 commit d560637
Show file tree
Hide file tree
Showing 2 changed files with 290 additions and 40 deletions.
195 changes: 156 additions & 39 deletions src/transformers/models/mobilevit/image_processing_mobilevit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@
import numpy as np

from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
flip_channel_order,
get_resize_output_image_size,
resize,
to_channel_dimension_format,
)
from ...image_transforms import flip_channel_order, get_resize_output_image_size, resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
ImageInput,
Expand Down Expand Up @@ -178,9 +173,126 @@ def flip_channel_order(
"""
return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format)

def __call__(self, images, segmentation_maps=None, **kwargs):
"""
Preprocesses a batch of images and optionally segmentation maps.
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
passed in as positional arguments.
"""
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)

def _preprocess(
self,
image: ImageInput,
do_resize: bool,
do_rescale: bool,
do_center_crop: bool,
do_flip_channel_order: bool,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
rescale_factor: Optional[float] = None,
crop_size: Optional[Dict[str, int]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

if do_center_crop:
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)

if do_flip_channel_order:
image = self.flip_channel_order(image, input_data_format=input_data_format)

return image

def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_flip_channel_order: bool = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
image = to_numpy_array(image)
if is_scaled_image(image) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)

image = self._preprocess(
image=image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_flip_channel_order=do_flip_channel_order,
input_data_format=input_data_format,
)

image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)

return image

def _preprocess_mask(
self,
segmentation_map: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map)
# Add channel dimension if missing - needed for certain transformations
if segmentation_map.ndim == 2:
added_channel_dim = True
segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)

segmentation_map = self._preprocess(
image=segmentation_map,
do_resize=do_resize,
size=size,
resample=PILImageResampling.NEAREST,
do_rescale=False,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_flip_channel_order=False,
input_data_format=input_data_format,
)
# Remove extra channel dimension if added for processing
if added_channel_dim:
segmentation_map = segmentation_map.squeeze(0)
segmentation_map = segmentation_map.astype(np.int64)
return segmentation_map

def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
Expand All @@ -201,6 +313,8 @@ def preprocess(
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
segmentation_maps (`ImageInput`, *optional*):
Segmentation map to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Expand Down Expand Up @@ -251,13 +365,21 @@ def preprocess(
crop_size = get_size_dict(crop_size, param_name="crop_size")

images = make_list_of_images(images)
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)

if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")

Expand All @@ -267,45 +389,40 @@ def preprocess(
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")

# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]

if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
images = [
self._preprocess_image(
image=img,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_flip_channel_order=do_flip_channel_order,
data_format=data_format,
input_data_format=input_data_format,
)
for img in images
]

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])

if do_resize:
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]

if do_center_crop:
images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
data = {"pixel_values": images}

if do_rescale:
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
if segmentation_maps is not None:
segmentation_maps = [
self._preprocess_mask(
segmentation_map=segmentation_map,
do_resize=do_resize,
size=size,
do_center_crop=do_center_crop,
crop_size=crop_size,
input_data_format=input_data_format,
)
for segmentation_map in segmentation_maps
]

# the pretrained checkpoints assume images are BGR, not RGB
if do_flip_channel_order:
images = [self.flip_channel_order(image=image, input_data_format=input_data_format) for image in images]
data["labels"] = segmentation_maps

images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]

data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileViT
Expand Down
Loading

0 comments on commit d560637

Please sign in to comment.