diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index d40b0165e33257..a84e098cb25fb6 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -147,6 +147,37 @@ def divide_to_patches( return patches +def group_images_by_shape( + images: List["torch.Tensor"], +) -> Tuple[Dict[Tuple[int, int], List["torch.Tensor"]], Dict[int, Tuple[Tuple[int, int], int]]]: + """ + Groups images by shape. + Returns a dictionary with the shape as key and a list of images with that shape as value, + and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value. + """ + grouped_images = {} + grouped_images_index = {} + for i, image in enumerate(images): + shape = image.shape[1:] + if shape not in grouped_images: + grouped_images[shape] = [] + grouped_images[shape].append(image) + grouped_images_index[i] = (shape, len(grouped_images[shape]) - 1) + return grouped_images, grouped_images_index + + +def reconstruct_images( + processed_images: Dict[Tuple[int, int], "torch.Tensor"], grouped_images_index: Dict[int, Tuple[int, int]] +) -> List["torch.Tensor"]: + """ + Reconstructs a list of images in the original order. + """ + return [ + processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]] + for i in range(len(grouped_images_index)) + ] + + @dataclass(frozen=True) class SizeDict: """ @@ -277,7 +308,7 @@ def resize( self, image: "torch.Tensor", size: Dict[str, int], - resample: "F.InterpolationMode" = None, + interpolation: "F.InterpolationMode" = None, **kwargs, ) -> "torch.Tensor": """ @@ -294,7 +325,7 @@ def resize( Returns: `np.ndarray`: The resized image. """ - resample = resample if resample is not None else F.InterpolationMode.BILINEAR + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR if size.shortest_edge and size.longest_edge: # Resize the image so that the shortest edge or the longest edge is of the given size # while maintaining the aspect ratio of the original image. @@ -319,7 +350,7 @@ def resize( "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got" f" {size.keys()}." ) - return F.resize(image, new_size, interpolation=resample) + return F.resize(image, new_size, interpolation=interpolation) def rescale( self, @@ -564,7 +595,6 @@ def preprocess( image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - return_tensors = "pt" if return_tensors is None else return_tensors device = kwargs.pop("device", None) images, image_mean, image_std, size, crop_size, interpolation = self.prepare_process_arguments( @@ -587,30 +617,37 @@ def preprocess( **kwargs, ) - processed_images = [] - for image in images: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, images in grouped_images.items(): + stacked_images = torch.stack(images, dim=0) if do_resize: - image = self.resize( - image=image, - size=size, - resample=interpolation, - ) - + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reconstruct_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, images in grouped_images.items(): + stacked_images = torch.stack(images, dim=0) if do_center_crop: - image = self.center_crop(image, crop_size) - + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize if do_rescale and do_normalize: - # fused rescale and normalize - image = self.normalize(image.to(dtype=torch.float32), image_mean, image_std) + stacked_images = self.normalize(stacked_images.to(dtype=torch.float32), image_mean, image_std) elif do_rescale: - image = image * rescale_factor + stacked_images = stacked_images * rescale_factor elif do_normalize: - image = self.normalize(image, image_mean, image_std) + stacked_images = self.normalize(stacked_images, image_mean, image_std) + processed_images_grouped[shape] = stacked_images - processed_images.append(image) - images = processed_images + processed_images = reconstruct_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return BatchFeature(data={"pixel_values": torch.stack(images, dim=0)}, tensor_type=return_tensors) + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) def to_dict(self): encoder_dict = super().to_dict() @@ -672,7 +709,11 @@ def prepare_images_structure( return make_batched_images(images) def _resize_for_patching( - self, image: "torch.Tensor", target_resolution: tuple, resample, input_data_format: ChannelDimension + self, + image: "torch.Tensor", + target_resolution: tuple, + interpolation: "F.InterpolationMode", + input_data_format: ChannelDimension, ) -> "torch.Tensor": """ Resizes an image to a target resolution while maintaining aspect ratio. @@ -682,7 +723,7 @@ def _resize_for_patching( The input image. target_resolution (tuple): The target resolution (height, width) of the image. - resample (`PILImageResampling`): + interpolation (`InterpolationMode`): Resampling filter to use if resizing the image. input_data_format (`ChannelDimension` or `str`): The channel dimension format of the input image. @@ -693,7 +734,7 @@ def _resize_for_patching( new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) # Resize the image - resized_image = F.resize(image, (new_height, new_width), interpolation=resample) + resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation) return resized_image @@ -719,7 +760,7 @@ def _get_image_patches( grid_pinpoints, size: tuple, patch_size: int, - resample: "F.InterpolationMode", + interpolation: "F.InterpolationMode", ) -> List[np.array]: """ Process an image with variable resolutions by dividing it into patches. @@ -733,7 +774,7 @@ def _get_image_patches( Size to resize the original image to. patch_size (`int`): Size of the patches to divide the image into. - resample (`"F.InterpolationMode"`): + interpolation (`"InterpolationMode"`): Resampling filter to use if resizing the image. Returns: @@ -747,11 +788,11 @@ def _get_image_patches( image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST) best_resolution = select_best_resolution(image_size, possible_resolutions) resized_image = self._resize_for_patching( - image, best_resolution, resample=resample, input_data_format=ChannelDimension.FIRST + image, best_resolution, interpolation=interpolation, input_data_format=ChannelDimension.FIRST ) padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=ChannelDimension.FIRST) patches = divide_to_patches(padded_image, patch_size=patch_size) - resized_original_image = F.resize(image, size=size, interpolation=resample) + resized_original_image = F.resize(image, size=size, interpolation=interpolation) image_patches = [resized_original_image] + patches @@ -907,54 +948,44 @@ def preprocess( image_grid_pinpoints, size=size_tuple, patch_size=patch_size, - resample=interpolation, + interpolation=interpolation, ) - processed_image_patches = {} - # stack image patches of the same size - grouped_image_patches = {} - grouped_image_patches_index = {} - for i, image_patch in enumerate(image_patches): - if image_patch.shape[1:] not in grouped_image_patches: - grouped_image_patches[image_patch.shape[1:]] = [image_patch] - - else: - grouped_image_patches[image_patch.shape[1:]].append(image_patch) - grouped_image_patches_index[i] = ( - image_patch.shape[1:], - len(grouped_image_patches[image_patch.shape[1:]]) - 1, - ) - for key, image_patch in grouped_image_patches.items(): - image_patch = torch.stack(image_patch, dim=0) + # Group images by size for batched processing + processed_image_patches_grouped = {} + grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches) + for shape, image_patches in grouped_image_patches.items(): + stacked_image_patches = torch.stack(image_patches, dim=0) if do_resize: - image_patch = self.resize( - image=image_patch, + stacked_image_patches = self.resize( + image=stacked_image_patches, size=size, - resample=interpolation, + interpolation=interpolation, ) if do_center_crop: - image_patch = self.center_crop(image_patch, crop_size) + stacked_image_patches = self.center_crop(stacked_image_patches, crop_size) + # Fused rescale and normalize if do_rescale and do_normalize: - # fused rescale and normalize - image_patch = self.normalize(image_patch.to(dtype=torch.float32), image_mean, image_std) + stacked_image_patches = self.normalize( + stacked_image_patches.to(dtype=torch.float32), image_mean, image_std + ) elif do_rescale: - image_patch = image_patch * rescale_factor + stacked_image_patches = stacked_image_patches * rescale_factor elif do_normalize: - image_patch = self.normalize(image_patch, image_mean, image_std) - processed_image_patches[key] = image_patch - processed_image_patches = [ - processed_image_patches[grouped_image_patches_index[i][0]][grouped_image_patches_index[i][1]] - for i in range(len(grouped_image_patches_index)) - ] - - processed_image_patches = torch.stack(processed_image_patches, dim=0) + stacked_image_patches = self.normalize(stacked_image_patches, image_mean, image_std) + processed_image_patches_grouped[shape] = stacked_image_patches + processed_image_patches = reconstruct_images(processed_image_patches_grouped, grouped_image_patches_index) + processed_image_patches = ( + torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches + ) processed_images.append(processed_image_patches) image_sizes.append(get_image_size(image, input_data_format)) - images = processed_images + if do_pad: - images = self._pad_for_batching(images) + processed_images = self._pad_for_batching(processed_images) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images return BatchFeature( - data={"pixel_values": torch.stack(images, dim=0), "image_sizes": image_sizes}, tensor_type=return_tensors + data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors ) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index d7d909e977ec5e..fec7f96e3815cf 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -512,7 +512,7 @@ def validate_fast_preprocess_arguments( resample=resample, ) # Extra checks for ImageProcessorFast - if return_tensors != "pt": + if return_tensors is not None and return_tensors != "pt": raise ValueError("Only returning PyTorch tensors is currently supported.") if data_format != ChannelDimension.FIRST: