Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ESM] Add support for sdpa. #34954

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

wzf03
Copy link

@wzf03 wzf03 commented Nov 27, 2024

What does this PR do?

Add support for SDPA (scaled dot product attention) for ESM. More context in #28802 (And this pr mainly reused the code from this pr as the ESM is Bert-based model) and #28005 .

This is my first time contributing to this project, please point out if there is any mistakes.

And revert a change in #29329 as the dtype-mismatching issue for bitsandbytes is actually caused by the rotary embedding.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker

@Rocketknight1
Copy link
Member

Thank you for this! Rather than skipping the SDPA test, though, can you write even a simple test that uses the SDPA path? It's okay if it can't compare hidden states deeply because of issues in the ESM model, but if it could compare that output logits are similar that'd give us a lot more confidence in the SDPA code!

@wzf03
Copy link
Author

wzf03 commented Nov 28, 2024

Thank you for this! Rather than skipping the SDPA test, though, can you write even a simple test that uses the SDPA path? It's okay if it can't compare hidden states deeply because of issues in the ESM model, but if it could compare that output logits are similar that'd give us a lot more confidence in the SDPA code!

Thanks for your reply, I will add relevant test cases soon.

@wzf03
Copy link
Author

wzf03 commented Nov 28, 2024

Thank you for this! Rather than skipping the SDPA test, though, can you write even a simple test that uses the SDPA path? It's okay if it can't compare hidden states deeply because of issues in the ESM model, but if it could compare that output logits are similar that'd give us a lot more confidence in the SDPA code!

@Rocketknight1 Hello, the sdpa inference tests for ESMFold has been added. Could you please review it?

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks like a good SDPA addition to me! I'll also set up slow tests in a sec.

src/transformers/models/esm/modeling_esm.py Show resolved Hide resolved
@Rocketknight1
Copy link
Member

Hi @wzf03, I ran the full test suite for ESM and I'm seeing one or two test failures. Can you see if you can reproduce those locally? They may just be flaky tests, but it might also be caused by changes in this PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@wzf03
Copy link
Author

wzf03 commented Nov 30, 2024

Hi @wzf03, I ran the full test suite for ESM and I'm seeing one or two test failures. Can you see if you can reproduce those locally? They may just be flaky tests, but it might also be caused by changes in this PR.

Hello @Rocketknight1, I found the test failures were due to the device mismatching of the input_ids (on cpu) and the model (on cuda) under bitsandbytes setting. It can be reproduced locally with the current latest master branch of accelerate@29be4788629b772a3b722076e433b5b3b5c85da3. But in may original test environment with accelerate==1.1.1, everything works well.

I will report this to accelerate later.

@wzf03
Copy link
Author

wzf03 commented Dec 2, 2024

Hello @Rocketknight1 , I made a quick fix according to other model's test, the test cases should work normally now.

@Rocketknight1
Copy link
Member

Yes, looks good to me now! cc @ArthurZucker @LysandreJik for core maintainer review

@wzf03
Copy link
Author

wzf03 commented Dec 9, 2024

@ArthurZucker @LysandreJik Hello! Can you please help review this pr?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey super sorry for the delay, waited a bit because #35235 changes the interface! Do you mind updating this PR ? Hope it's not too much of a burden! 😿

@wzf03
Copy link
Author

wzf03 commented Dec 21, 2024

Hey super sorry for the delay, waited a bit because #35235 changes the interface! Do you mind updating this PR ? Hope it's not too much of a burden! 😿

Sure, I will do it soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants