From 8cf5ae21dfd56d572949623426bf42f0311e6cd5 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 28 Sep 2023 20:47:07 +0530 Subject: [PATCH 1/3] Flash attention --- docs/source/en/perf_infer_gpu_one.md | 1 + src/transformers/models/mpt/modeling_mpt.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index f0c0bf0b107154..673409461e8c37 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -33,6 +33,7 @@ We natively support Flash Attention 2 for the following models: - Llama - Falcon +- MPT You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.* diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 0c608dbd2a93bc..c469cccf49dfc5 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -32,9 +32,13 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import logging +from ...utils import is_flash_attn_available, logging from .configuration_mpt import MptConfig +if is_flash_attn_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + logger = logging.get_logger(__name__) @@ -273,6 +277,7 @@ class MptPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MptBlock"] _keys_to_ignore_on_load_missing = [r"lm_head.*."] + _supports_flash_attn_2 = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) From d7f7e3dba3d15ecf677a30ccd2c1bfddeda9baba Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 30 Sep 2023 16:22:39 +0530 Subject: [PATCH 2/3] Add class of FA2 --- src/transformers/models/mpt/modeling_mpt.py | 40 ++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index c469cccf49dfc5..266e39db3c9185 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -32,7 +32,11 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import is_flash_attn_available, logging +from ...utils import ( + is_flash_attn_available, + is_torch_fx_available, + logging +) from .configuration_mpt import MptConfig if is_flash_attn_available(): @@ -59,6 +63,34 @@ ] +# TODO: @younesbelkada move that in pytorch_utils and document it +if is_torch_fx_available(): + + @torch.fx.wrap + def check_padding_in_attention_mask(attention_mask): + if 0 in attention_mask: + return attention_mask + return None + +else: + + def check_padding_in_attention_mask(attention_mask): + if 0 in attention_mask: + return attention_mask + return None + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(padding_mask): + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) # Copied from transformers.models.bloom.modeling_bloom._make_causal_mask def _make_causal_mask( input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int @@ -139,6 +171,8 @@ def forward( position_bias: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + encoder_padding_mask: Optional[torch.LongTensor] = None, ): batch_size, seq_length = hidden_states.shape[:2] @@ -185,6 +219,10 @@ def forward( return attn_output, attn_weights, past_key_value +#implement flash attention2 for Mosaic Pretrained transformer by inheriting the above MPTAttention class +class MPTAttention2(MptAttention): + pass + class MptMLP(nn.Module): def __init__(self, config: MptConfig): From b6ec7da72c5c9b85f826fd0136dbce74cc512b02 Mon Sep 17 00:00:00 2001 From: Rockerz <64583161+rajveer43@users.noreply.github.com> Date: Wed, 4 Oct 2023 06:47:57 +0000 Subject: [PATCH 3/3] Add FA2 class for MPT --- src/transformers/models/mpt/modeling_mpt.py | 73 ++++++++++++++++++++- 1 file changed, 70 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 266e39db3c9185..69111d0e07b9cb 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -220,8 +220,69 @@ def forward( return attn_output, attn_weights, past_key_value #implement flash attention2 for Mosaic Pretrained transformer by inheriting the above MPTAttention class -class MPTAttention2(MptAttention): - pass +class MptFlashAttention2(MptAttention): + def __init__(self, config: MptConfig): + super().__init__(config) + self.alibi_bias_max = config.flash_attn_config.alibi_bias_max + + def forward( + self, + hidden_states: torch.Tensor, + position_bias: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + encoder_padding_mask: Optional[torch.LongTensor] = None, + ): + batch_size, seq_length = hidden_states.shape[:2] + + mixed_qkv = self.Wqkv(hidden_states) + query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2) + query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + if len(past_key_value) != 0: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + past_key_value = (key_states, value_states) + else: + past_key_value = (key_states, value_states) + + # Calculate FlashAttention2 Alibi tensor + alibi = build_mpt_alibi_tensor(self.n_heads, seq_length, alibi_bias_max=self.alibi_bias_max, device=query_states.device) + + # Compute attention scores with Alibi tensor + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale + alibi + + query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] + + if position_bias is not None: + if len(position_bias.shape) != 3: + raise ValueError(f"Expecting position_bias shape to be 3 dimensions, got {len(position_bias.shape)}") + key_length = key_states.shape[-2] + + position_bias_query_index = max(0, position_bias.size(1) - query_length) + position_bias_key_index = max(0, position_bias.size(2) - key_length) + + position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:] + + attention_scores = attention_scores + position_bias + + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, torch.finfo(query_states.dtype).min) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attn_dropout_p, training=self.training) + + context_states = torch.matmul(attn_weights, value_states) + context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) + attn_output = self.out_proj(context_states) + + return attn_output, attn_weights, past_key_value + class MptMLP(nn.Module): @@ -255,7 +316,11 @@ def __init__(self, config: MptConfig): self.norm_1.bias = None self.num_heads = config.n_heads - self.attn = MptAttention(config) + self.attn = ( + MptAttention(config) + if not getattr(config, "_flash_attn_2_enabled", False) + else MptFlashAttention2(config) + ) self.norm_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) # backward compatibility with weights on the Hub @@ -274,6 +339,7 @@ def forward( layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, output_attentions: bool = False, + padding_mask: Optional[torch.LongTensor] = None, ): # hidden_states: [batch_size, seq_length, hidden_size] # Layer norm at the beginning of the transformer layer. @@ -287,6 +353,7 @@ def forward( position_bias=position_bias, attention_mask=attention_mask, past_key_value=layer_past, + padding_mask=padding_mask, ) hidden_states = self.resid_attn_dropout(attn_outputs) + residual