Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimized PixtralImageProcessorFast #34836

Merged
merged 15 commits into from
Nov 28, 2024

Conversation

mgoin
Copy link
Contributor

@mgoin mgoin commented Nov 20, 2024

What does this PR do?

This PR implements a fast image processor for Pixtral. Follows issue #33810.

The key acceleration comes from replacing Pillow/Numpy tensors and functions (resize, rescale, normalize) with torch tensors and torchvisionv2 functions. It comes along with support for torch.compile and passing device="cuda" during inference to process the input on GPU. One limitation is that only return_tensors="pt" will be supported.

Usage

from transformers import AutoImageProcessor

slow_processor = AutoImageProcessor.from_pretrained("mistral-community/pixtral-12b", use_fast=False)
fast_processor = AutoImageProcessor.from_pretrained("mistral-community/pixtral-12b", use_fast=True)
compiled_processor = torch.compile(fast_processor, mode="reduce-overhead")

From simple benchmarking with a single image of size [3, 876, 1300], I see 6x to 10x speedup

image

--------------------------------------------------
Slow Processor (PIL Image) Statistics (milliseconds):
          Mean: 23.680
        Median: 23.098
       Std Dev: 2.240
           Min: 21.824
           Max: 36.064

--------------------------------------------------
Fast Processor (PIL Image) Statistics (milliseconds):
          Mean: 3.759
        Median: 3.762
       Std Dev: 0.133
           Min: 3.556
           Max: 4.223

--------------------------------------------------
Compiled Processor (PIL Image) Statistics (milliseconds):
          Mean: 4.632
        Median: 4.794
       Std Dev: 1.086
           Min: 3.488
           Max: 11.707

--------------------------------------------------
Slow Processor (Torch Image) Statistics (milliseconds):
          Mean: 22.331
        Median: 21.878
       Std Dev: 1.821
           Min: 21.316
           Max: 36.603

--------------------------------------------------
Fast Processor (Torch Image) Statistics (milliseconds):
          Mean: 2.242
        Median: 2.209
       Std Dev: 0.164
           Min: 2.182
           Max: 3.803

--------------------------------------------------
Compiled Processor (Torch Image) Statistics (milliseconds):
          Mean: 2.125
        Median: 2.117
       Std Dev: 0.073
           Min: 2.062
           Max: 2.594

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@mgoin mgoin marked this pull request as draft November 20, 2024 21:19
@mgoin mgoin changed the title Add optimized PixtralImageProcessorFast Add optimized PixtralImageProcessorFast Nov 20, 2024
@mgoin mgoin marked this pull request as ready for review November 20, 2024 21:26
@qubvel
Copy link
Member

qubvel commented Nov 20, 2024

Hi @mgoin! Sounds great! Thanks for working on this 🤗

cc @yonigozlan maybe if you have bandwidth

Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @mgoin for working on this! Looks great to me, just mentioned some minor things to fix.

To be transparent, the current plan for fast image processors is to make a strong BaseImageProcessorFast and several image processing Mixins, in order to make adding new fast image processor much simpler. All that to say that this processor might change again in the future. Meanwhile, I think it would be great to have as is, because as you can see it makes a huge performance difference, and this fast image processor in particular doesn't require a huge diff (compared to the DETR ones for example).

Also, this would be the first fast image processor used in a processor, and I don't think there is a mechanism to use it with AutoProcessor yet (I might be wrong), so this is also something that will need to be added soon.
But I'm curious then to know how you are using it with a PixtralProcessor, I'm guessing you manually instantiate it with:

from transformers import PixtralProcessor, AutoImageProcessor, AutoTokenizer

fast_image_processor = AutoImageProcessor.from_pretrained("mistral-community/pixtral-12b", use_fast=True)
tokenizer = AutoTokenizer.from_pretrained("mistral-community/pixtral-12b")

processor = PixtralProcessor(fast_image_processor, tokenizer)

Is that correct?

Thanks again!

docs/source/en/_config.py Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/models/pixtral/__init__.py Show resolved Hide resolved
src/transformers/models/pixtral/__init__.py Show resolved Hide resolved
src/transformers/utils/dummy_vision_objects.py Outdated Show resolved Hide resolved
tests/models/pixtral/test_image_processing_pixtral.py Outdated Show resolved Hide resolved
@mgoin
Copy link
Contributor Author

mgoin commented Nov 21, 2024

Thanks for the review and context @yonigozlan ! I will look into it later today. Yes you are correct about using it within a Processor, however I have tested this works within vLLM simply by adding use_fast=True to our AutoProcessor.from_pretrained() call here. No need to manually specify the Processor class.

