Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally preprocess segmentation maps for MobileViT #28420

Merged
merged 7 commits into from
Jan 11, 2024
195 changes: 156 additions & 39 deletions src/transformers/models/mobilevit/image_processing_mobilevit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@
import numpy as np

from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
flip_channel_order,
get_resize_output_image_size,
resize,
to_channel_dimension_format,
)
from ...image_transforms import flip_channel_order, get_resize_output_image_size, resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
ImageInput,
Expand Down Expand Up @@ -178,9 +173,126 @@ def flip_channel_order(
"""
return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format)

def __call__(self, images, segmentation_maps=None, **kwargs):
"""
Preprocesses a batch of images and optionally segmentation maps.
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
passed in as positional arguments.
"""
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)

def _preprocess(
self,
image: ImageInput,
do_resize: bool,
do_rescale: bool,
do_center_crop: bool,
do_flip_channel_order: bool,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
rescale_factor: Optional[float] = None,
crop_size: Optional[Dict[str, int]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

if do_center_crop:
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)

if do_flip_channel_order:
image = self.flip_channel_order(image, input_data_format=input_data_format)

return image

def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
do_flip_channel_order: bool = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
image = to_numpy_array(image)
if is_scaled_image(image) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)

image = self._preprocess(
image=image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_flip_channel_order=do_flip_channel_order,
input_data_format=input_data_format,
)

image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)

return image

def _preprocess_mask(
self,
segmentation_map: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
do_center_crop: bool = None,
crop_size: Dict[str, int] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map)
# Add channel dimension if missing - needed for certain transformations
if segmentation_map.ndim == 2:
added_channel_dim = True
segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)

segmentation_map = self._preprocess(
image=segmentation_map,
do_resize=do_resize,
size=size,
resample=PILImageResampling.NEAREST,
do_rescale=False,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_flip_channel_order=False,
input_data_format=input_data_format,
)
# Remove extra channel dimension if added for processing
if added_channel_dim:
segmentation_map = segmentation_map.squeeze(0)
segmentation_map = segmentation_map.astype(np.int64)
return segmentation_map

def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
Expand All @@ -201,6 +313,8 @@ def preprocess(
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
segmentation_maps (`ImageInput`, *optional*):
Segmentation map to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Expand Down Expand Up @@ -251,13 +365,21 @@ def preprocess(
crop_size = get_size_dict(crop_size, param_name="crop_size")

images = make_list_of_images(images)
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)

if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")

Expand All @@ -267,45 +389,40 @@ def preprocess(
if do_center_crop and crop_size is None:
raise ValueError("Crop size must be specified if do_center_crop is True.")

# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]

if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
images = [
self._preprocess_image(
image=img,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_flip_channel_order=do_flip_channel_order,
data_format=data_format,
input_data_format=input_data_format,
)
for img in images
]

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])

if do_resize:
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]

if do_center_crop:
images = [
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
]
data = {"pixel_values": images}

if do_rescale:
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
if segmentation_maps is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add two new methods to keep in line with other models e.g. SAM here:

  • preprocess_image
  • preprocess_mask

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added additional methods to preprocess images and masks seperately in 877b81a

segmentation_maps = [
self._preprocess_mask(
segmentation_map=segmentation_map,
do_resize=do_resize,
size=size,
do_center_crop=do_center_crop,
crop_size=crop_size,
input_data_format=input_data_format,
)
for segmentation_map in segmentation_maps
]

# the pretrained checkpoints assume images are BGR, not RGB
if do_flip_channel_order:
images = [self.flip_channel_order(image=image, input_data_format=input_data_format) for image in images]
data["labels"] = segmentation_maps

images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]

data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileViT
Expand Down
Loading
Loading