Skip to content

Commit

Permalink
push clean Fast (x3!) image processor
Browse files Browse the repository at this point in the history
  • Loading branch information
molbap committed Dec 12, 2024
1 parent 6e0634b commit 8569fd0
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 18 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/molmo.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ print(processor.decode(output[0], skip_special_tokens=True))

[[autodoc]] MolmoImageProcessor

## MolmoImageProcessorFast

[[autodoc]] MolmoImageProcessorFast

## MolmoProcessor

[[autodoc]] MolmoProcessor
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,7 @@
_import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
_import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast")
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.molmo"].append("MolmoImageProcessorFast")
_import_structure["models.pixtral"].append("PixtralImageProcessorFast")
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
_import_structure["models.vit"].append("ViTImageProcessorFast")
Expand Down Expand Up @@ -6250,6 +6251,7 @@
from .image_processing_utils_fast import BaseImageProcessorFast
from .models.deformable_detr import DeformableDetrImageProcessorFast
from .models.detr import DetrImageProcessorFast
from .models.molmo import MolmoImageProcessorFast
from .models.pixtral import PixtralImageProcessorFast
from .models.rt_detr import RTDetrImageProcessorFast
from .models.vit import ViTImageProcessorFast
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/molmo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
if TYPE_CHECKING:
from .configuration_molmo import *
from .image_processing_molmo import *
from .image_processing_molmo_fast import *
from .modeling_molmo import *
from .processing_molmo import *
else:
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/molmo/convert_molmo_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,20 @@
import torch
from safetensors.torch import load_file

from transformers import GPT2TokenizerFast, Qwen2TokenizerFast
from transformers import (
GPT2TokenizerFast,
MolmoImageProcessor,
MolmoImageProcessorFast,
MolmoProcessor,
Qwen2TokenizerFast,
)
from transformers.models.molmo import MolmoForConditionalGeneration
from transformers.models.molmo.configuration_molmo import (
MolmoConfig,
MolmoPoolingConfig,
MolmoTextConfig,
MolmoVisionConfig,
)
from transformers.models.molmo.image_processing_molmo import MolmoImageProcessor
from transformers.models.molmo.processing_molmo import MolmoProcessor