One bug I noticed is that if I specify use_fast=True and there isn't a Fast version of the ImageProcessor available, I get an exception. I can look into this, but would be good to get clarity that this is unintended behavior.

@yonigozlan
Copy link
Member

Oh great news that it already works with AutoProcessor. As I said this is the first fast image processor used in a processor so it was not guaranteed :).

One bug I noticed is that if I specify use_fast=True and there isn't a Fast version of the ImageProcessor available, I get an exception. I can look into this, but would be good to get clarity that this is unintended behavior.

Yes this is the same right now when using ImageProcessingAuto. I don't think it should be that way though, especially as more and more people will want to use fast image processors by default. I'll open a PR to fix this.

Current plan is:

  • keep use_fast to False by default in Auto classes during a deprecation cycle, fall back on slow image processor (with a warning) if use_fast is set to True and no fast image processor exists
  • set use_fast to True by default in Auto classes after the deprecation cycle (by then most models will hopefully have a fast image processor), still fall back on slow if no fast image processor exists.

The deprecation cycle is needed as there are slight differences in outputs when using torchvision vs PIL, see this PR #34785 for more info.

@ArthurZucker
Copy link
Collaborator

Feel free to ping us for another round of review! 🚀

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@mgoin
Copy link
Contributor Author

mgoin commented Nov 25, 2024

Thanks @yonigozlan and @ArthurZucker - this PR is ready for more review! I believe the failing test is unrelated

Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!
LGTM after adding a get_resize_output_image_size with torch instead of numpy

src/transformers/image_utils.py Show resolved Hide resolved
from .image_processing_pixtral import (
BatchMixFeature,
convert_to_rgb,
get_resize_output_image_size,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think it would be better to rewrite this function with torch.ceil instead of np.ceil

Copy link
Contributor Author

@mgoin mgoin Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed your comment response. Since height, width, and ratio are python scalars it doesn't make sense to use torch here. I can replace the np.ceil with just math.ceil though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just saw this, indeed my mistake :)

@mgoin
Copy link
Contributor Author

mgoin commented Nov 26, 2024

Hi @yonigozlan what do you think about the use of math.ceil?

@yonigozlan
Copy link
Member

Hi @yonigozlan what do you think about the use of math.ceil?

Sorry to be annoying with this but I'd prefer torch.ceil. I'd say we leave the get_resize_output_image_size function as it is (with np.ceil) in image_processing_pixtral, and rewrite one with torch in image_processing_pixtral_fast.

@mgoin
Copy link
Contributor Author

mgoin commented Nov 26, 2024

Sorry if I am misunderstanding something @yonigozlan but torch.ceil only works on torch.Tensor. Since we are working with height/width as int primitives I cannot substitute torch.ceil.

Are you suggesting that I wrap the values in a Tensor just to extract them back out as primitives? This is inefficient and unneeded to work with torch.compile

height = int(np.ceil(height / ratio))
# versus
height = int(torch.ceil(torch.tensor(height / ratio)).item())

Considering torchvision.transforms.v2.functional.resize takes in size as a list of ints (size: Optional[List[int]]), I don't see the benefit of using torch.Tensor for the height and width. If using numpy is the concern, I think using math.ceil makes sense.

@yonigozlan
Copy link
Member

Hi @mgoin,
Sorry about that, I went over the code too quickly and completely missed that we were dealing with floats and not arrays in the np.ceil. So what you did with ImageInput and math.ceil in image_processing_pixtral.get_resize_output_image_size makes perfect sense, you can ignore my last few messages🤗.
LGTM then! :)

@mgoin
Copy link
Contributor Author

mgoin commented Nov 27, 2024

Nice thanks for clarification! Hi @ArthurZucker would you mind signing off?

@ArthurZucker
Copy link
Collaborator

Yep having a look!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, sorry for the delay! My main question is for @yonigozlan , do we have compile tests for image processor fast?

@ArthurZucker ArthurZucker merged commit 9d6f0dd into huggingface:main Nov 28, 2024
26 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks a lot @mgoin for this contribution! 🔥

@yonigozlan
Copy link
Member

yonigozlan commented Nov 29, 2024

LGTM, sorry for the delay! My main question is for @yonigozlan , do we have compile tests for image processor fast?

I don't think so, I'll open a PR for that!

@ArthurZucker
Copy link
Collaborator

Cool thanks!

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* Add optimized PixtralImageProcessorFast

* make style

* Add dummy_vision_object

* Review comments

* Format

* Fix dummy

* Format

* np.ceil for math.ceil
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* Add optimized PixtralImageProcessorFast

* make style

* Add dummy_vision_object

* Review comments

* Format

* Fix dummy

* Format

* np.ceil for math.ceil
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants