From c52e0eb3e480f9b244c78224b846de6ffcba4a7f Mon Sep 17 00:00:00 2001 From: Alexander Visheratin Date: Thu, 18 Apr 2024 04:27:58 -0400 Subject: [PATCH] Add Flash Attention 2 to M2M100 model (#30256) * Added flash attention 2. * Fixes. * Fix inheritance. * Fixed init. * Remove stuff. * Added documentation. * Add FA2 to M2M100 documentation. * Add test. * Fixed documentation. * Update src/transformers/models/m2m_100/modeling_m2m_100.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update docs/source/en/model_doc/nllb.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fixed variable name. --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/en/model_doc/m2m_100.md | 42 +++ docs/source/en/model_doc/nllb.md | 43 +++ docs/source/en/perf_infer_gpu_one.md | 2 + .../models/m2m_100/modeling_m2m_100.py | 257 +++++++++++++++++- tests/models/m2m_100/test_modeling_m2m_100.py | 49 ++++ 5 files changed, 379 insertions(+), 14 deletions(-) diff --git a/docs/source/en/model_doc/m2m_100.md b/docs/source/en/model_doc/m2m_100.md index fa808c2e94bbfd..449e06ec30c29b 100644 --- a/docs/source/en/model_doc/m2m_100.md +++ b/docs/source/en/model_doc/m2m_100.md @@ -121,3 +121,45 @@ Hindi to French and Chinese to English using the *facebook/m2m100_418M* checkpoi [[autodoc]] M2M100ForConditionalGeneration - forward + +## Using Flash Attention 2 + +Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels. + +### Installation + +First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). + +Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2: + +```bash +pip install -U flash-attn --no-build-isolation +``` + +### Usage + +To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). You can use either `torch.float16` or `torch.bfloat16` precision. + +```python +>>> import torch +>>> from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer + +>>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda").eval() +>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M") + +>>> # translate Hindi to French +>>> hi_text = "जीवन एक चॉकलेट बॉक्स की तरह है।" +>>> tokenizer.src_lang = "hi" +>>> encoded_hi = tokenizer(hi_text, return_tensors="pt").to("cuda") +>>> generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.get_lang_id("fr")) +>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) +"La vie est comme une boîte de chocolat." +``` + +### Expected speedups + +Below is an expected speedup diagram that compares pure inference time between the native implementation and the Flash Attention 2. + +
+ +
diff --git a/docs/source/en/model_doc/nllb.md b/docs/source/en/model_doc/nllb.md index 3f272129d2f8f0..00a069e86af176 100644 --- a/docs/source/en/model_doc/nllb.md +++ b/docs/source/en/model_doc/nllb.md @@ -145,3 +145,46 @@ UN-Chef sagt, es gibt keine militärische Lösung in Syrien ## NllbTokenizerFast [[autodoc]] NllbTokenizerFast + +## Using Flash Attention 2 + +Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels. + +### Installation + +First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). + +Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2: + +```bash +pip install -U flash-attn --no-build-isolation +``` + +### Usage + +To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). You can use either `torch.float16` or `torch.bfloat16` precision. + +```python +>>> import torch +>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + +>>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda").eval() +>>> tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") + +>>> article = "Şeful ONU spune că nu există o soluţie militară în Siria" +>>> inputs = tokenizer(article, return_tensors="pt").to("cuda") + +>>> translated_tokens = model.generate( +... **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["deu_Latn"], max_length=30 +... ) +>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] +"UN-Chef sagt, es gibt keine militärische Lösung in Syrien" +``` + +### Expected speedups + +Below is an expected speedup diagram that compares pure inference time between the native implementation and the Flash Attention 2. + +
+ +
\ No newline at end of file diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 2a6c2e2b136ca5..5522885df7bd88 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -53,11 +53,13 @@ FlashAttention-2 is currently supported for the following architectures: * [Llava](https://huggingface.co/docs/transformers/model_doc/llava) * [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next) * [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava) +* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100) * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) +* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 9e2ff11ad88184..1517610b06111d 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch M2M100 model.""" - +"""PyTorch M2M100 model.""" import math from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss @@ -37,12 +37,19 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_m2m_100 import M2M100Config +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "M2M100Config" @@ -317,6 +324,208 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class M2M100FlashAttention2(M2M100Attention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[M2M100Config] = None, + ): + super().__init__(embed_dim, num_heads, dropout, is_decoder, bias, is_causal, config) + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout, softmax_scale=None + ) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100 class M2M100EncoderLayer(nn.Module): def __init__(self, config: M2M100Config): @@ -388,7 +597,10 @@ def forward( return outputs -M2M100_ATTENTION_CLASSES = {"eager": M2M100Attention} +M2M100_ATTENTION_CLASSES = { + "eager": M2M100Attention, + "flash_attention_2": M2M100FlashAttention2, +} # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100 @@ -517,6 +729,7 @@ class M2M100PreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["M2M100Attention"] + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.init_std @@ -687,6 +900,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -767,8 +981,11 @@ def forward( # expand attention_mask if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + if self._use_flash_attention_2: + attention_mask = attention_mask if 0 in attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -857,6 +1074,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = self.padding_idx, ) self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -967,18 +1185,24 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + if self._use_flash_attention_2: + # 2d mask is passed through the layers + combined_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + combined_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + if self._use_flash_attention_2: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) @@ -1102,6 +1326,11 @@ def __init__(self, config: M2M100Config): self.encoder = M2M100Encoder(config, self.shared) self.decoder = M2M100Decoder(config, self.shared) + if config._attn_implementation == "flash_attention_2": + logger.warning_once( + "Attention with Flash Attention 2 does not support `layer_head_mask`. If you need this feature, please use standard attention." + ) + # Initialize weights and apply final processing self.post_init() diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py index 39790917488d76..c280f698ba6cd7 100644 --- a/tests/models/m2m_100/test_modeling_m2m_100.py +++ b/tests/models/m2m_100/test_modeling_m2m_100.py @@ -19,12 +19,16 @@ import tempfile import unittest +import pytest + from transformers import M2M100Config, is_torch_available from transformers.testing_utils import ( + require_flash_attn, require_sentencepiece, require_tokenizers, require_torch, require_torch_fp16, + require_torch_gpu, slow, torch_device, ) @@ -412,3 +416,48 @@ def test_seq_to_seq_generation(self): hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True ) assert generated == expected_en + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_seq_to_seq_generation(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + model = M2M100ForConditionalGeneration.from_pretrained( + "facebook/m2m100_418M", attn_implementation="flash_attention_2" + ).to(torch_device) + + tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en") + + src_fr = [ + "L'affaire NSA souligne l'absence totale de débat sur le renseignement", + "Selon moi, il y a deux niveaux de réponse de la part du gouvernement français.", + "Lorsque François Hollande téléphone à Barack Obama ou quand le ministre des affaires étrangères Laurent" + " Fabius convoque l'ambassadeur des Etats-Unis, ils réagissent à une vraie découverte, qui est celle de" + " l'ampleur de la surveillance américaine sur l'ensemble des communications en France.", + ] + + # The below article tests that we don't add any hypotheses outside of the top n_beams + dct = tokenizer(src_fr, padding=True, return_tensors="pt") + + hypotheses_batch = model.generate( + input_ids=dct["input_ids"].to(torch_device), + attention_mask=dct["attention_mask"].to(torch_device), + num_beams=5, + forced_bos_token_id=tokenizer.get_lang_id("en"), + ) + + expected_en = [ + "The NSA case highlights the total absence of intelligence debate", + "I think there are two levels of response from the French government.", + "When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S." + " Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all" + " communications in France.", + ] + + generated = tokenizer.batch_decode( + hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True + ) + assert generated == expected_en