-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Flash Attention 2 Support for GPT2 #29226
Adding Flash Attention 2 Support for GPT2 #29226
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow thanks for the great work ! At a quick glance it seems you took care very well of the copy mechanism which is quite a challenge for GPT2 !
Please find the benchmarking script: https://gist.github.com/younesbelkada/02f35734da906cc0f2389ae4f665c58f I suggest to try it out for prefill only on large sequence length - let us know with @ArthurZucker @fxmarty how it goes
Hey, I don't have a GPU and I was renting in RunPod an RTX 3090 to work on this PR, is it a problem to use the 3090 to benchmark or should I switch to an A100 (which I believe it was the GPU used in the other benchmarks at least the ones I've seen)? |
Thanks @EduardoPach for getting back, I think using a 3090 is fine ! |
@ArthurZucker I believe it should be ready for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work !
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but we need to add a test 😉
@EduardoPach thanks again, what @ArthurZucker meant is an integration test similar as:
|
Yeah, I will add the test in the following hours |
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
Following hours became more like following days haha, but should be good now @ArthurZucker |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost good, left a few nits
@@ -346,21 +572,25 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl | |||
return hidden_states | |||
|
|||
|
|||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2 | |||
DECISIONTRANSFORMERGPT2_ATTENTION_CLASSES = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DECISIONTRANSFORMERGPT2_ATTENTION_CLASSES = { | |
DECISION_TRANSFORMER_GPT2_ATTENTION_CLASSES = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2 | ||
DECISIONTRANSFORMERGPT2_ATTENTION_CLASSES = { | ||
"eager": DecisionTransformerGPT2Attention, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is DecisionTransformerGPT2FlashAttention2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't added it there, but added it now here 74fb9bd. However, DecisionTransformer
does not support flash attention yet just I had to do these modifications to make sure nothing would break with the Copy from
statements.
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGMT on final nit for the test to have explicit values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again ! We just merged some fixes on main - could you rebase again 🙏 then we should finally merge :D sorry for all the iterations !
No worries! Done |
Thanks ! Hmm I can't see the rebase commit on the history, perhaps can you try again ? |
I've done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this and making our models go brrr 🔥
Just a few small comments. The diffs in the READMEs will need to be resolved before we can merge
README_de.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There shouldn't be readme changes here. Can you make sure to rebase on main to include the mode recent changes?
Co-authored-by: amyeroberts <[email protected]>
…transformers into add-flash-attn-gpt2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the continued work on this!
Only thing left to do is make sure decision transformer has the updated documentation and tests
@require_torch_gpu | ||
@pytest.mark.flash_attn_test | ||
@slow | ||
def test_flash_attn_2_generate_padding_left(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The equivalent test for decision transformer should also be added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't the test in GPT2 already cover Decision Transformer? Since, basically the usage of Flash Attention in Decision Transformer happens exactly due to GPT2Model being embedded in its architecture
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both models should be tested. This makes sure that if anything changes upstream they remain correct, for example, inputs preparation in DecisionTransformerModel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While adding the test for DecisionTransformer
I realized that the model has two distinct xxxPreTrainedModel
s and that adding support for flash_attention_2
would be a bit more complicated, therefore I believe it would be better to have a specific PR to add support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, flash attention shouldn't be added at all for the model. You can use #Ignore copy
on the init so the previous attention class' method is used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -60,6 +60,73 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o | |||
- Enabling the *scale_attn_by_inverse_layer_idx* and *reorder_and_upcast_attn* flags will apply the training stability | |||
improvements from [Mistral](https://github.com/stanford-crfm/mistral/) (for PyTorch only). | |||
|
|||
## Usage example |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have the equivalent added for decision transformer too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See message above
Thank you for your hard work! Many of us are excited about the GPT-2 model supporting flash attention. May I ask when the PR is expected to be merged? |
Hey, I believe if @amyeroberts agrees with my latest message it should get merged right away 🤞 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating - a few final places to tidy up.
docs/source/en/perf_infer_gpu_one.md
Outdated
@@ -40,8 +40,10 @@ FlashAttention-2 is currently supported for the following architectures: | |||
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel) | |||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) | |||
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) | |||
* [DecisionTransformer](https://huggingface.co/docs/transformers/en/model_doc/decision_transformer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be removed
@@ -548,25 +551,26 @@ def forward( | |||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) | |||
position_ids = position_ids.unsqueeze(0) | |||
|
|||
# GPT2Attention mask. | |||
# Attention mask. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I would use # ignore copy - the model shouldn't have FA2 logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed
@@ -575,7 +579,8 @@ def forward( | |||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) | |||
if encoder_attention_mask is None: | |||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) | |||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here above this line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this for GPT2 and iterating on a solution!
It seems everything is okay. May I kindly request to merge this PR? I am really looking forward to speeding up my GPT-2. If my request has added to your workload, I apologize for any inconvenience. |
c.c. @amyeroberts |
* First commit to add flash attention 2 for GPT-2 * more improvements * Make GPT2 pass tests and fixed Decison Transformers copies * Fixed missing arg * fix copies * Added expected speedup * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <[email protected]> * Added test * Fixed attn attribute * Update docs/source/en/model_doc/gpt2.md Co-authored-by: Arthur <[email protected]> * Update docs/source/en/model_doc/gpt2.md Co-authored-by: Arthur <[email protected]> * Update Decision transformer attentions * More updates * Passing tests * Fix copies * Fix copies part 2 * Decision transformer updates * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: amyeroberts <[email protected]> * Fix copies * Decision transformer not supporting flash attn * Addressed comments * Addressed comments * Addressed comments --------- Co-authored-by: Arthur <[email protected]> Co-authored-by: amyeroberts <[email protected]>
What does this PR do?
Fixes #26350
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Hey, @younesbelkada added flash attention 2 support for
GPT2
. The only thing missing is the Expected speedups, could you share the code you used for the other models you added support to keep consistency?