Skip to content

Commit

Permalink
group images by sizes and add batch processing
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Dec 11, 2024
1 parent 8e7e910 commit d5e23ea
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 64 deletions.
157 changes: 94 additions & 63 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,37 @@ def divide_to_patches(
return patches


def group_images_by_shape(
images: List["torch.Tensor"],
) -> Tuple[Dict[Tuple[int, int], List["torch.Tensor"]], Dict[int, Tuple[Tuple[int, int], int]]]:
"""
Groups images by shape.
Returns a dictionary with the shape as key and a list of images with that shape as value,
and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value.
"""
grouped_images = {}
grouped_images_index = {}
for i, image in enumerate(images):
shape = image.shape[1:]
if shape not in grouped_images:
grouped_images[shape] = []
grouped_images[shape].append(image)
grouped_images_index[i] = (shape, len(grouped_images[shape]) - 1)
return grouped_images, grouped_images_index


def reconstruct_images(
processed_images: Dict[Tuple[int, int], "torch.Tensor"], grouped_images_index: Dict[int, Tuple[int, int]]
) -> List["torch.Tensor"]:
"""
Reconstructs a list of images in the original order.
"""
return [
processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
for i in range(len(grouped_images_index))
]


@dataclass(frozen=True)
class SizeDict:
"""
Expand Down Expand Up @@ -277,7 +308,7 @@ def resize(
self,
image: "torch.Tensor",
size: Dict[str, int],
resample: "F.InterpolationMode" = None,
interpolation: "F.InterpolationMode" = None,
**kwargs,
) -> "torch.Tensor":
"""
Expand All @@ -294,7 +325,7 @@ def resize(
Returns:
`np.ndarray`: The resized image.
"""
resample = resample if resample is not None else F.InterpolationMode.BILINEAR
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
if size.shortest_edge and size.longest_edge:
# Resize the image so that the shortest edge or the longest edge is of the given size
# while maintaining the aspect ratio of the original image.
Expand All @@ -319,7 +350,7 @@ def resize(
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
f" {size.keys()}."
)
return F.resize(image, new_size, interpolation=resample)
return F.resize(image, new_size, interpolation=interpolation)

def rescale(
self,
Expand Down Expand Up @@ -564,7 +595,6 @@ def preprocess(
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
return_tensors = "pt" if return_tensors is None else return_tensors
device = kwargs.pop("device", None)

images, image_mean, image_std, size, crop_size, interpolation = self.prepare_process_arguments(
Expand All @@ -587,30 +617,37 @@ def preprocess(
**kwargs,
)

processed_images = []
for image in images:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
for shape, images in grouped_images.items():
stacked_images = torch.stack(images, dim=0)
if do_resize:
image = self.resize(
image=image,
size=size,
resample=interpolation,
)

stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
resized_images = reconstruct_images(resized_images_grouped, grouped_images_index)

# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {}
for shape, images in grouped_images.items():
stacked_images = torch.stack(images, dim=0)
if do_center_crop:
image = self.center_crop(image, crop_size)

stacked_images = self.center_crop(stacked_images, crop_size)
# Fused rescale and normalize
if do_rescale and do_normalize:
# fused rescale and normalize
image = self.normalize(image.to(dtype=torch.float32), image_mean, image_std)
stacked_images = self.normalize(stacked_images.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
stacked_images = stacked_images * rescale_factor
elif do_normalize:
image = self.normalize(image, image_mean, image_std)
stacked_images = self.normalize(stacked_images, image_mean, image_std)
processed_images_grouped[shape] = stacked_images

processed_images.append(image)
images = processed_images
processed_images = reconstruct_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

return BatchFeature(data={"pixel_values": torch.stack(images, dim=0)}, tensor_type=return_tensors)
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

def to_dict(self):
encoder_dict = super().to_dict()
Expand Down Expand Up @@ -672,7 +709,11 @@ def prepare_images_structure(
return make_batched_images(images)

def _resize_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, resample, input_data_format: ChannelDimension
self,
image: "torch.Tensor",
target_resolution: tuple,
interpolation: "F.InterpolationMode",
input_data_format: ChannelDimension,
) -> "torch.Tensor":
"""
Resizes an image to a target resolution while maintaining aspect ratio.
Expand All @@ -682,7 +723,7 @@ def _resize_for_patching(
The input image.
target_resolution (tuple):
The target resolution (height, width) of the image.
resample (`PILImageResampling`):
interpolation (`InterpolationMode`):
Resampling filter to use if resizing the image.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Expand All @@ -693,7 +734,7 @@ def _resize_for_patching(
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)

# Resize the image
resized_image = F.resize(image, (new_height, new_width), interpolation=resample)
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)

return resized_image

Expand All @@ -719,7 +760,7 @@ def _get_image_patches(
grid_pinpoints,
size: tuple,
patch_size: int,
resample: "F.InterpolationMode",
interpolation: "F.InterpolationMode",
) -> List[np.array]:
"""
Process an image with variable resolutions by dividing it into patches.
Expand All @@ -733,7 +774,7 @@ def _get_image_patches(
Size to resize the original image to.
patch_size (`int`):
Size of the patches to divide the image into.
resample (`"F.InterpolationMode"`):
interpolation (`"InterpolationMode"`):
Resampling filter to use if resizing the image.
Returns:
Expand All @@ -747,11 +788,11 @@ def _get_image_patches(
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
best_resolution = select_best_resolution(image_size, possible_resolutions)
resized_image = self._resize_for_patching(
image, best_resolution, resample=resample, input_data_format=ChannelDimension.FIRST
image, best_resolution, interpolation=interpolation, input_data_format=ChannelDimension.FIRST
)
padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=ChannelDimension.FIRST)
patches = divide_to_patches(padded_image, patch_size=patch_size)
resized_original_image = F.resize(image, size=size, interpolation=resample)
resized_original_image = F.resize(image, size=size, interpolation=interpolation)

image_patches = [resized_original_image] + patches

Expand Down Expand Up @@ -907,54 +948,44 @@ def preprocess(
image_grid_pinpoints,
size=size_tuple,
patch_size=patch_size,
resample=interpolation,
interpolation=interpolation,
)
processed_image_patches = {}
# stack image patches of the same size
grouped_image_patches = {}
grouped_image_patches_index = {}
for i, image_patch in enumerate(image_patches):
if image_patch.shape[1:] not in grouped_image_patches:
grouped_image_patches[image_patch.shape[1:]] = [image_patch]

else:
grouped_image_patches[image_patch.shape[1:]].append(image_patch)
grouped_image_patches_index[i] = (
image_patch.shape[1:],
len(grouped_image_patches[image_patch.shape[1:]]) - 1,
)

for key, image_patch in grouped_image_patches.items():
image_patch = torch.stack(image_patch, dim=0)
# Group images by size for batched processing
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
for shape, image_patches in grouped_image_patches.items():
stacked_image_patches = torch.stack(image_patches, dim=0)
if do_resize:
image_patch = self.resize(
image=image_patch,
stacked_image_patches = self.resize(
image=stacked_image_patches,
size=size,
resample=interpolation,
interpolation=interpolation,
)
if do_center_crop:
image_patch = self.center_crop(image_patch, crop_size)
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
# Fused rescale and normalize
if do_rescale and do_normalize:
# fused rescale and normalize
image_patch = self.normalize(image_patch.to(dtype=torch.float32), image_mean, image_std)
stacked_image_patches = self.normalize(
stacked_image_patches.to(dtype=torch.float32), image_mean, image_std
)
elif do_rescale:
image_patch = image_patch * rescale_factor
stacked_image_patches = stacked_image_patches * rescale_factor
elif do_normalize:
image_patch = self.normalize(image_patch, image_mean, image_std)
processed_image_patches[key] = image_patch
processed_image_patches = [
processed_image_patches[grouped_image_patches_index[i][0]][grouped_image_patches_index[i][1]]
for i in range(len(grouped_image_patches_index))
]

processed_image_patches = torch.stack(processed_image_patches, dim=0)
stacked_image_patches = self.normalize(stacked_image_patches, image_mean, image_std)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reconstruct_images(processed_image_patches_grouped, grouped_image_patches_index)
processed_image_patches = (
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
)
processed_images.append(processed_image_patches)
image_sizes.append(get_image_size(image, input_data_format))
images = processed_images

if do_pad:
images = self._pad_for_batching(images)
processed_images = self._pad_for_batching(processed_images)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": torch.stack(images, dim=0), "image_sizes": image_sizes}, tensor_type=return_tensors
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
)


Expand Down
2 changes: 1 addition & 1 deletion src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def validate_fast_preprocess_arguments(
resample=resample,
)
# Extra checks for ImageProcessorFast
if return_tensors != "pt":
if return_tensors is not None and return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")

if data_format != ChannelDimension.FIRST:
Expand Down

0 comments on commit d5e23ea

Please sign in to comment.