Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FlashAttention2 for XLM-RoBERTa #28713

Conversation

DavidAfonsoValente
Copy link
Contributor

@DavidAfonsoValente DavidAfonsoValente commented Jan 25, 2024

What does this PR do?

Fixes #27957

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts
Copy link
Collaborator

@DavidAfonsoValente Thanks for opening a PR and adding FA2 for this model! Let us know when the PR is ready to be reviewed 🤗

@DavidAfonsoValente
Copy link
Contributor Author

@DavidAfonsoValente Thanks for opening a PR and adding FA2 for this model! Let us know when the PR is ready to be reviewed 🤗

Thanks! I believe it's ready to be reviewed.

@DavidAfonsoValente
Copy link
Contributor Author

DavidAfonsoValente commented Jan 29, 2024

The tests that are failing are due to the changes made in the functions that are labeled as:
Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->XLMRoberta

What is the standard procedure when I have made changes to these functions in order to accommodate the new feature?

@amyeroberts
Copy link
Collaborator

@DavidAfonsoValente If you click on the CI runs, you'll see the error messages relating to the code quality failures. These detail how to resolve the issue: run make fix-copies and push the changes.

If another model is copying this model's attention classes, then it's great because you're adding FA2 for two models 🥳 You just need to apply the equivalent changes to the model that's copying.

@DavidAfonsoValente DavidAfonsoValente force-pushed the flashattention2-XLMRoBERTa branch from 90107b7 to 237c1a3 Compare January 30, 2024 02:02
@DavidAfonsoValente DavidAfonsoValente force-pushed the flashattention2-XLMRoBERTa branch from efcae38 to 09f4695 Compare January 30, 2024 02:22
@DavidAfonsoValente DavidAfonsoValente force-pushed the flashattention2-XLMRoBERTa branch from 9b03bc7 to 9cff1b0 Compare January 30, 2024 16:52
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work - thanks for adding!

Just a few tiny comments and then we're good to merge 🤗

tests/models/xlm_roberta/test_modeling_xlm_roberta.py Outdated Show resolved Hide resolved
tests/models/xlm_roberta/test_modeling_xlm_roberta.py Outdated Show resolved Hide resolved
docs/source/en/model_doc/xlm-roberta.md Show resolved Hide resolved
@hackyon
Copy link
Contributor

hackyon commented Feb 7, 2024

Hello there!

I'm working on integrating scaled_dot_product_attention to BERT #28802, and there might be some merge conflicts with this change.

Most of the changes I'm making will propagate through to XML-RoBERTa (through fix-copies). From what I can, our approaches are a bit different from this in 2 areas:

  1. I've kept the #Copies from line at the class level, so any changes to BERT's init() will be propagated through to the fix-copies classes. I've modified the #Copies from line to use the corresponding ATTENTION_CLASS as necessary.
  2. I've created a BertSdpaSelfAttention to go with BertSelfAttention, instead of creating a BertSdpaAttention. The rationale behind this is that the code in BertSdpaSelfAttention is adapted from the code in BertSelfAttention, so I figured it's better to name it as such.

Let me know if you have any questions about these. I think we should discuss about which approach is better and adopt it. Would you mind looking through #28802 and let me know what you think?

@DavidAfonsoValente
Copy link
Contributor Author

Hello there!

I'm working on integrating scaled_dot_product_attention to BERT #28802, and there might be some merge conflicts with this change.

Most of the changes I'm making will propagate through to XML-RoBERTa (through fix-copies). From what I can, our approaches are a bit different from this in 2 areas:

  1. I've kept the #Copies from line at the class level, so any changes to BERT's init() will be propagated through to the fix-copies classes. I've modified the #Copies from line to use the corresponding ATTENTION_CLASS as necessary.
  2. I've created a BertSdpaSelfAttention to go with BertSelfAttention, instead of creating a BertSdpaAttention. The rationale behind this is that the code in BertSdpaSelfAttention is adapted from the code in BertSelfAttention, so I figured it's better to name it as such.

Let me know if you have any questions about these. I think we should discuss about which approach is better and adopt it. Would you mind looking through #28802 and let me know what you think?

It seems like your implementation doesnt create complicated conflitcts, it should be simple to merge, what do you think?

@hackyon
Copy link
Contributor

hackyon commented Feb 8, 2024

It should be simple to merge. I think the main question is whether or not we want to use the XLM_ROBERTA_SELF_ATTENTION_CLASSES or XLM_ROBERTA_ATTENTION_CLASSES approach?

@DavidAfonsoValente
Copy link
Contributor Author

I think in order to be consistent with other models we should keep the name XLM_ROBERTA_ATTENTION_CLASSES since this is how other models have their attention classes named.

@hackyon
Copy link
Contributor

hackyon commented Feb 8, 2024

Thanks for your input.

I am leaning towards using the "SelfAttention" convention in this case, because XLMRobertaFlashAttention2 actually tries to copy the logic of both XLMRobertaAttention and XLMRobertaSelfAttention into one. I think it's cleaner to have a XLMRobertaSelfFlashAttention2 that mirrors XLMRobertaSelfAttention, and then reuse the logic inside XLMRobertaAttention for both types of self attentions.

