Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash Attention 2 to M2M100 model #30256

Merged
merged 13 commits into from
Apr 18, 2024
281 changes: 265 additions & 16 deletions src/transformers/models/m2m_100/modeling_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch M2M100 model."""

"""PyTorch M2M100 model."""

import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss

Expand All @@ -37,12 +37,19 @@
add_end_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_m2m_100 import M2M100Config


if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "M2M100Config"
Expand Down Expand Up @@ -317,6 +324,226 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


class M2M100FlashAttention2(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[M2M100Config] = None,
):
super().__init__()
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None

bsz, q_len, _ = hidden_states.size()

# get query proj
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2)
value_states = past_key_value[1].transpose(1, 2)
elif is_cross_attention:
# cross_attentions
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else:
# self_attention
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout, softmax_scale=None
)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

return attn_output, None, past_key_value

def _flash_attention_forward(
visheratin marked this conversation as resolved.
Show resolved Hide resolved
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1

# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)

return attn_output

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)


# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100
class M2M100EncoderLayer(nn.Module):
def __init__(self, config: M2M100Config):
Expand Down Expand Up @@ -388,7 +615,10 @@ def forward(
return outputs


M2M100_ATTENTION_CLASSES = {"eager": M2M100Attention}
M2M100_ATTENTION_CLASSES = {
"eager": M2M100Attention,
"flash_attention_2": M2M100FlashAttention2,
}


# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100
Expand Down Expand Up @@ -517,6 +747,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["M2M100Attention"]
_supports_flash_attn_2 = True

def _init_weights(self, module):
std = self.config.init_std
Expand Down Expand Up @@ -687,6 +918,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] =
)
self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"

self.gradient_checkpointing = False
# Initialize weights and apply final processing
Expand Down Expand Up @@ -767,8 +999,11 @@ def forward(

# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
Expand Down Expand Up @@ -857,6 +1092,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] =
self.padding_idx,
)
self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layer_norm = nn.LayerNorm(config.d_model)

self.gradient_checkpointing = False
Expand Down Expand Up @@ -967,18 +1203,24 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = _prepare_4d_causal_attention_mask(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why rename here from combined_attention_mask to attention_mask?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an artifact of debugging. I returned the old name.

attention_mask, input_shape, inputs_embeds, past_key_values_length
)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)

# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
if self._use_flash_attention_2:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)

# embed positions
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
Expand Down Expand Up @@ -1028,7 +1270,8 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,
# combined_attention_mask,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be better to keep the old name though

Suggested change
# combined_attention_mask,

attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
Expand All @@ -1040,7 +1283,8 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
# attention_mask=combined_attention_mask,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Suggested change
# attention_mask=combined_attention_mask,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch !

attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
Expand Down Expand Up @@ -1102,6 +1346,11 @@ def __init__(self, config: M2M100Config):
self.encoder = M2M100Encoder(config, self.shared)
self.decoder = M2M100Decoder(config, self.shared)

if config._attn_implementation == "flash_attention_2":
logger.warning_once(
"Attention with Flash Attention 2 does not support `layer_head_mask`. If you need this feature, please use standard attention."
)

# Initialize weights and apply final processing
self.post_init()

Expand Down