Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of Flash Attention 2 to MPT #26471

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ We natively support Flash Attention 2 for the following models:
- Llama
- Mistral
- Falcon
- MPT

You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*

Expand Down
114 changes: 112 additions & 2 deletions src/transformers/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,17 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ...utils import (
is_flash_attn_available,
is_torch_fx_available,
logging
)
from .configuration_mpt import MptConfig

if is_flash_attn_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input


logger = logging.get_logger(__name__)

Expand All @@ -55,6 +63,34 @@
]


# TODO: @younesbelkada move that in pytorch_utils and document it
if is_torch_fx_available():

@torch.fx.wrap
def check_padding_in_attention_mask(attention_mask):
if 0 in attention_mask:
return attention_mask
return None

else:

def check_padding_in_attention_mask(attention_mask):
if 0 in attention_mask:
return attention_mask
return None


# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(padding_mask):
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copied from transformers.models.bloom.modeling_bloom._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
Expand Down Expand Up @@ -135,6 +171,8 @@ def forward(
position_bias: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.LongTensor] = None,
):
batch_size, seq_length = hidden_states.shape[:2]

Expand Down Expand Up @@ -181,6 +219,71 @@ def forward(

return attn_output, attn_weights, past_key_value

#implement flash attention2 for Mosaic Pretrained transformer by inheriting the above MPTAttention class
class MptFlashAttention2(MptAttention):
def __init__(self, config: MptConfig):
super().__init__(config)
self.alibi_bias_max = config.flash_attn_config.alibi_bias_max

def forward(
self,
hidden_states: torch.Tensor,
position_bias: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.LongTensor] = None,
):
batch_size, seq_length = hidden_states.shape[:2]

mixed_qkv = self.Wqkv(hidden_states)
query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2)
query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)

if past_key_value is not None:
if len(past_key_value) != 0:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states)
else:
past_key_value = (key_states, value_states)

# Calculate FlashAttention2 Alibi tensor
alibi = build_mpt_alibi_tensor(self.n_heads, seq_length, alibi_bias_max=self.alibi_bias_max, device=query_states.device)

# Compute attention scores with Alibi tensor
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale + alibi

query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2]

if position_bias is not None:
if len(position_bias.shape) != 3:
raise ValueError(f"Expecting position_bias shape to be 3 dimensions, got {len(position_bias.shape)}")
key_length = key_states.shape[-2]

position_bias_query_index = max(0, position_bias.size(1) - query_length)
position_bias_key_index = max(0, position_bias.size(2) - key_length)

position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:]

attention_scores = attention_scores + position_bias

if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask, torch.finfo(query_states.dtype).min)

# (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).to(value_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attn_dropout_p, training=self.training)

context_states = torch.matmul(attn_weights, value_states)
context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
attn_output = self.out_proj(context_states)

return attn_output, attn_weights, past_key_value



class MptMLP(nn.Module):
def __init__(self, config: MptConfig):
Expand Down Expand Up @@ -213,7 +316,11 @@ def __init__(self, config: MptConfig):
self.norm_1.bias = None

self.num_heads = config.n_heads
self.attn = MptAttention(config)
self.attn = (
MptAttention(config)
if not getattr(config, "_flash_attn_2_enabled", False)
else MptFlashAttention2(config)
)

self.norm_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
# backward compatibility with weights on the Hub
Expand All @@ -232,6 +339,7 @@ def forward(
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# Layer norm at the beginning of the transformer layer.
Expand All @@ -245,6 +353,7 @@ def forward(
position_bias=position_bias,
attention_mask=attention_mask,
past_key_value=layer_past,
padding_mask=padding_mask,
)

hidden_states = self.resid_attn_dropout(attn_outputs) + residual
Expand Down Expand Up @@ -273,6 +382,7 @@ class MptPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MptBlock"]
_keys_to_ignore_on_load_missing = [r"lm_head.*."]
_supports_flash_attn_2 = True

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down