Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Flash Attention 2 Support for GPT2 #29226

Merged
merged 33 commits into from
Mar 28, 2024

Conversation

EduardoPach
Copy link
Contributor

What does this PR do?

Fixes #26350

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Hey, @younesbelkada added flash attention 2 support for GPT2. The only thing missing is the Expected speedups, could you share the code you used for the other models you added support to keep consistency?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow thanks for the great work ! At a quick glance it seems you took care very well of the copy mechanism which is quite a challenge for GPT2 !
Please find the benchmarking script: https://gist.github.com/younesbelkada/02f35734da906cc0f2389ae4f665c58f I suggest to try it out for prefill only on large sequence length - let us know with @ArthurZucker @fxmarty how it goes

@EduardoPach
Copy link
Contributor Author

Wow thanks for the great work ! At a quick glance it seems you took care very well of the copy mechanism which is quite a challenge for GPT2 ! Please find the benchmarking script: https://gist.github.com/younesbelkada/02f35734da906cc0f2389ae4f665c58f I suggest to try it out for prefill only on large sequence length - let us know with @ArthurZucker @fxmarty how it goes

Hey, I don't have a GPU and I was renting in RunPod an RTX 3090 to work on this PR, is it a problem to use the 3090 to benchmark or should I switch to an A100 (which I believe it was the GPU used in the other benchmarks at least the ones I've seen)?

@younesbelkada
Copy link
Contributor

Thanks @EduardoPach for getting back, I think using a 3090 is fine !

@EduardoPach
Copy link
Contributor Author

@ArthurZucker I believe it should be ready for review

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but we need to add a test 😉

src/transformers/models/gpt2/modeling_gpt2.py Outdated Show resolved Hide resolved
src/transformers/models/gpt2/modeling_gpt2.py Outdated Show resolved Hide resolved
src/transformers/models/gpt2/modeling_gpt2.py Outdated Show resolved Hide resolved
@younesbelkada
Copy link
Contributor

@EduardoPach thanks again, what @ArthurZucker meant is an integration test similar as:

def test_flash_attn_2_generate_padding_right(self):
for GPT2 only, would you be happy to work on that? 🙏

@EduardoPach
Copy link
Contributor Author

@EduardoPach thanks again, what @ArthurZucker meant is an integration test similar as:

def test_flash_attn_2_generate_padding_right(self):
for GPT2 only, would you be happy to work on that? 🙏

Yeah, I will add the test in the following hours

@EduardoPach
Copy link
Contributor Author

Following hours became more like following days haha, but should be good now @ArthurZucker

@EduardoPach EduardoPach requested a review from ArthurZucker March 1, 2024 14:01
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost good, left a few nits

docs/source/en/model_doc/gpt2.md Outdated Show resolved Hide resolved
docs/source/en/model_doc/gpt2.md Outdated Show resolved Hide resolved
@@ -346,21 +572,25 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states


# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
DECISIONTRANSFORMERGPT2_ATTENTION_CLASSES = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DECISIONTRANSFORMERGPT2_ATTENTION_CLASSES = {
DECISION_TRANSFORMER_GPT2_ATTENTION_CLASSES = {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
DECISIONTRANSFORMERGPT2_ATTENTION_CLASSES = {
"eager": DecisionTransformerGPT2Attention,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is DecisionTransformerGPT2FlashAttention2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't added it there, but added it now here 74fb9bd. However, DecisionTransformer does not support flash attention yet just I had to do these modifications to make sure nothing would break with the Copy from statements.

tests/models/gpt2/test_modeling_gpt2.py Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGMT on final nit for the test to have explicit values

tests/models/gpt2/test_modeling_gpt2.py Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again ! We just merged some fixes on main - could you rebase again 🙏 then we should finally merge :D sorry for all the iterations !

@EduardoPach
Copy link
Contributor Author

Thanks again ! We just merged some fixes on main - could you rebase again 🙏 then we should finally merge :D sorry for all the iterations !

No worries! Done

@younesbelkada
Copy link
Contributor

Thanks ! Hmm I can't see the rebase commit on the history, perhaps can you try again ?

@EduardoPach
Copy link
Contributor Author

Thanks ! Hmm I can't see the rebase commit on the history, perhaps can you try again ?

I've done git fetch upstream && git merge upstream/main here dcde56c

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this and making our models go brrr 🔥

Just a few small comments. The diffs in the READMEs will need to be resolved before we can merge

README_de.md Outdated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There shouldn't be readme changes here. Can you make sure to rebase on main to include the mode recent changes?

src/transformers/models/gpt2/modeling_gpt2.py Outdated Show resolved Hide resolved
docs/source/en/perf_infer_gpu_one.md Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the continued work on this!

Only thing left to do is make sure decision transformer has the updated documentation and tests

@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_left(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The equivalent test for decision transformer should also be added

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the test in GPT2 already cover Decision Transformer? Since, basically the usage of Flash Attention in Decision Transformer happens exactly due to GPT2Model being embedded in its architecture

Copy link
Collaborator

@amyeroberts amyeroberts Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both models should be tested. This makes sure that if anything changes upstream they remain correct, for example, inputs preparation in DecisionTransformerModel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While adding the test for DecisionTransformer I realized that the model has two distinct xxxPreTrainedModels and that adding support for flash_attention_2 would be a bit more complicated, therefore I believe it would be better to have a specific PR to add support.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, flash attention shouldn't be added at all for the model. You can use #Ignore copy on the init so the previous attention class' method is used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -60,6 +60,73 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
- Enabling the *scale_attn_by_inverse_layer_idx* and *reorder_and_upcast_attn* flags will apply the training stability
improvements from [Mistral](https://github.com/stanford-crfm/mistral/) (for PyTorch only).

## Usage example
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have the equivalent added for decision transformer too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See message above

@EduardoPach EduardoPach requested a review from amyeroberts March 25, 2024 10:10
@NieShenRuc
Copy link

Thank you for your hard work! Many of us are excited about the GPT-2 model supporting flash attention. May I ask when the PR is expected to be merged?

@EduardoPach
Copy link
Contributor Author

Thank you for your hard work! Many of us are excited about the GPT-2 model supporting flash attention. May I ask when the PR is expected to be merged?

Hey, I believe if @amyeroberts agrees with my latest message it should get merged right away 🤞

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating - a few final places to tidy up.

@@ -40,8 +40,10 @@ FlashAttention-2 is currently supported for the following architectures:
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [DecisionTransformer](https://huggingface.co/docs/transformers/en/model_doc/decision_transformer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed

@@ -548,25 +551,26 @@ def forward(
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)

# GPT2Attention mask.
# Attention mask.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I would use # ignore copy - the model shouldn't have FA2 logic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

@@ -575,7 +579,8 @@ def forward(
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here above this line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this for GPT2 and iterating on a solution!

@NieShenRuc
Copy link

It seems everything is okay. May I kindly request to merge this PR? I am really looking forward to speeding up my GPT-2. If my request has added to your workload, I apologize for any inconvenience.

@EduardoPach
Copy link
Contributor Author

It seems everything is okay. May I kindly request to merge this PR? I am really looking forward to speeding up my GPT-2. If my request has added to your workload, I apologize for any inconvenience.

c.c. @amyeroberts

@amyeroberts amyeroberts merged commit 22d159d into huggingface:main Mar 28, 2024
19 checks passed
itazap pushed a commit that referenced this pull request May 14, 2024
* First commit to add flash attention 2 for GPT-2

* more improvements

* Make GPT2 pass tests and fixed Decison Transformers copies

* Fixed missing arg

* fix copies

* Added expected speedup

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Arthur <[email protected]>

* Added test

* Fixed attn attribute

* Update docs/source/en/model_doc/gpt2.md

Co-authored-by: Arthur <[email protected]>

* Update docs/source/en/model_doc/gpt2.md

Co-authored-by: Arthur <[email protected]>

* Update Decision transformer attentions

* More updates

* Passing tests

* Fix copies

* Fix copies part 2

* Decision transformer updates

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: amyeroberts <[email protected]>

* Fix copies

* Decision transformer not supporting flash attn

* Addressed comments

* Addressed comments

* Addressed comments

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Community contribution: Adding Flash Attention 2 support for more architectures
6 participants