Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash Attention 2 to M2M100 model #30256

Merged
merged 13 commits into from
Apr 18, 2024
Merged

Conversation

visheratin
Copy link
Contributor

What does this PR do?

This PR adds support for Flash Attention 2 in M2M100 models (e.g., NLLB). Here is the Colab notebook with a working demo.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @younesbelkada

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @visheratin
Thanks for this great addition ! I see in the PR you used some old / deprecated variables such as _use_flash_attention_2, please see: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L405 - you need to inherit from M2M100Attention. There is also sligtly more work to be done on the documentation side to add expected speedups, check out this most recent PR: #29226 to see what are the required changes and let me know if you have any question - thanks !

@visheratin
Copy link
Contributor Author

The _use_flash_attention_2 is needed to handle different attention mask formats. As far as I can see, the same flags/logic is used across many other models (e.g., DistilBERT). Is there another better way to handle this?

@younesbelkada
Copy link
Contributor

@visheratin correct, for llama it's because the attention mask logic has been refactored in favor of AttentionMaskConverter. OK to use _use_flash_attention_2, the other points are still valid though, can you check what has been done for GPT2 and make the changes accordingly? 🙏

@visheratin
Copy link
Contributor Author

I fixed inheritance and added the sections about FA2 along with speedup image to the NLLB and M2M100 doc pages. I also added an integration test. Let me know if there is anything that needs to be done.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very clean ! Thanks for working on this ! I left one single comment - what do you think?

@visheratin
Copy link
Contributor Author

Sure! I committed the change.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for the smooth integration!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

Just some small nits to resolve before merge

docs/source/en/model_doc/nllb.md Outdated Show resolved Hide resolved
@@ -967,18 +1185,24 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = _prepare_4d_causal_attention_mask(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why rename here from combined_attention_mask to attention_mask?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an artifact of debugging. I returned the old name.

@@ -1028,7 +1252,8 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,
# combined_attention_mask,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be better to keep the old name though

Suggested change
# combined_attention_mask,

@@ -1040,7 +1265,8 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
# attention_mask=combined_attention_mask,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Suggested change
# attention_mask=combined_attention_mask,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch !

"I think there are two levels of response from the French government.",
"When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S."
" Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all"
" communications in France.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same examples as in the original tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, I didn't mean to be confusing, It's just that it was talking about surveillance so I thought I'd drop a wee spy

@visheratin visheratin requested a review from amyeroberts April 17, 2024 19:25
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for adding and iterating!

@visheratin
Copy link
Contributor Author

My pleasure! Thank you both, @amyeroberts and @younesbelkada, for the fast review!

@younesbelkada younesbelkada merged commit b65df51 into huggingface:main Apr 18, 2024
18 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Apr 18, 2024
* Added flash attention 2.

* Fixes.

* Fix inheritance.

* Fixed init.

* Remove stuff.

* Added documentation.

* Add FA2 to M2M100 documentation.

* Add test.

* Fixed documentation.

* Update src/transformers/models/m2m_100/modeling_m2m_100.py

Co-authored-by: Younes Belkada <[email protected]>

* Update docs/source/en/model_doc/nllb.md

Co-authored-by: amyeroberts <[email protected]>

* Fixed variable name.

---------

Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
ArthurZucker pushed a commit that referenced this pull request Apr 22, 2024
* Added flash attention 2.

* Fixes.

* Fix inheritance.

* Fixed init.

* Remove stuff.

* Added documentation.

* Add FA2 to M2M100 documentation.

* Add test.

* Fixed documentation.

* Update src/transformers/models/m2m_100/modeling_m2m_100.py

Co-authored-by: Younes Belkada <[email protected]>

* Update docs/source/en/model_doc/nllb.md

Co-authored-by: amyeroberts <[email protected]>

* Fixed variable name.

---------

Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
ydshieh pushed a commit that referenced this pull request Apr 23, 2024
* Added flash attention 2.

* Fixes.

* Fix inheritance.

* Fixed init.

* Remove stuff.

* Added documentation.

* Add FA2 to M2M100 documentation.

* Add test.

* Fixed documentation.

* Update src/transformers/models/m2m_100/modeling_m2m_100.py

Co-authored-by: Younes Belkada <[email protected]>

* Update docs/source/en/model_doc/nllb.md

Co-authored-by: amyeroberts <[email protected]>

* Fixed variable name.

---------

Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
itazap pushed a commit that referenced this pull request May 14, 2024
* Added flash attention 2.

* Fixes.

* Fix inheritance.

* Fixed init.

* Remove stuff.

* Added documentation.

* Add FA2 to M2M100 documentation.

* Add test.

* Fixed documentation.

* Update src/transformers/models/m2m_100/modeling_m2m_100.py

Co-authored-by: Younes Belkada <[email protected]>

* Update docs/source/en/model_doc/nllb.md

Co-authored-by: amyeroberts <[email protected]>

* Fixed variable name.

---------

Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants