-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RoBERTa-based] Add support for sdpa #30510
Conversation
ecadcfb
to
f6e5e74
Compare
I ran slow tests for the affected models, and verified that they all pass except I'll also try to run some the perf benchmarks on RoBERTa over the weekend to see how they behave. |
Preliminary perf numbers for Roberta (using "roberta-base" with AutoModel/Tokenizer). Training
Inference
|
It seems like EDIT: I added a set_seed(0) to |
c39f457
to
41537e3
Compare
@fxmarty @ArthurZucker @amyeroberts This is ready for review! With the exception of the changes to the test and check_support_list.py, all the changes are coming from "Copied From". Please let me know if you have any questions! |
@hackyon, I'm curious about whether implementing flash_atten is essential when writing an SDPA. I came across claims that flash_atten can offer up to a x4 efficiency boost (roughly) compared to native PyTorch. However, your remarks in #30510 suggest that the actual improvement is less than 50%. Could you help shed some light on this apparent difference? |
@michaelshekasta I believe the 4x improvement only applies to certain models, usually larger models with more computationally expensive attention computations. |
@fxmarty can you have a look and ping me for the final review? 🤗 |
@fxmarty , gentle bump |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a fix that I think needs to be made for cross attention for the is_causal
param
The test here https://github.com/huggingface/transformers/pull/30138/files#diff-681c988a50a31869d1756f2db71904939c639617569a5168d7b3167fe8da0b48 could also be extended for extra safety, but up to you.
src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
Outdated
Show resolved
Hide resolved
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@fxmarty what's left? How can I help? |
@fxmarty You may want to resolve conflicts. |
Sorry did not have time before, will try to do today or next week. It's a big PR with lots of changes, need to be extra careful! |
@ArthurZucker would you have a time for this review? |
I've also experienced approximately 20% faster training with XLMRoberta using this PR on an RTX4090. I've been testing it for over a week now, and it's been working without any issues. I sincerely hope this can be merged. |
@ArthurZucker Can we help with anything reviewing this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept pushing this back, it's on me! I'll solve whatever comes up with this merge.
Thanks @hackyon for your hard work LGTM!
@ArthurZucker when do you think that this change will appear in transformers package? next version? P.S. You are so amazing guys! |
It should be there in at most 2 weeks! 🤗 |
I would like to thank everyone involved in this Pull Request from the bottom of my heart! 🎉 |
* Adding SDPA support for RoBERTa-based models * add not is_cross_attention * fix copies * fix test * add minimal test for camembert and xlm_roberta as their test class does not inherit from ModelTesterMixin * address some review comments * use copied from * style * consistency * fix lists --------- Co-authored-by: fxmarty <[email protected]> Co-authored-by: Arthur <[email protected]>
* Adding SDPA support for RoBERTa-based models * add not is_cross_attention * fix copies * fix test * add minimal test for camembert and xlm_roberta as their test class does not inherit from ModelTesterMixin * address some review comments * use copied from * style * consistency * fix lists --------- Co-authored-by: fxmarty <[email protected]> Co-authored-by: Arthur <[email protected]>
@ArthurZucker A gentle reminder ;-)
@ArthurZucker A gentle remider ;-) |
We are gonna release today / tomorrow! 🤗 sorry for the delay |
@ArthurZucker Thanks!!! I hope that it will be release today!! |
* Adding SDPA support for RoBERTa-based models * add not is_cross_attention * fix copies * fix test * add minimal test for camembert and xlm_roberta as their test class does not inherit from ModelTesterMixin * address some review comments * use copied from * style * consistency * fix lists --------- Co-authored-by: fxmarty <[email protected]> Co-authored-by: Arthur <[email protected]>
😐 really sorry, big big release is coming on Wednesday, don't the wait is worth it ! 👀 |
It was released on Wednesday, thanks for your patience 🤗 |
* Adding SDPA support for RoBERTa-based models * add not is_cross_attention * fix copies * fix test * add minimal test for camembert and xlm_roberta as their test class does not inherit from ModelTesterMixin * address some review comments * use copied from * style * consistency * fix lists --------- Co-authored-by: fxmarty <[email protected]> Co-authored-by: Arthur <[email protected]>
* Adding SDPA support for RoBERTa-based models * add not is_cross_attention * fix copies * fix test * add minimal test for camembert and xlm_roberta as their test class does not inherit from ModelTesterMixin * address some review comments * use copied from * style * consistency * fix lists --------- Co-authored-by: fxmarty <[email protected]> Co-authored-by: Arthur <[email protected]>
What does this PR do?
Adding support for SDPA (scaled dot product attention) for RoBERTa-based models. More context in #28005 and #28802.
Models: camembert, roberta, xlm_roberta, xlm_roberta_xl.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@fxmarty @ArthurZucker @amyeroberts