From 6196c46553b95229283f981c9e101096d50d0012 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 18 Dec 2024 03:21:14 +0000 Subject: [PATCH] fix convnext --- .../models/convnext/image_processing_convnext_fast.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/convnext/image_processing_convnext_fast.py b/src/transformers/models/convnext/image_processing_convnext_fast.py index d2ffefb6bd3380..c4459c2bcd03c7 100644 --- a/src/transformers/models/convnext/image_processing_convnext_fast.py +++ b/src/transformers/models/convnext/image_processing_convnext_fast.py @@ -275,14 +275,12 @@ def preprocess( do_resize=do_resize, size=size, resample=resample, - do_center_crop=do_center_crop, crop_size=crop_size, do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, - do_convert_rgb=do_convert_rgb, return_tensors=return_tensors, data_format=data_format, **kwargs, @@ -291,8 +289,7 @@ def preprocess( # Group images by size for batched resizing grouped_images, grouped_images_index = group_images_by_shape(images) resized_images_grouped = {} - for shape, images in grouped_images.items(): - stacked_images = torch.stack(images, dim=0) + for shape, stacked_images in grouped_images.items(): if do_resize: stacked_images = self.resize( image=stacked_images, size=size, crop_pct=crop_pct, interpolation=interpolation @@ -304,8 +301,7 @@ def preprocess( # Needed in case do_resize is False, or resize returns images with different sizes grouped_images, grouped_images_index = group_images_by_shape(resized_images) processed_images_grouped = {} - for shape, images in grouped_images.items(): - stacked_images = torch.stack(images, dim=0) + for shape, stacked_images in grouped_images.items(): if do_center_crop: stacked_images = self.center_crop(stacked_images, crop_size) # Fused rescale and normalize