CHAT_TEMPLATE = (
Expand Down Expand Up @@ -291,7 +295,8 @@ def write_model(
elif variant == "7B-O":
tokenizer = GPT2TokenizerFast.from_pretrained(input_base_path, extra_special_tokens=extra_special_tokens)
tokenizer.save_pretrained(model_path)
image_processor = MolmoImageProcessor.from_pretrained(input_base_path)
image_processor_class = MolmoImageProcessor if MolmoImageProcessorFast is None else MolmoImageProcessorFast
image_processor = image_processor_class.from_pretrained(input_base_path)
processor = MolmoProcessor(image_processor=image_processor, tokenizer=tokenizer, chat_template=CHAT_TEMPLATE)
processor.save_pretrained(model_path)
print("Processor saved successfully.")
Expand Down
29 changes: 15 additions & 14 deletions src/transformers/models/molmo/image_processing_molmo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

from ...feature_extraction_utils import BatchFeature
from ...image_processing_utils import BaseImageProcessor, get_size_dict
from ...image_processing_utils import get_size_dict
from ...image_processing_utils_fast import BaseImageProcessorFast
from ...image_transforms import convert_to_rgb
from ...image_utils import (
OPENAI_CLIP_MEAN,
Expand Down Expand Up @@ -89,7 +90,7 @@ def pad_to_bounding_box(
return padded_image


class MolmoImageProcessorFast(BaseImageProcessor):
class MolmoImageProcessorFast(BaseImageProcessorFast):
"""
Image processor for the Molmo model.
Expand Down Expand Up @@ -185,6 +186,11 @@ def __init__(
self.crop_window_size = self.crop_window_patches * self.image_patch_size
self.crop_size = size["width"]

if ((self.patches_per_image_height + 1) // 2 != self.tokens_per_image_height) or (
(self.patches_per_image_width + 1) // 2 != self.tokens_per_image_width
):
raise ValueError("Number of patches per crop does not fit number of tokens per image dimension.")

def resize(
self,
image: torch.Tensor,
Expand Down Expand Up @@ -294,12 +300,6 @@ def split_image_into_crops(
crops = []
cropped_masks = []
patch_orderings = []

if ((self.patches_per_image_height + 1) // 2 != self.tokens_per_image_height) or (
(self.patches_per_image_width + 1) // 2 != self.tokens_per_image_width
):
raise ValueError("Number of patches per crop does not fit number of tokens per image dimension.")

patch_index = 0
for row in range(crop_grid[0]):
crop_y_start = row * self.crop_window_size
Expand All @@ -324,7 +324,6 @@ def split_image_into_crops(

# Correct padding based on margins and offsets
crop_x_offset = self.overlap_margins[0] // 2 if column > 0 else 0

# Track patch ordering: generate an array representing the order of patches (overlaps (on crops))
reshaped_image = torch.arange(
patch_index,
Expand Down Expand Up @@ -356,7 +355,6 @@ def split_image_into_crops(
cropped_masks.append(cropped_mask)

patch_index += pooled_height * pooled_width

crops = torch.stack(crops)
patch_orderings = torch.stack(patch_orderings)
cropped_masks = torch.stack(cropped_masks)
Expand Down Expand Up @@ -492,18 +490,15 @@ def preprocess(
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
device = kwargs.pop("device", None)
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)

images = make_batched_images(images)
image_type = get_image_type(images[0])

if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]

if image_type == ImageType.PIL:
images = [F.pil_to_tensor(image) for image in images]
elif image_type == ImageType.NUMPY:
images = [torch.from_numpy(image).contiguous() for image in images]

all_images = []
all_crop_grids = []
all_cropped_masks = []
Expand All @@ -512,12 +507,15 @@ def preprocess(
for image in images:
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)

if do_resize:
global_image_size = get_resize_output_image_size(image, size)
global_image = self.resize(
image=image, size=global_image_size, resample=resample, input_data_format=input_data_format
)

crop_grid = self.find_best_crop_grid_for_image_size(image)

new_crop_size = {}
new_crop_size["height"] = crop_grid[0] * self.crop_window_size + self.total_margin_pixels
new_crop_size["width"] = crop_grid[1] * self.crop_window_size + self.total_margin_pixels
Expand All @@ -528,6 +526,7 @@ def preprocess(
image = self.resize(
image=image, size=crop_output_size, resample=resample, input_data_format=input_data_format
)

if do_pad:
image, image_mask = self.pad(
image=image, size=new_crop_size, input_data_format=input_data_format, constant_values=0
Expand All @@ -546,9 +545,10 @@ def preprocess(
global_image = (global_image - image_mean_tensor) / image_std_tensor

if do_split_into_crops:
crops, patch_orderings, cropped_masks = self.split_image_into_crops(
crops, patch_orderings, cropped_masks = self.fully_batched_split_image_into_crops(
image=image, image_mask=image_mask, crop_grid=crop_grid, input_data_format=input_data_format
)

patch_orderings = self.transpose_patch_orderings(crop_grid, patch_orderings)
global_image = self.reshape_into_patches(global_image, input_data_format=input_data_format)
crops = torch.cat([global_image.unsqueeze(0), crops], dim=0)
Expand All @@ -560,6 +560,7 @@ def preprocess(
all_crop_grids.append(crop_grid)
all_cropped_masks.append(cropped_masks)
all_patch_orderings.append(patch_orderings)

data = {
"pixel_values": all_images,
"crop_grids": all_crop_grids,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_torchvision_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])


class MolmoImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])


class PixtralImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]

Expand Down

0 comments on commit 8569fd0

Please sign in to comment.