Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Flash Attention 2 Support for GPT2 #29226

Merged
merged 33 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
bfd9a8f
First commit to add flash attention 2 for GPT-2
EduardoPach Feb 22, 2024
333a867
more improvements
EduardoPach Feb 22, 2024
b4a053a
Make GPT2 pass tests and fixed Decison Transformers copies
EduardoPach Feb 23, 2024
eaaa6e4
Fixed missing arg
EduardoPach Feb 23, 2024
b5ded2e
fix copies
EduardoPach Feb 23, 2024
774414b
Added expected speedup
EduardoPach Feb 24, 2024
d405771
Update src/transformers/models/gpt2/modeling_gpt2.py
EduardoPach Mar 1, 2024
06bf96e
Update src/transformers/models/gpt2/modeling_gpt2.py
EduardoPach Mar 1, 2024
ad65025
Update src/transformers/models/gpt2/modeling_gpt2.py
EduardoPach Mar 1, 2024
bc7f558
Added test
EduardoPach Mar 1, 2024
19f171d
Fixed attn attribute
EduardoPach Mar 1, 2024
78c78f6
Update docs/source/en/model_doc/gpt2.md
EduardoPach Mar 4, 2024
a895172
Update docs/source/en/model_doc/gpt2.md
EduardoPach Mar 4, 2024
74fb9bd
Update Decision transformer attentions
EduardoPach Mar 4, 2024
fc1cf99
More updates
EduardoPach Mar 13, 2024
cafc0f1
Passing tests
EduardoPach Mar 14, 2024
dcde56c
Merge remote-tracking branch 'upstream/main' into add-flash-attn-gpt2
EduardoPach Mar 14, 2024
2ea5301
Fix copies
EduardoPach Mar 14, 2024
a67557f
Fix copies part 2
EduardoPach Mar 14, 2024
6fe34ab
Merge remote-tracking branch 'upstream/main' into add-flash-attn-gpt2
EduardoPach Mar 14, 2024
c043c27
Merge remote-tracking branch 'upstream/main' into add-flash-attn-gpt2
EduardoPach Mar 15, 2024
c36c60a
Decision transformer updates
EduardoPach Mar 15, 2024
656561b
Update src/transformers/models/gpt2/modeling_gpt2.py
EduardoPach Mar 15, 2024
9be39bf
Merge remote-tracking branch 'upstream/main' into add-flash-attn-gpt2
EduardoPach Mar 15, 2024
54935c5
Merge branch 'add-flash-attn-gpt2' of https://github.com/EduardoPach/…
EduardoPach Mar 15, 2024
97c04ec
Fix copies
EduardoPach Mar 15, 2024
5fb3a3d
Merge remote-tracking branch 'upstream/main' into add-flash-attn-gpt2
EduardoPach Mar 18, 2024
a938fc9
Decision transformer not supporting flash attn
EduardoPach Mar 23, 2024
0c42513
Merge remote-tracking branch 'upstream/main' into add-flash-attn-gpt2
EduardoPach Mar 23, 2024
393d17f
Merge remote-tracking branch 'upstream/main' into add-flash-attn-gpt2
EduardoPach Mar 26, 2024
94c2fe8
Addressed comments
EduardoPach Mar 26, 2024
343c04e
Addressed comments
EduardoPach Mar 26, 2024
2799988
Addressed comments
EduardoPach Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions docs/source/en/model_doc/gpt2.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,73 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
- Enabling the *scale_attn_by_inverse_layer_idx* and *reorder_and_upcast_attn* flags will apply the training stability
improvements from [Mistral](https://github.com/stanford-crfm/mistral/) (for PyTorch only).

## Usage example
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have the equivalent added for decision transformer too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See message above


The `generate()` method can be used to generate text using GPT2 model.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")

>>> prompt = "GPT2 is a model developed by OpenAI."

>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

>>> gen_tokens = model.generate(
... input_ids,
... do_sample=True,
... temperature=0.9,
... max_length=100,
... )
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
```

## Using Flash Attention 2

Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.

### Installation

First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). If your hardware is not compatible with Flash Attention 2, you can still benefit from attention kernel optimisations through Better Transformer support covered [above](https://huggingface.co/docs/transformers/main/en/model_doc/bark#using-better-transformer).

Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:

```bash
pip install -U flash-attn --no-build-isolation
```

### Usage

To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:

```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")

>>> prompt = "def hello_world():"

>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
>>> model.to(device)

>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
```


### Expected speedups

Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `gpt2` checkpoint and the Flash Attention 2 version of the model using a sequence length of 512.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/EduardoPacheco/documentation-images/resolve/main/gpt2_flash_attention_2_speedup.jpg">
</div>

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
EduardoPach marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ FlashAttention-2 is currently supported for the following architectures:
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [DecisionTransformer](https://huggingface.co/docs/transformers/en/model_doc/decision_transformer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed

* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
Expand Down
Loading
Loading