Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash Attention 2 to M2M100 model #30256

Merged
merged 13 commits into from
Apr 18, 2024
42 changes: 42 additions & 0 deletions docs/source/en/model_doc/m2m_100.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,45 @@ Hindi to French and Chinese to English using the *facebook/m2m100_418M* checkpoi

[[autodoc]] M2M100ForConditionalGeneration
- forward

## Using Flash Attention 2

Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.

### Installation

First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features).

Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:

```bash
pip install -U flash-attn --no-build-isolation
```

### Usage

To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). You can use either `torch.float16` or `torch.bfloat16` precision.

```python
>>> import torch
>>> from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer

>>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda").eval()
>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")

>>> # translate Hindi to French
>>> hi_text = "जीवन एक चॉकलेट बॉक्स की तरह है।"
>>> tokenizer.src_lang = "hi"
>>> encoded_hi = tokenizer(hi_text, return_tensors="pt").to("cuda")
>>> generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.get_lang_id("fr"))
>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
"La vie est comme une boîte de chocolat."
```

### Expected speedups

Below is an expected speedup diagram that compares pure inference time between the native implementation and the Flash Attention 2.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/visheratin/documentation-images/resolve/main/nllb-speedup.webp">
</div>
43 changes: 43 additions & 0 deletions docs/source/en/model_doc/nllb.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,46 @@ UN-Chef sagt, es gibt keine militärische Lösung in Syrien
## NllbTokenizerFast

[[autodoc]] NllbTokenizerFast

## Using Flash Attention 2

Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.

### Installation

First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features).

Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:

```bash
pip install -U flash-attn --no-build-isolation
```

### Usage

To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). You can use either `torch.float16` or `torch.bfloat16` precision.

```python
>>> import torch
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

>>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda").eval()
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")

>>> article = "Şeful ONU spune că nu există o soluţie militară în Siria"
>>> inputs = tokenizer(article, return_tensors="pt").to("cuda")

>>> translated_tokens = model.generate(
... **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["deu_Latn"], max_length=30
... )
>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
UN-Chef sagt, es gibt keine militärische Lösung in Syrien
visheratin marked this conversation as resolved.
Show resolved Hide resolved
```

### Expected speedups

Below is an expected speedup diagram that compares pure inference time between the native implementation and the Flash Attention 2.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/visheratin/documentation-images/resolve/main/nllb-speedup.webp">
</div>
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ FlashAttention-2 is currently supported for the following architectures:
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
Expand Down
Loading
Loading