Reusing the code in XLMRobertaAttention should help avoid some bugs. Judging from the existing code, I think there's already a bit of an issue with the way self.output() is called in XLMRobertaFlashAttention2 (I think you need more than the call to self.output.dense(), otherwise you're missing the dropout and LayerNorm).

@DavidAfonsoValente
Copy link
Contributor Author

Okk, maybe it'll be easier if you merge your PR and I work on top of your version.

@hackyon
Copy link
Contributor

hackyon commented Feb 8, 2024

Thanks 👍 I'll iterate as fast as possible to get it merged, and will let you know as soon as it's done!

Copy link

github-actions bot commented Mar 4, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@michaelshekasta
Copy link

@amyeroberts @DavidAfonsoValente @hackyon Why hasn't this PR been merged yet? How can I help with that?

@hackyon
Copy link
Contributor

hackyon commented May 8, 2024

@michaelshekasta

I have this #30510 open to add scaled dot product attention (SDPA) to RoBERTa-based models, and is still under review. When available, the SDPA could run on a FA2 backend. You could very well try to add FA2 to it as well.

@michaelshekasta
Copy link

@michaelshekasta

I have this #30510 open to add scaled dot product attention (SDPA) to RoBERTa-based models, and is still under review. When available, the SDPA could run on a FA2 backend. You could very well try to add FA2 to it as well.

@hackyon Thank you for your feedback! As far as I'm aware, SDPA isn't Flash Attention 2, and it's currently in beta, correct? So, will HF support it? Or am I missing something?

I'd be happy to contribute, but I'm not sure how. Could we discuss this on the HF Discord server, perhaps in the NLP channel?

@hackyon
Copy link
Contributor

hackyon commented May 8, 2024

Newer versions of SDPA actually has FA2 integration: https://pytorch.org/blog/pytorch2-2. With that said, you could probably add a FA2 option as well if you'd like (some of the models have that as well), but I'm not HG maintainer so take my opinion with a grain of salt.

SDPA is already supported in a variety of HG models.

@michaelshekasta
Copy link

Thank you once more for your input, @hackyon!

I'm interested in implementing XLM-RoBERTa with FA2, and I believe integrating it into Hugging Face (HF) could be the way forward. From what I've gathered, this pull request appears to align with my goals.

While I understand you're not directly involved in HF maintenance, your prior work suggests you're knowledgeable in this area. Could you advise on the best course of action? Should I wait for your pull request, or would utilizing the flash_attn library be a viable alternative?

@hackyon
Copy link
Contributor

hackyon commented May 8, 2024

I think you should reference how FA2 was implemented in other libraries, such as modeling_llama, and bring those changes over. I recommend using "Copied From" as much as possible.

My 2 cents would be that it'd be better to use XLM_ROBERTA_SELF_ATTENTION_CLASSES instead of XLM_ROBERTA_ATTENTION_CLASSES, otherwise there'd be conflict with #30510 (and since they have a "Copied From" relationship, it wouldn't be easier to keep both copies).

Unfortunately, I won't be able to discuss on discord. You might be able to find some of the core maintainers there though.

@michaelshekasta
Copy link

@hackyon thanks again! you are awesome!
I think that I don't really understand the issue with Copy.from.. can you explain me more or refer me to resource?

@hackyon
Copy link
Contributor

hackyon commented May 8, 2024

I think that I don't really understand the issue with Copy.from.. can you explain me more or refer me to resource?

#30510 has a XLMROBERTAXL_SELF_ATTENTION_CLASSES, so it's better to add the FA2 implementation in there instead of creating a new XLM_ROBERTA_ATTENTION_CLASSES.

@michaelshekasta
Copy link

@hackyon, upon revisiting the thread, I've come to the conclusion that transitioning from SDPA to XLMR seems to be the most appropriate course of action. Given that PyTorch has already implemented FA2 with SDPA, I don't see the necessity of replicating it for XLMR. Is there anything I can assist with to facilitate the seamless integration of this pull request?

@hackyon
Copy link
Contributor

hackyon commented May 19, 2024

@michaelshekasta if you're up for it, you can clone #30510 and try to get that submitted. I will be starting a new job in a week and won't really have time to follow through with that PR.

If this is your first change, you'll need to read through their contributions page and follow the instructions there. Specifically, pay attention to getting "make fix-copies" and "RUN_SLOW" tests to pass before opening the PR.

@michaelshekasta
Copy link

michaelshekasta commented May 19, 2024

@hackyon Let's talk in the email [email protected]

@aikangjun
Copy link

hi, I want to know whether this job has been merged into 4.42.2?

@amyeroberts
Copy link
Collaborator

amyeroberts commented Aug 30, 2024

@aikangjun This PR wasn't merged - it closed because of inactivity it seems. We've recently merged in other PRs to add SDPA to roberta based models though #30510 which adds it to this model. This isn't part of 4.42 but will be part of the next release

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

XLMRoberta with Flash Attention 2
5 participants