Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuyu Multi-image interleaved processor #27587

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@ tags

# ruff
.ruff_cache
test.py
144 changes: 49 additions & 95 deletions src/transformers/models/fuyu/image_processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,6 @@
logger = logging.get_logger(__name__)


def make_list_of_list_of_images(
images: Union[List[List[ImageInput]], List[ImageInput], ImageInput]
) -> List[List[ImageInput]]:
if is_valid_image(images):
return [[images]]

if isinstance(images, list) and all(isinstance(image, list) for image in images):
return images

if isinstance(images, list):
return [make_list_of_images(image) for image in images]

raise ValueError("images must be a list of list of images or a list of images or an image.")


class FuyuBatchFeature(BatchFeature):
"""
BatchFeature class for Fuyu image processor and processor.
Expand Down Expand Up @@ -356,7 +341,7 @@ def pad_image(
input_data_format=input_data_format,
)
return padded_image

def preprocess(
self,
images,
Expand Down Expand Up @@ -441,93 +426,62 @@ def preprocess(
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
patch_size = patch_size if patch_size is not None else self.patch_size

if isinstance(images, list) and any(isinstance(elem, list) and len(elem) >= 2 for elem in images):
raise ValueError("Multiple images for a single sample are not yet supported.")

batch_images = make_list_of_list_of_images(images)

if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")

if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")

if do_normalize and image_mean is None or image_std is None:
raise ValueError("image_mean and image_std must be specified if do_normalize is True.")

# All transformations expect numpy arrays.
batch_images = [[to_numpy_array(image) for image in images] for images in batch_images]

if is_scaled_image(batch_images[0][0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(batch_images[0][0])

original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]

if do_resize:
batch_images = [
[self.resize(image, size=size, input_data_format=input_data_format) for image in images]
for images in batch_images
]

image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
image_unpadded_heights = [[image_size[0]] for image_size in image_sizes]
image_unpadded_widths = [[image_size[1]] for image_size in image_sizes]

# scale_h is the same as scale_w
image_scale_factors = [
[resized_size[0] / original_size[0]]
for original_size, resized_size in zip(original_image_sizes, image_sizes)
]
batch_images = images
original_image_sizes = []
batch_image_sizes = []
image_unpadded_heights = []
image_unpadded_widths = []
image_scale_factors = []

for image_list in batch_images:
original_sizes_per_list = []
batch_sizes_per_list = []
unpadded_heights_per_list = []
unpadded_widths_per_list = []
scale_factors_per_list = []

#If there is no image in the list, make a placeholder image then preprocess the image
for idx, image in enumerate(image_list if image_list else [np.zeros((size['height'], size['width'], 3), dtype=np.uint8)]):
image = to_numpy_array(image) if image_list else image

original_size = get_image_size(image, channel_dim=input_data_format)
original_sizes_per_list.append(original_size)

if do_resize:
image = self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
if do_pad:
image = self.pad_image(image, size=size, mode=padding_mode, constant_values=padding_value, input_data_format=input_data_format)
if do_rescale:
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
if data_format is not None:
image = to_channel_dimension_format(image, data_format, input_data_format)

batch_size = get_image_size(image, channel_dim=input_data_format)
batch_sizes_per_list.append(batch_size)
unpadded_heights_per_list.append([batch_size[0]])
unpadded_widths_per_list.append([batch_size[1]])
scale_factors_per_list.append([batch_size[0] / original_size[0]])

if not image_list:
image_list.append(image)
else:
image_list[idx] = image

if do_pad:
batch_images = [
[
self.pad_image(
image,
size=size,
mode=padding_mode,
constant_values=padding_value,
input_data_format=input_data_format,
)
for image in images
]
for images in batch_images
]

if do_rescale:
batch_images = [
[self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) for image in images]
for images in batch_images
]

if do_normalize:
batch_images = [
[
self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
for images in batch_images
]

if data_format is not None:
batch_images = [
[to_channel_dimension_format(image, data_format, input_data_format) for image in images]
for images in batch_images
]
original_image_sizes.append(original_sizes_per_list)
batch_image_sizes.append(batch_sizes_per_list)
image_unpadded_heights.append(unpadded_heights_per_list)
image_unpadded_widths.append(unpadded_widths_per_list)
image_scale_factors.append(scale_factors_per_list)

data = {
"images": batch_images,
"image_unpadded_heights": image_unpadded_heights,
"image_unpadded_widths": image_unpadded_widths,
"image_scale_factors": image_scale_factors,
}

return FuyuBatchFeature(data=data, tensor_type=return_tensors)

def get_num_patches(self, image_height: int, image_width: int, patch_size: Dict[str, int] = None) -> int:
Expand Down
Loading