Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨🚨🚨 fix(Mask2Former): torch export 🚨🚨🚨 #34393

Merged
merged 9 commits into from
Nov 19, 2024

Conversation

philkuz
Copy link
Contributor

@philkuz philkuz commented Oct 24, 2024

What does this PR do?

Fixes #34390 (issue)

Mask2Former modeling had a set of issues that prevented torch.export from working. This PR addresses ~3 individual problems that prevented the model from working. I took direction from a few PRs
that made similar changes:

  1. fix(Wav2Vec2ForCTC): torch export #34023 for the attention mask modification to prevent a RuntimeError during torch.export
  2. Optim deformable detr #33600 for passing around spatial_shapes as a list and as a tensor and for the if statement fix

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts, @qubvel, @ylacombe (tagging because you looked at #34023)

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Testing:

Run this once in main and then once with this branch. Ensure the testing.all_close() works.

import os

import numpy as np
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation


def get_test_image():
   url = "http://images.cocodataset.org/val2017/000000039769.jpg"
   image = Image.open(requests.get(url, stream=True).raw)
   return image


processor = AutoImageProcessor.from_pretrained(
   "facebook/mask2former-swin-base-coco-panoptic"
)
model = Mask2FormerForUniversalSegmentation.from_pretrained(
   "facebook/mask2former-swin-base-coco-panoptic"
)


results = model(**processor(images=get_test_image(), return_tensors="pt"))

results = processor.post_process_panoptic_segmentation(results)

results_file = "segmentation_results.pt"
if os.path.exists(results_file):
   old_segmentation = torch.load(results_file)
   np.testing.assert_array_almost_equal(results[0]["segmentation"], old_segmentation)
else:
   torch.save(results[0]["segmentation"], results_file)

scripted_model = torch.export.export(model, args=(torch.randn(1, 3, 800, 1280),))

I also visually compared the output masks and they look the same

Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks for working on this! I just have a question regarding one of your changes

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@philkuz
Copy link
Contributor Author

philkuz commented Oct 28, 2024

Hey @yonigozlan if you have a chance, I would love another review of this PR after addressing your comments!
Thanks

@yonigozlan
Copy link
Member

Looks great to me thanks again for fixing this! I think we should add a short comment just above the definition of level_start_index_list explaining why we do this (iterating over a tensor breaks torch.compile/export).
Please rebase on main and then ask a core maintainer for a final review :).

@philkuz philkuz force-pushed the torch_export_mask2former branch from cfaef32 to 7a460d4 Compare October 28, 2024 18:57
@philkuz
Copy link
Contributor Author

philkuz commented Oct 28, 2024

Thanks Yoni!

Pinging @amyeroberts and @qubvel for a core maintainer review!

Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

It would be also nice to have a test that export works, maybe smth similar to this.

@philkuz philkuz force-pushed the torch_export_mask2former branch from 4c78eca to 78f3848 Compare October 29, 2024 18:08
Signed-off-by: Phillip Kuznetsov <[email protected]>
Signed-off-by: Phillip Kuznetsov <[email protected]>
Signed-off-by: Phillip Kuznetsov <[email protected]>
Signed-off-by: Phillip Kuznetsov <[email protected]>
Signed-off-by: Phillip Kuznetsov <[email protected]>
@philkuz philkuz force-pushed the torch_export_mask2former branch from 3b3562f to 104b16d Compare October 29, 2024 20:15
@philkuz
Copy link
Contributor Author

philkuz commented Oct 29, 2024

Thanks for working on this!

It would be also nice to have a test that export works, maybe smth similar to this.

Added an export test and addressed your other comments. Please let me know if you would like me to modify anything else!

Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for a quick iteration! A few more comments.

Please, push empty commit with message[run_slow] mask2former at the end to trigger all slow tests run for mask2former model

@philkuz
Copy link
Contributor Author

philkuz commented Oct 29, 2024

Thanks for a quick iteration! A few more comments.

Please, push empty commit with message[run_slow] mask2former at the end to trigger all slow tests run for mask2former model

You got it! I think I need maintainer approval to run the slow workflows

Signed-off-by: Phillip Kuznetsov <[email protected]>
@philkuz philkuz force-pushed the torch_export_mask2former branch from 38feb0d to 704b79f Compare October 29, 2024 21:57
Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@qubvel qubvel requested a review from ArthurZucker October 30, 2024 08:23
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% we have to break not, do you think we can keep supporting the old behaviour / at least have a deprecation cycle?

