diff --git a/docs/source/en/model_doc/molmo.md b/docs/source/en/model_doc/molmo.md index ff0f8fa4571ae8..8c7703133a0b37 100644 --- a/docs/source/en/model_doc/molmo.md +++ b/docs/source/en/model_doc/molmo.md @@ -98,6 +98,10 @@ print(processor.decode(output[0], skip_special_tokens=True)) [[autodoc]] MolmoImageProcessor +## MolmoImageProcessorFast + +[[autodoc]] MolmoImageProcessorFast + ## MolmoProcessor [[autodoc]] MolmoProcessor diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6629abe282699b..62517446acbb45 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1277,6 +1277,7 @@ _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"] _import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast") _import_structure["models.detr"].append("DetrImageProcessorFast") + _import_structure["models.molmo"].append("MolmoImageProcessorFast") _import_structure["models.pixtral"].append("PixtralImageProcessorFast") _import_structure["models.rt_detr"].append("RTDetrImageProcessorFast") _import_structure["models.vit"].append("ViTImageProcessorFast") @@ -6250,6 +6251,7 @@ from .image_processing_utils_fast import BaseImageProcessorFast from .models.deformable_detr import DeformableDetrImageProcessorFast from .models.detr import DetrImageProcessorFast + from .models.molmo import MolmoImageProcessorFast from .models.pixtral import PixtralImageProcessorFast from .models.rt_detr import RTDetrImageProcessorFast from .models.vit import ViTImageProcessorFast diff --git a/src/transformers/models/molmo/__init__.py b/src/transformers/models/molmo/__init__.py index f69497707ab6b8..ed0c568ee1077c 100644 --- a/src/transformers/models/molmo/__init__.py +++ b/src/transformers/models/molmo/__init__.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .configuration_molmo import * from .image_processing_molmo import * + from .image_processing_molmo_fast import * from .modeling_molmo import * from .processing_molmo import * else: diff --git a/src/transformers/models/molmo/convert_molmo_weights_to_hf.py b/src/transformers/models/molmo/convert_molmo_weights_to_hf.py index d64b5ab91f137b..310e6158c83d55 100644 --- a/src/transformers/models/molmo/convert_molmo_weights_to_hf.py +++ b/src/transformers/models/molmo/convert_molmo_weights_to_hf.py @@ -23,7 +23,13 @@ import torch from safetensors.torch import load_file -from transformers import GPT2TokenizerFast, Qwen2TokenizerFast +from transformers import ( + GPT2TokenizerFast, + MolmoImageProcessor, + MolmoImageProcessorFast, + MolmoProcessor, + Qwen2TokenizerFast, +) from transformers.models.molmo import MolmoForConditionalGeneration from transformers.models.molmo.configuration_molmo import ( MolmoConfig, @@ -31,8 +37,6 @@ MolmoTextConfig, MolmoVisionConfig, ) -from transformers.models.molmo.image_processing_molmo import MolmoImageProcessor -from transformers.models.molmo.processing_molmo import MolmoProcessor CHAT_TEMPLATE = ( @@ -291,7 +295,8 @@ def write_model( elif variant == "7B-O": tokenizer = GPT2TokenizerFast.from_pretrained(input_base_path, extra_special_tokens=extra_special_tokens) tokenizer.save_pretrained(model_path) - image_processor = MolmoImageProcessor.from_pretrained(input_base_path) + image_processor_class = MolmoImageProcessor if MolmoImageProcessorFast is None else MolmoImageProcessorFast + image_processor = image_processor_class.from_pretrained(input_base_path) processor = MolmoProcessor(image_processor=image_processor, tokenizer=tokenizer, chat_template=CHAT_TEMPLATE) processor.save_pretrained(model_path) print("Processor saved successfully.") diff --git a/src/transformers/models/molmo/image_processing_molmo_fast.py b/src/transformers/models/molmo/image_processing_molmo_fast.py index a36653d2825c4e..1591e4990eaec7 100644 --- a/src/transformers/models/molmo/image_processing_molmo_fast.py +++ b/src/transformers/models/molmo/image_processing_molmo_fast.py @@ -17,7 +17,8 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from ...feature_extraction_utils import BatchFeature -from ...image_processing_utils import BaseImageProcessor, get_size_dict +from ...image_processing_utils import get_size_dict +from ...image_processing_utils_fast import BaseImageProcessorFast from ...image_transforms import convert_to_rgb from ...image_utils import ( OPENAI_CLIP_MEAN, @@ -89,7 +90,7 @@ def pad_to_bounding_box( return padded_image -class MolmoImageProcessorFast(BaseImageProcessor): +class MolmoImageProcessorFast(BaseImageProcessorFast): """ Image processor for the Molmo model. @@ -185,6 +186,11 @@ def __init__( self.crop_window_size = self.crop_window_patches * self.image_patch_size self.crop_size = size["width"] + if ((self.patches_per_image_height + 1) // 2 != self.tokens_per_image_height) or ( + (self.patches_per_image_width + 1) // 2 != self.tokens_per_image_width + ): + raise ValueError("Number of patches per crop does not fit number of tokens per image dimension.") + def resize( self, image: torch.Tensor, @@ -294,12 +300,6 @@ def split_image_into_crops( crops = [] cropped_masks = [] patch_orderings = [] - - if ((self.patches_per_image_height + 1) // 2 != self.tokens_per_image_height) or ( - (self.patches_per_image_width + 1) // 2 != self.tokens_per_image_width - ): - raise ValueError("Number of patches per crop does not fit number of tokens per image dimension.") - patch_index = 0 for row in range(crop_grid[0]): crop_y_start = row * self.crop_window_size @@ -324,7 +324,6 @@ def split_image_into_crops( # Correct padding based on margins and offsets crop_x_offset = self.overlap_margins[0] // 2 if column > 0 else 0 - # Track patch ordering: generate an array representing the order of patches (overlaps (on crops)) reshaped_image = torch.arange( patch_index, @@ -356,7 +355,6 @@ def split_image_into_crops( cropped_masks.append(cropped_mask) patch_index += pooled_height * pooled_width - crops = torch.stack(crops) patch_orderings = torch.stack(patch_orderings) cropped_masks = torch.stack(cropped_masks) @@ -492,10 +490,8 @@ def preprocess( do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb device = kwargs.pop("device", None) validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) - images = make_batched_images(images) image_type = get_image_type(images[0]) - if do_convert_rgb: images = [convert_to_rgb(image) for image in images] @@ -503,7 +499,6 @@ def preprocess( images = [F.pil_to_tensor(image) for image in images] elif image_type == ImageType.NUMPY: images = [torch.from_numpy(image).contiguous() for image in images] - all_images = [] all_crop_grids = [] all_cropped_masks = [] @@ -512,12 +507,15 @@ def preprocess( for image in images: if input_data_format is None: input_data_format = infer_channel_dimension_format(image) + if do_resize: global_image_size = get_resize_output_image_size(image, size) global_image = self.resize( image=image, size=global_image_size, resample=resample, input_data_format=input_data_format ) + crop_grid = self.find_best_crop_grid_for_image_size(image) + new_crop_size = {} new_crop_size["height"] = crop_grid[0] * self.crop_window_size + self.total_margin_pixels new_crop_size["width"] = crop_grid[1] * self.crop_window_size + self.total_margin_pixels @@ -528,6 +526,7 @@ def preprocess( image = self.resize( image=image, size=crop_output_size, resample=resample, input_data_format=input_data_format ) + if do_pad: image, image_mask = self.pad( image=image, size=new_crop_size, input_data_format=input_data_format, constant_values=0 @@ -546,9 +545,10 @@ def preprocess( global_image = (global_image - image_mean_tensor) / image_std_tensor if do_split_into_crops: - crops, patch_orderings, cropped_masks = self.split_image_into_crops( + crops, patch_orderings, cropped_masks = self.fully_batched_split_image_into_crops( image=image, image_mask=image_mask, crop_grid=crop_grid, input_data_format=input_data_format ) + patch_orderings = self.transpose_patch_orderings(crop_grid, patch_orderings) global_image = self.reshape_into_patches(global_image, input_data_format=input_data_format) crops = torch.cat([global_image.unsqueeze(0), crops], dim=0) @@ -560,6 +560,7 @@ def preprocess( all_crop_grids.append(crop_grid) all_cropped_masks.append(cropped_masks) all_patch_orderings.append(patch_orderings) + data = { "pixel_values": all_images, "crop_grids": all_crop_grids, diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index 747f75386490fc..a12ddabf58619a 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -23,6 +23,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torchvision"]) +class MolmoImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class PixtralImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"]