-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FlashAttention2 for XLM-RoBERTa #28713
Add FlashAttention2 for XLM-RoBERTa #28713
Conversation
@DavidAfonsoValente Thanks for opening a PR and adding FA2 for this model! Let us know when the PR is ready to be reviewed 🤗 |
Thanks! I believe it's ready to be reviewed. |
The tests that are failing are due to the changes made in the functions that are labeled as: What is the standard procedure when I have made changes to these functions in order to accommodate the new feature? |
@DavidAfonsoValente If you click on the CI runs, you'll see the error messages relating to the code quality failures. These detail how to resolve the issue: run If another model is copying this model's attention classes, then it's great because you're adding FA2 for two models 🥳 You just need to apply the equivalent changes to the model that's copying. |
90107b7
to
237c1a3
Compare
efcae38
to
09f4695
Compare
9b03bc7
to
9cff1b0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work - thanks for adding!
Just a few tiny comments and then we're good to merge 🤗
Hello there! I'm working on integrating scaled_dot_product_attention to BERT #28802, and there might be some merge conflicts with this change. Most of the changes I'm making will propagate through to XML-RoBERTa (through fix-copies). From what I can, our approaches are a bit different from this in 2 areas:
Let me know if you have any questions about these. I think we should discuss about which approach is better and adopt it. Would you mind looking through #28802 and let me know what you think? |
It seems like your implementation doesnt create complicated conflitcts, it should be simple to merge, what do you think? |
It should be simple to merge. I think the main question is whether or not we want to use the XLM_ROBERTA_SELF_ATTENTION_CLASSES or XLM_ROBERTA_ATTENTION_CLASSES approach? |
I think in order to be consistent with other models we should keep the name XLM_ROBERTA_ATTENTION_CLASSES since this is how other models have their attention classes named. |
Thanks for your input. I am leaning towards using the "SelfAttention" convention in this case, because XLMRobertaFlashAttention2 actually tries to copy the logic of both XLMRobertaAttention and XLMRobertaSelfAttention into one. I think it's cleaner to have a XLMRobertaSelfFlashAttention2 that mirrors XLMRobertaSelfAttention, and then reuse the logic inside XLMRobertaAttention for both types of self attentions. Reusing the code in XLMRobertaAttention should help avoid some bugs. Judging from the existing code, I think there's already a bit of an issue with the way self.output() is called in XLMRobertaFlashAttention2 (I think you need more than the call to self.output.dense(), otherwise you're missing the dropout and LayerNorm). |
Okk, maybe it'll be easier if you merge your PR and I work on top of your version. |
Thanks 👍 I'll iterate as fast as possible to get it merged, and will let you know as soon as it's done! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@amyeroberts @DavidAfonsoValente @hackyon Why hasn't this PR been merged yet? How can I help with that? |
I have this #30510 open to add scaled dot product attention (SDPA) to RoBERTa-based models, and is still under review. When available, the SDPA could run on a FA2 backend. You could very well try to add FA2 to it as well. |
@hackyon Thank you for your feedback! As far as I'm aware, SDPA isn't Flash Attention 2, and it's currently in beta, correct? So, will HF support it? Or am I missing something? I'd be happy to contribute, but I'm not sure how. Could we discuss this on the HF Discord server, perhaps in the NLP channel? |
Newer versions of SDPA actually has FA2 integration: https://pytorch.org/blog/pytorch2-2. With that said, you could probably add a FA2 option as well if you'd like (some of the models have that as well), but I'm not HG maintainer so take my opinion with a grain of salt. SDPA is already supported in a variety of HG models. |
Thank you once more for your input, @hackyon! I'm interested in implementing XLM-RoBERTa with FA2, and I believe integrating it into Hugging Face (HF) could be the way forward. From what I've gathered, this pull request appears to align with my goals. While I understand you're not directly involved in HF maintenance, your prior work suggests you're knowledgeable in this area. Could you advise on the best course of action? Should I wait for your pull request, or would utilizing the flash_attn library be a viable alternative? |
I think you should reference how FA2 was implemented in other libraries, such as modeling_llama, and bring those changes over. I recommend using "Copied From" as much as possible. My 2 cents would be that it'd be better to use XLM_ROBERTA_SELF_ATTENTION_CLASSES instead of XLM_ROBERTA_ATTENTION_CLASSES, otherwise there'd be conflict with #30510 (and since they have a "Copied From" relationship, it wouldn't be easier to keep both copies). Unfortunately, I won't be able to discuss on discord. You might be able to find some of the core maintainers there though. |
@hackyon thanks again! you are awesome! |
#30510 has a XLMROBERTAXL_SELF_ATTENTION_CLASSES, so it's better to add the FA2 implementation in there instead of creating a new XLM_ROBERTA_ATTENTION_CLASSES. |
@hackyon, upon revisiting the thread, I've come to the conclusion that transitioning from SDPA to XLMR seems to be the most appropriate course of action. Given that PyTorch has already implemented FA2 with SDPA, I don't see the necessity of replicating it for XLMR. Is there anything I can assist with to facilitate the seamless integration of this pull request? |
@michaelshekasta if you're up for it, you can clone #30510 and try to get that submitted. I will be starting a new job in a week and won't really have time to follow through with that PR. If this is your first change, you'll need to read through their contributions page and follow the instructions there. Specifically, pay attention to getting "make fix-copies" and "RUN_SLOW" tests to pass before opening the PR. |
@hackyon Let's talk in the email [email protected] |
hi, I want to know whether this job has been merged into 4.42.2? |
@aikangjun This PR wasn't merged - it closed because of inactivity it seems. We've recently merged in other PRs to add SDPA to roberta based models though #30510 which adds it to this model. This isn't part of 4.42 but will be part of the next release |
What does this PR do?
Fixes #27957
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.