Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make VideoMAEImageProcessor much faster #28221

Closed
wants to merge 9 commits into from

Conversation

ikergarcia1996
Copy link
Contributor

@ikergarcia1996 ikergarcia1996 commented Dec 23, 2023

What does this PR do?

Currently, VideoMAEImageProcessor is extremely slow. In fact, during inference, it takes longer to preprocess a video than to run the model. After investigating the code, I discovered that this issue can be easily fixed.

Currently, the preprocess() function in VideoMAEImageProcessor creates a list of list of ndarrays (as self._preprocess_image() returns an ndarray), which is then sent to BatchFeature to be converted into a torch.tensor. The issue arises because creating a tensor from a list of ndarrays is extremely slow. Additional information on this problem can be found here: pytorch/pytorch#13918.

By converting the list of ndarrays into a single ndarray, a significant speedup can be achieved. Here is a minimal example for demonstration.

We create two processors: one using the current code and another with the modified code.

from transformers import VideoMAEImageProcessor as original_image_processor
from image_processing_videomae import VideoMAEImageProcessor as new_image_processor
import torch


IMAGE_MEAN: list[float] = [0.33363932, 0.32581538, 0.31566033]
IMAGE_STD: list[float] = [0.1914285, 0.18449214, 0.1853477]

image_processor_og = original_image_processor(
    do_resize=False,
    do_center_crop=False,
    do_rescale=True,
    do_normalize=True,
    image_mean=IMAGE_MEAN,
    image_std=IMAGE_STD,
)

image_processor_new = new_image_processor(
    do_resize=False,
    do_center_crop=False,
    do_rescale=True,
    do_normalize=True,
    image_mean=IMAGE_MEAN,
    image_std=IMAGE_STD,
)

We then create a video of 128 frammes with a 200x200 resolution

image_sequences = np.asarray(
                    [
                        np.random.rand(200, 200, 3)*255,
                    ]*128,
                    dtype=np.uint8,
                )

image_sequences = list(image_sequences)
print(len(image_sequences))
print(image_sequences[0].shape)
print(image_sequences[0][0][0][:10])

## OUTPUT
128
(200, 200, 3)
[119 129 132]

I have run both processors in a jupyter notebook

%%timeit
model_inputs_og = image_processor_og(
                        images=image_sequences,
                        input_data_format="channels_last",
                        return_tensors="pt",
                    )
# 1.11 s ± 123 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
model_inputs_new = image_processor_new(
                        images=image_sequences,
                        input_data_format="channels_last",
                        return_tensors="pt",
                    )
# 154 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Just to be sure than nothing changes

model_inputs_og["pixel_values"].size()
# torch.Size([1, 128, 3, 200, 200])
model_inputs_new["pixel_values"].size()
# torch.Size([1, 128, 3, 200, 200])
torch.all(model_inputs_new["pixel_values"]==model_inputs_og["pixel_values"])
# tensor(True)

With this small change, we reduce the video preprocessing time from 1.11 seconds to 154 ms 😃

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts

Creating a tensor from a list of numpy.ndarrays is extremely slow. This small change makes the preprocessing step much faster.
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ikergarcia1996 - thanks for opening this PR and for the detail write-up! And apologies for the delay in review.

Could you extend this logic to the other image processors in the library?

@ikergarcia1996
Copy link
Contributor Author

ikergarcia1996 commented Jan 4, 2024

@amyeroberts The suggestion makes sense, although we also need to define data if return_tensors is None. I have slightly modified the suggestion.

Regarding extending this logic to other image processors. I think that the root of the issue is here:

def as_tensor(value):
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
value = np.array(value)
return torch.tensor(value)

The as_tensor function in _get_is_as_tensor_fns from the BatchFeature class already takes into account that if the data is a list of np.ndarray is should first be converted into an np.array. However, this is not the case with the VideoMAE data becase it is a list of lists of ndarrays. Doing a recursive type check would fix the issue, altough I am not sure if it would break any other functionallity. Maybe I should open a separate PR for this?

def recursive_ndarray_check(value):
    if isinstance(value, (list, tuple)) and len(value) > 0:
        return recursive_ndarray_check(value[0])
    return isinstance(value, np.ndarray)

if recursive_ndarray_check(value):
    value = np.array(value)
return torch.tensor(value)

@amyeroberts
Copy link
Collaborator

@ikergarcia1996 Thanks for the detailed explanation! In principle, I'd be pro adding in the recursive check. However, there's other models e.g. audio models which rely on this logic which might also be affected by this change so we'd have to make sure it's well tested for these models too.

Let's add it now, we can iron out any issues it might flag for the vision models and then I can ask the audio team if they think they're be any issues.

@huggingface huggingface deleted a comment from github-actions bot Jan 30, 2024
@amyeroberts
Copy link
Collaborator

Hi @ikergarcia1996, are you still working on this? It would be great to have this contribution! The next steps would be adding in the recursive logic you proposed.

@ikergarcia1996
Copy link
Contributor Author

Sorry, @amyeroberts, I had other urgent matters to attend to and forgot about this. I have updated the code; it works for VideoMAEImageProcessor, although I am not sure if it may cause issues with other models. The tests are failing, but the error seems unrelated to the changes in the code.

RuntimeError: Failed to import transformers.models.nat.modeling_nat because of the following error (look up to see its traceback):
E   Failed to import NATTEN's CPP backend. This could be due to an invalid/incomplete install. Please uninstall NATTEN (pip uninstall natten) and re-install with the correct torch build: shi-labs.com/natten

@amyeroberts
Copy link
Collaborator

@ikergarcia1996 Thanks for updating!

The natten issues aren't related to this PR - we recently had issues on our CI runs because of recent package releases and incompatible versions. A fix has been pushed to the main branch. Rebasing to include these changes should resolve them.

Merge pull request #1 from huggingface/main
@ikergarcia1996
Copy link
Contributor Author

@amyeroberts I have updated my branch, but the tests still fail.

@amyeroberts
Copy link
Collaborator

@ikergarcia1996 Apolgies. There's been some continued issues with handling compatibility between packages. A final fix should have been merged into main now. Could you try rebasing again?

@ikergarcia1996
Copy link
Contributor Author

@amyeroberts done 😃

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this!

Just a few small comments on what's there. Only thing that needs to be added are some tests to make sure that the inputs and outputs are handled as expected. In particular, it would be good to test for a vision model which outputs more than just pixel_values, for example annotations for DETR.

Comment on lines 342 to 343
# Speeds up tensor conversion - see: https://github.com/huggingface/transformers/pull/28221/files
data = {"pixel_values": np.asarray(videos) if return_tensors == "pt" else videos}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, although it doesn't hurt either. Should we remove it?

@huggingface huggingface deleted a comment from github-actions bot Mar 8, 2024
@amyeroberts
Copy link
Collaborator

@ikergarcia1996 Any update on this? It would be great to have this included in the library!

@ikergarcia1996
Copy link
Contributor Author

Hi @amyeroberts!
Unfortunately, right now I don't have the time to implement the tests for other models. My expertise is not in vision models, so it would take me some time to understand how other models work and to implement the tests to ensure that this PR doesn't change the behavior of any model. I would appreciate some help if possible

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jun 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants