Skip to content

Commit

Permalink
Add optimized PixtralImageProcessorFast (#34836)
Browse files Browse the repository at this point in the history
* Add optimized PixtralImageProcessorFast

* make style

* Add dummy_vision_object

* Review comments

* Format

* Fix dummy

* Format

* np.ceil for math.ceil
  • Loading branch information
mgoin authored Nov 28, 2024
1 parent 6300212 commit 9d6f0dd
Show file tree
Hide file tree
Showing 10 changed files with 560 additions and 74 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
"{processor_class}": "FakeProcessorClass",
"{model_class}": "FakeModelClass",
"{object_class}": "FakeObjectClass",
}
}
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/pixtral.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up
[[autodoc]] PixtralImageProcessor
- preprocess

## PixtralImageProcessorFast

[[autodoc]] PixtralImageProcessorFast
- preprocess

## PixtralProcessor

[[autodoc]] PixtralProcessor
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,7 @@
_import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
_import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast")
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.pixtral"].append("PixtralImageProcessorFast")
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
_import_structure["models.vit"].append("ViTImageProcessorFast")

Expand Down Expand Up @@ -6189,6 +6190,7 @@
from .image_processing_utils_fast import BaseImageProcessorFast
from .models.deformable_detr import DeformableDetrImageProcessorFast
from .models.detr import DetrImageProcessorFast
from .models.pixtral import PixtralImageProcessorFast
from .models.rt_detr import RTDetrImageProcessorFast
from .models.vit import ViTImageProcessorFast

Expand Down
39 changes: 39 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from .utils import (
ExplicitEnum,
TensorType,
is_jax_tensor,
is_numpy_array,
is_tf_tensor,
Expand Down Expand Up @@ -447,6 +448,44 @@ def validate_preprocess_arguments(
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")


def validate_fast_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[Dict[str, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
):
"""
Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
Raises `ValueError` if arguments incompatibility is caught.
"""
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# Extra checks for ImageProcessorFast
if return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")

if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")


# In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
("paligemma", ("SiglipImageProcessor",)),
("perceiver", ("PerceiverImageProcessor",)),
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor",)),
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
("poolformer", ("PoolFormerImageProcessor",)),
("pvt", ("PvtImageProcessor",)),
("pvt_v2", ("PvtImageProcessor",)),
Expand Down
24 changes: 23 additions & 1 deletion src/transformers/models/pixtral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_torchvision_available,
is_vision_available,
)


_import_structure = {
Expand Down Expand Up @@ -41,6 +47,14 @@
else:
_import_structure["image_processing_pixtral"] = ["PixtralImageProcessor"]

try:
if not is_torchvision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_pixtral_fast"] = ["PixtralImageProcessorFast"]


if TYPE_CHECKING:
from .configuration_pixtral import PixtralVisionConfig
Expand All @@ -65,6 +79,14 @@
else:
from .image_processing_pixtral import PixtralImageProcessor

try:
if not is_torchvision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_pixtral_fast import PixtralImageProcessorFast

else:
import sys

Expand Down
9 changes: 5 additions & 4 deletions src/transformers/models/pixtral/image_processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""Image processor class for Pixtral."""

import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -179,7 +180,7 @@ def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int])


def get_resize_output_image_size(
input_image: np.ndarray,
input_image: ImageInput,
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
Expand All @@ -189,7 +190,7 @@ def get_resize_output_image_size(
size.
Args:
input_image (`np.ndarray`):
input_image (`ImageInput`):
The image to resize.
size (`int` or `Tuple[int, int]`):
Max image size an input image can be. Must be a dictionary with the key "longest_edge".
Expand All @@ -210,8 +211,8 @@ def get_resize_output_image_size(

if ratio > 1:
# Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
height = int(np.ceil(height / ratio))
width = int(np.ceil(width / ratio))
height = int(math.ceil(height / ratio))
width = int(math.ceil(width / ratio))

num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
return num_height_tokens * patch_height, num_width_tokens * patch_width
Expand Down
Loading

0 comments on commit 9d6f0dd

Please sign in to comment.