Skip to content

Commit

Permalink
Adding Flash Attention 2 Support for GPT2 (#29226)
Browse files Browse the repository at this point in the history
* First commit to add flash attention 2 for GPT-2

* more improvements

* Make GPT2 pass tests and fixed Decison Transformers copies

* Fixed missing arg

* fix copies

* Added expected speedup

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Arthur <[email protected]>

* Added test

* Fixed attn attribute

* Update docs/source/en/model_doc/gpt2.md

Co-authored-by: Arthur <[email protected]>

* Update docs/source/en/model_doc/gpt2.md

Co-authored-by: Arthur <[email protected]>

* Update Decision transformer attentions

* More updates

* Passing tests

* Fix copies

* Fix copies part 2

* Decision transformer updates

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: amyeroberts <[email protected]>

* Fix copies

* Decision transformer not supporting flash attn

* Addressed comments

* Addressed comments

* Addressed comments

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
3 people authored and Ita Zaporozhets committed May 14, 2024
1 parent 4612d83 commit 0222c0d
Show file tree
Hide file tree
Showing 5 changed files with 377 additions and 25 deletions.
67 changes: 67 additions & 0 deletions docs/source/en/model_doc/gpt2.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,73 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
- Enabling the *scale_attn_by_inverse_layer_idx* and *reorder_and_upcast_attn* flags will apply the training stability
improvements from [Mistral](https://github.com/stanford-crfm/mistral/) (for PyTorch only).

## Usage example

The `generate()` method can be used to generate text using GPT2 model.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")

>>> prompt = "GPT2 is a model developed by OpenAI."

>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

>>> gen_tokens = model.generate(
... input_ids,
... do_sample=True,
... temperature=0.9,
... max_length=100,
... )
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
```

## Using Flash Attention 2

Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.

### Installation

First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). If your hardware is not compatible with Flash Attention 2, you can still benefit from attention kernel optimisations through Better Transformer support covered [above](https://huggingface.co/docs/transformers/main/en/model_doc/bark#using-better-transformer).

Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:

```bash
pip install -U flash-attn --no-build-isolation
```

### Usage

To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:

```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")

>>> prompt = "def hello_world():"

>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
>>> model.to(device)

>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
```


### Expected speedups

Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `gpt2` checkpoint and the Flash Attention 2 version of the model using a sequence length of 512.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/EduardoPacheco/documentation-images/resolve/main/gpt2_flash_attention_2_speedup.jpg">
</div>

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
class DecisionTransformerGPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()

self.config = config
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
Expand Down Expand Up @@ -146,6 +146,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.is_causal = True

self.pruned_heads = set()

Expand Down Expand Up @@ -346,6 +347,7 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl

# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
class DecisionTransformerGPT2Block(nn.Module):
# Ignore copy
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
Expand Down Expand Up @@ -497,7 +499,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings

# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Model.forward
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -548,7 +549,7 @@ def forward(
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)

# GPT2Attention mask.
# Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
Expand Down
Loading

0 comments on commit 0222c0d

Please sign in to comment.