@@ -926,7 +926,7 @@ def forward(
encoder_attention_mask=None,
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing the name is a breaking change, we should probably have a deprecation cycle no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not, we can add 🔴 as I think the motivation is strong enough.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I suppose it is an internal module of the model, not sure if it is intended to be used elsewhere, let me know if I'm wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't changing the logic, but keeping both spatial_shapes and spatial_shapes_list already a breaking change?

Sure there's some question of whether users can rely on internal Modules of transformers models, but also if a user doesn't pass a value for spatial_shapes_list this code will fail as L939 will try to iterate over a None object.

BTW seems like the changes in #33600 already violate this contract(see this line)? I followed that PR as a guide on what I should change here.

It seems like the proper way forward is to add a 🔴 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the path forward @ArthurZucker ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this again after the weekend. I'm not familiar with Huggingface's policies of

  1. What qualifies as a breaking change?
  2. What the release process is?

Could you provide a recommendation on where we should take this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually even if a module is Internal people still end up using it 😅
This IS a breaking change, but acceptable IMO. Let's use some 🚨 on the PR title to make sure we communicate about it on the release!

@qubvel qubvel requested a review from ArthurZucker November 4, 2024 16:54
@qubvel
Copy link
Member

qubvel commented Nov 4, 2024

soft ping @ArthurZucker, regarding your last comment, I'm not sure it can be done without breaking changes, but as soon as it is the internal module of the model I suppose it is fine to change its signature. 🚨 can be added to PR message if it is still required

@philkuz
Copy link
Contributor Author

philkuz commented Nov 7, 2024

soft ping @ArthurZucker, regarding your last comment, I'm not sure it can be done without breaking changes, but as soon as it is the internal module of the model I suppose it is fine to change its signature. 🚨 can be added to PR message if it is still required

Hi @ArthurZucker @qubvel,

Hope you're doing well, could you find some time today or tomorrow to provide guidance on this PR? Would love to check it off of my list!

@philkuz
Copy link
Contributor Author

philkuz commented Nov 17, 2024

soft ping @ArthurZucker, regarding your last comment, I'm not sure it can be done without breaking changes, but as soon as it is the internal module of the model I suppose it is fine to change its signature. 🚨 can be added to PR message if it is still required

Hi @ArthurZucker @qubvel,

Hope you're doing well, could you find some time today or tomorrow to provide guidance on this PR? Would love to check it off of my list!

Hi @ArthurZucker @qubvel any chance we can move forward with this PR this week? Happy to do whatever you would like, just need to get guidance on what you would like to do here.

@qubvel
Copy link
Member

qubvel commented Nov 17, 2024

Hi @philkuz, sorry for the delay, the team was on the offsite this week. I will ping Arthur to get it reviewed and merged. Thanks for the patience

@philkuz
Copy link
Contributor Author

philkuz commented Nov 18, 2024

Hi @philkuz, sorry for the delay, the team was on the offsite this week. I will ping Arthur to get it reviewed and merged. Thanks for the patience

Thank you, Pavel!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this contribution! 🤗 Looks good will just update the PR name!

@@ -926,7 +926,7 @@ def forward(
encoder_attention_mask=None,
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually even if a module is Internal people still end up using it 😅
This IS a breaking change, but acceptable IMO. Let's use some 🚨 on the PR title to make sure we communicate about it on the release!

@ArthurZucker ArthurZucker changed the title fix(Mask2Former): torch export 🚨🚨🚨 fix(Mask2Former): torch export 🚨🚨🚨 Nov 19, 2024
@ArthurZucker ArthurZucker merged commit 5fa4f64 into huggingface:main Nov 19, 2024
19 checks passed
@qubvel qubvel added the Vision label Nov 19, 2024
@qubvel
Copy link
Member

qubvel commented Nov 19, 2024

cc @guangy10

@guangy10
Copy link
Contributor

@philkuz Awesome! Thanks for expanding the export coverage to more models! 🚀 🚀 🚀

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* fix(Mask2Former): torch export

Signed-off-by: Phillip Kuznetsov <[email protected]>

* revert level_start_index and create a level_start_index_list

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Add a comment to explain the level_start_index_list

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Address comment

Signed-off-by: Phillip Kuznetsov <[email protected]>

* add torch.export.export test

Signed-off-by: Phillip Kuznetsov <[email protected]>

* rename arg

Signed-off-by: Phillip Kuznetsov <[email protected]>

* remove spatial_shapes

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Use the version check from pytorch_utils

Signed-off-by: Phillip Kuznetsov <[email protected]>

* [run_slow] mask2former

Signed-off-by: Phillip Kuznetsov <[email protected]>

---------

Signed-off-by: Phillip Kuznetsov <[email protected]>
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* fix(Mask2Former): torch export

Signed-off-by: Phillip Kuznetsov <[email protected]>

* revert level_start_index and create a level_start_index_list

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Add a comment to explain the level_start_index_list

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Address comment

Signed-off-by: Phillip Kuznetsov <[email protected]>

* add torch.export.export test

Signed-off-by: Phillip Kuznetsov <[email protected]>

* rename arg

Signed-off-by: Phillip Kuznetsov <[email protected]>

* remove spatial_shapes

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Use the version check from pytorch_utils

Signed-off-by: Phillip Kuznetsov <[email protected]>

* [run_slow] mask2former

Signed-off-by: Phillip Kuznetsov <[email protected]>

---------

Signed-off-by: Phillip Kuznetsov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mask2former] torch.export error for Mask2Former
6 participants