Skip to content

Commit

Permalink
Add Flash Attention 2 to M2M100 model (huggingface#30256)
Browse files Browse the repository at this point in the history
* Added flash attention 2.

* Fixes.

* Fix inheritance.

* Fixed init.

* Remove stuff.

* Added documentation.

* Add FA2 to M2M100 documentation.

* Add test.

* Fixed documentation.

* Update src/transformers/models/m2m_100/modeling_m2m_100.py

Co-authored-by: Younes Belkada <[email protected]>

* Update docs/source/en/model_doc/nllb.md

Co-authored-by: amyeroberts <[email protected]>

* Fixed variable name.

---------

Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
3 people authored and zucchini-nlp committed Apr 18, 2024
1 parent 4da1d81 commit 8bb292b
Show file tree
Hide file tree
Showing 5 changed files with 379 additions and 14 deletions.
42 changes: 42 additions & 0 deletions docs/source/en/model_doc/m2m_100.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,45 @@ Hindi to French and Chinese to English using the *facebook/m2m100_418M* checkpoi

[[autodoc]] M2M100ForConditionalGeneration
- forward

## Using Flash Attention 2

Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.

### Installation

First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features).

Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:

```bash
pip install -U flash-attn --no-build-isolation
```

### Usage

To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). You can use either `torch.float16` or `torch.bfloat16` precision.

```python
>>> import torch
>>> from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer

>>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda").eval()
>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")

>>> # translate Hindi to French
>>> hi_text = "जीवन एक चॉकलेट बॉक्स की तरह है।"
>>> tokenizer.src_lang = "hi"
>>> encoded_hi = tokenizer(hi_text, return_tensors="pt").to("cuda")
>>> generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.get_lang_id("fr"))
>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
"La vie est comme une boîte de chocolat."
```

### Expected speedups

Below is an expected speedup diagram that compares pure inference time between the native implementation and the Flash Attention 2.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/visheratin/documentation-images/resolve/main/nllb-speedup.webp">
</div>
43 changes: 43 additions & 0 deletions docs/source/en/model_doc/nllb.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,46 @@ UN-Chef sagt, es gibt keine militärische Lösung in Syrien
## NllbTokenizerFast

[[autodoc]] NllbTokenizerFast

## Using Flash Attention 2

Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.

### Installation

First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features).

Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:

```bash
pip install -U flash-attn --no-build-isolation
```

### Usage

To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). You can use either `torch.float16` or `torch.bfloat16` precision.

```python
>>> import torch
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

>>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda").eval()
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")

>>> article = "Şeful ONU spune că nu există o soluţie militară în Siria"
>>> inputs = tokenizer(article, return_tensors="pt").to("cuda")

>>> translated_tokens = model.generate(
... **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["deu_Latn"], max_length=30
... )
>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
"UN-Chef sagt, es gibt keine militärische Lösung in Syrien"
```

### Expected speedups

Below is an expected speedup diagram that compares pure inference time between the native implementation and the Flash Attention 2.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/visheratin/documentation-images/resolve/main/nllb-speedup.webp">
</div>
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ FlashAttention-2 is currently supported for the following architectures:
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
Expand Down
Loading

0 comments on commit 8bb292b

Please sign in to comment.