Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimized PixtralImageProcessorFast #34836

Merged
merged 15 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
"{processor_class}": "FakeProcessorClass",
"{model_class}": "FakeModelClass",
"{object_class}": "FakeObjectClass",
}
}
mgoin marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/pixtral.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up
[[autodoc]] PixtralImageProcessor
- preprocess

## PixtralImageProcessorFast

[[autodoc]] PixtralImageProcessorFast
- preprocess

## PixtralProcessor

[[autodoc]] PixtralProcessor
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@
_import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"])
_import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"])
_import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
_import_structure["models.pixtral"].append("PixtralImageProcessor")
_import_structure["models.pixtral"].extend(["PixtralImageProcessor", "PixtralImageProcessorFast"])
mgoin marked this conversation as resolved.
Show resolved Hide resolved
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
_import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"])
Expand Down Expand Up @@ -6157,7 +6157,7 @@
from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor
from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor
from .models.pix2struct import Pix2StructImageProcessor
from .models.pixtral import PixtralImageProcessor
from .models.pixtral import PixtralImageProcessor, PixtralImageProcessorFast
mgoin marked this conversation as resolved.
Show resolved Hide resolved
from .models.poolformer import (
PoolFormerFeatureExtractor,
PoolFormerImageProcessor,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
("paligemma", ("SiglipImageProcessor",)),
("perceiver", ("PerceiverImageProcessor",)),
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor",)),
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
("poolformer", ("PoolFormerImageProcessor",)),
("pvt", ("PvtImageProcessor",)),
("pvt_v2", ("PvtImageProcessor",)),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/pixtral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
pass
else:
_import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"]
_import_structure["image_processing_pixtral_fast"] = ["PixtralImageProcessorFast"]
mgoin marked this conversation as resolved.
Show resolved Hide resolved


if TYPE_CHECKING:
Expand All @@ -64,6 +65,7 @@
pass
else:
from .image_processing_pixtral import PixtralImageProcessor
from .image_processing_pixtral_fast import PixtralImageProcessorFast
mgoin marked this conversation as resolved.
Show resolved Hide resolved

else:
import sys
Expand Down
356 changes: 356 additions & 0 deletions src/transformers/models/pixtral/image_processing_pixtral_fast.py

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_vision_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])


class PixtralImageProcessorFast(metaclass=DummyObject):
_backends = ["vision"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])


mgoin marked this conversation as resolved.
Show resolved Hide resolved
class PoolFormerFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]

Expand Down
164 changes: 110 additions & 54 deletions tests/models/pixtral/test_image_processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# limitations under the License.

import random
import time
import unittest

import numpy as np
import requests

from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available

from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs

Expand All @@ -30,7 +32,7 @@
if is_vision_available():
from PIL import Image

from transformers import PixtralImageProcessor
from transformers import PixtralImageProcessor, PixtralImageProcessorFast
mgoin marked this conversation as resolved.
Show resolved Hide resolved


class PixtralImageProcessingTester(unittest.TestCase):
Expand All @@ -51,6 +53,7 @@ def __init__(
image_std=[0.26862954, 0.26130258, 0.27577711],
do_convert_rgb=True,
):
super().__init__()
size = size if size is not None else {"longest_edge": 24}
patch_size = patch_size if patch_size is not None else {"height": 8, "width": 8}
self.parent = parent
Expand Down Expand Up @@ -128,6 +131,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F
@require_vision
class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = PixtralImageProcessor if is_vision_available() else None
fast_image_processing_class = PixtralImageProcessorFast if is_torchvision_available() else None

def setUp(self):
super().setUp()
Expand All @@ -138,25 +142,27 @@ def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()

def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "patch_size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "patch_size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))

def test_call_pil(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs_list = self.image_processor_tester.prepare_image_inputs()
for image_inputs in image_inputs_list:
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs_list = self.image_processor_tester.prepare_image_inputs()
for image_inputs in image_inputs_list:
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
mgoin marked this conversation as resolved.
Show resolved Hide resolved

# Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
Expand All @@ -171,46 +177,96 @@ def test_call_pil(self):
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)

def test_call_numpy(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True)
for image_inputs in image_inputs_list:
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)

# Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)

# Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
for encoded_image, image in zip(encoded_images, images):
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True)
for image_inputs in image_inputs_list:
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)

# Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
image_inputs_list[0][0]
)
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)

# Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
for encoded_image, image in zip(encoded_images, images):
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)

def test_call_pytorch(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
for image_inputs in image_inputs_list:
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)

# Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
image_inputs_list[0][0]
)
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)

# Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
for encoded_image, image in zip(encoded_images, images):
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)

@require_vision
@require_torch
def test_fast_is_faster_than_slow(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping speed test")

if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping speed test as one of the image processors is not defined")

def measure_time(image_processor, image):
start = time.time()
_ = image_processor(image, return_tensors="pt")
return time.time() - start

image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
for image_inputs in image_inputs_list:
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)

# Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0][0])
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
fast_time = measure_time(image_processor_fast, image_inputs_list)
slow_time = measure_time(image_processor_slow, image_inputs_list)

# Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
for encoded_image, image in zip(encoded_images, images):
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
self.assertLessEqual(fast_time, slow_time)

@require_vision
@require_torch
def test_slow_fast_equivalence(self):
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)

if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")

if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")

image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)

encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")

self.assertTrue(torch.allclose(encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], atol=1e-2))

@unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
def test_call_numpy_4_channels(self):
Expand Down