Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash attention for gpt_bigcode #26479

Merged
merged 14 commits into from
Oct 31, 2023
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ We natively support Flash Attention 2 for the following models:

- Llama
- Falcon
- GPTBigCode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- GPTBigCode
- GPTBigCode (Starcoder)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*

Expand Down
252 changes: 249 additions & 3 deletions src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
Expand All @@ -32,11 +33,17 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_available,
logging,
)
from .configuration_gpt_bigcode import GPTBigCodeConfig


if is_flash_attn_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa


logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "bigcode/gpt_bigcode-santacoder"
Expand Down Expand Up @@ -78,6 +85,19 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor
return x


# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(padding_mask):
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


class GPTBigCodeAttention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
Expand Down Expand Up @@ -211,6 +231,8 @@ def forward(
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
encoder_padding_mask: Optional[torch.LongTensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should wait again for the padding mask refactor and not pass padding mask! #26792

output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Expand Down Expand Up @@ -262,6 +284,206 @@ def forward(
return outputs # a, present, (attentions)


class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
"""
GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module
stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
API of flash attention and deal with padding tokens in case the input contains any of them.
"""

def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be no more padding_mask now since #26792 has been merged, let me know if you want to handle this, otherwise I can have a look!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey I am currently working to fix this.

encoder_padding_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key_value = self.c_attn(encoder_hidden_states)
padding_mask = encoder_padding_mask
elif self.multi_query:
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
else:
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
# i.e., the memory layout is not the same as GPT2.
# This makes the concatenation with past_key_value more efficient.
query, key_value = (
self.c_attn(hidden_states)
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
.transpose(1, 2)
.split((self.head_dim, 2 * self.head_dim), dim=3)
)

if layer_past is not None:
key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None

key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
if self.multi_query:
batch_size, query_length, _ = query.shape
query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim)
key = key.unsqueeze(2)
value = value.unsqueeze(2)
else:
query_length = query.shape[2]
batch_size, _, tgt, _ = key.shape
query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim)
key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)

attn_dropout = self.dropout if self.training else 0.0
susnato marked this conversation as resolved.
Show resolved Hide resolved

softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
upcast = query.dtype != softmax_dtype
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
softmax_scale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
softmax_scale = softmax_scale**-1
if self.scale_attn_weights:
softmax_scale /= self.head_dim**0.5

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query.dtype
if input_dtype == torch.float32:
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
)
query = query.to(torch.float16)
key = key.to(torch.float16)
value = value.to(torch.float16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be fixed in the global fix I want to apply in #26451 as a follow up PR that I will take care

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I then remove this block? or are we keeping this block for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say we can keep it for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait until your PR is merge @younesbelkada 😉


attn_output = self._flash_attention_forward(
query, key, value, padding_mask, query_length, dropout=attn_dropout, softmax_scale=softmax_scale
)

attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
attn_output = self.c_proj(attn_weights_reshaped)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)

if output_attentions:
if self.multi_query:
# Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2)
else:
attn_weights_reshaped = None

outputs += (attn_weights_reshaped,)

return outputs # a, present, (attentions)

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
padding_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if padding_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, padding_mask, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
)

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
)

return attn_output

def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
batch_size, kv_seq_len, kv_num_heads, head_dim = key_layer.shape
query_num_heads = query_layer.shape[2]

key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, kv_num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, kv_num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, query_num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)


class GPTBigCodeMLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
Expand All @@ -287,13 +509,21 @@ def __init__(self, config, layer_idx=None):
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx)
self.attn = (
GPTBigCodeAttention(config, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else GPTBigCodeFlashAttention2(config, layer_idx=layer_idx)
)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

if config.add_cross_attention:
if config.multi_query:
raise NotImplementedError("Cross-attention not implemented for MQA")
self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
self.crossattention = (
GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
if not getattr(config, "_flash_attn_2_enabled", False)
else GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx)
)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

self.mlp = GPTBigCodeMLP(self.inner_dim, config)
Expand All @@ -307,6 +537,8 @@ def forward(
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
encoder_padding_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Expand All @@ -320,6 +552,8 @@ def forward(
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
padding_mask=padding_mask,
encoder_padding_mask=encoder_padding_mask,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
Expand All @@ -342,6 +576,8 @@ def forward(
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
padding_mask=padding_mask,
encoder_padding_mask=encoder_padding_mask,
)
attn_output = cross_attn_outputs[0]
# residual connection
Expand Down Expand Up @@ -373,6 +609,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down Expand Up @@ -586,6 +823,13 @@ def forward(
else:
past_length = past_key_values[0].size(-2)

padding_mask = None
if attention_mask is not None and 0 in attention_mask:
padding_mask = attention_mask
encoder_padding_mask = None
if encoder_attention_mask is not None and 0 in encoder_attention_mask:
encoder_padding_mask = encoder_attention_mask

if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
Expand Down Expand Up @@ -656,7 +900,7 @@ def forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return module(*inputs, use_cache, output_attentions, padding_mask, encoder_padding_mask)

return custom_forward

Expand All @@ -679,6 +923,8 @@ def custom_forward(*inputs):
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
padding_mask=padding_mask,
encoder_padding_mask=encoder_padding_mask,
)

hidden_states = outputs[0]
Expand Down