diff --git a/docs/source/en/model_doc/gpt_bigcode.md b/docs/source/en/model_doc/gpt_bigcode.md
index 6965d5837d8e74..8cc77a825de75c 100644
--- a/docs/source/en/model_doc/gpt_bigcode.md
+++ b/docs/source/en/model_doc/gpt_bigcode.md
@@ -42,6 +42,45 @@ The main differences compared to GPT2.
You can read more about the optimizations in the [original pull request](https://github.com/huggingface/transformers/pull/22575)
+## Combining Starcoder and Flash Attention 2
+
+First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
+
+```bash
+pip install -U flash-attn --no-build-isolation
+```
+
+Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
+
+To load and run a model using Flash Attention 2, refer to the snippet below:
+
+```python
+>>> import torch
+>>> from transformers import AutoModelForCausalLM, AutoTokenizer
+>>> device = "cuda" # the device to load the model onto
+
+>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True)
+>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder")
+
+>>> prompt = "def hello_world():"
+
+>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
+>>> model.to(device)
+
+>>> generated_ids = model.generate(**model_inputs, max_new_tokens=30, do_sample=False)
+>>> tokenizer.batch_decode(generated_ids)[0]
+'def hello_world():\n print("hello world")\n\nif __name__ == "__main__":\n print("hello world")\n<|endoftext|>'
+```
+
+### Expected speedups
+
+Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `bigcode/starcoder` checkpoint and the Flash Attention 2 version of the model using two different sequence lengths.
+
+
+
+
+
+
## GPTBigCodeConfig
[[autodoc]] GPTBigCodeConfig
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index d24299012e9fe1..39f2ca22b1f040 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -34,6 +34,7 @@ We natively support Flash Attention 2 for the following models:
- Llama
- Mistral
- Falcon
+- [GPTBigCode (Starcoder)](model_doc/gpt_bigcode#)
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*
diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
index f8e52b6510a0bd..fcbbfca5cedac7 100644
--- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -16,6 +16,7 @@
from typing import List, Optional, Tuple, Union
import torch
+import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -32,11 +33,17 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
logging,
)
from .configuration_gpt_bigcode import GPTBigCodeConfig
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "bigcode/gpt_bigcode-santacoder"
@@ -78,11 +85,25 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor
return x
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
class GPTBigCodeAttention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
- self.mask_value = None
+ self.config = config
+ self.mask_value = None
self.multi_query = config.multi_query
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
@@ -90,6 +111,8 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.kv_heads = 1 if self.multi_query else self.num_heads
self.kv_dim = self.kv_heads * self.head_dim
self.split_size = self.embed_dim
+ self.is_causal = True
+
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
@@ -212,10 +235,16 @@ def forward(
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
+ **kwargs,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
+ if "padding_mask" in kwargs:
+ logger.warning_once(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
raise ValueError(
@@ -262,6 +291,223 @@ def forward(
return outputs # a, present, (attentions)
+class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
+ """
+ GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module
+ stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
+ API of flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_past: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> Union[
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
+ Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
+ ]:
+ if "padding_mask" in kwargs:
+ logger.warning_once(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop("padding_mask")
+
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn") or not self.is_cross_attention:
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
+ )
+
+ query = self.q_attn(hidden_states)
+ key_value = self.c_attn(encoder_hidden_states)
+ attention_mask = encoder_attention_mask
+ elif self.multi_query:
+ query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
+ else:
+ # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
+ # i.e., the memory layout is not the same as GPT2.
+ # This makes the concatenation with past_key_value more efficient.
+ query, key_value = (
+ self.c_attn(hidden_states)
+ .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
+ .transpose(1, 2)
+ .split((self.head_dim, 2 * self.head_dim), dim=3)
+ )
+
+ if layer_past is not None:
+ key_value = torch.cat((layer_past, key_value), dim=-2)
+ present = key_value if use_cache else None
+
+ key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ if self.multi_query:
+ batch_size, query_length, _ = query.shape
+ query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim)
+ key = key.unsqueeze(2)
+ value = value.unsqueeze(2)
+ else:
+ query_length = query.shape[2]
+ batch_size, _, tgt, _ = key.shape
+ query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim)
+ key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
+ value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
+
+ attn_dropout = self.dropout if self.training else 0.0
+
+ softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
+ upcast = query.dtype != softmax_dtype
+ softmax_scale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
+ softmax_scale = softmax_scale**-1
+ if self.scale_attn_weights:
+ softmax_scale /= self.head_dim**0.5
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query.dtype
+ if input_dtype == torch.float32:
+ # Handle the case where the model is quantized
+ if hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.c_attn.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+ query = query.to(target_dtype)
+ key = key.to(target_dtype)
+ value = value.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(
+ query, key, value, attention_mask, query_length, dropout=attn_dropout, softmax_scale=softmax_scale
+ )
+
+ attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
+ attn_output = self.c_proj(attn_weights_reshaped)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, present)
+
+ if output_attentions:
+ if self.multi_query:
+ # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
+ attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2)
+ else:
+ attn_weights_reshaped = None
+
+ outputs += (attn_weights_reshaped,)
+
+ return outputs # a, present, (attentions)
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
+ def _flash_attention_forward(
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=self.is_causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
+ )
+
+ return attn_output
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
class GPTBigCodeMLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
@@ -287,13 +533,21 @@ def __init__(self, config, layer_idx=None):
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.attn = GPTBigCodeAttention(config, layer_idx=layer_idx)
+ self.attn = (
+ GPTBigCodeAttention(config, layer_idx=layer_idx)
+ if not getattr(config, "_flash_attn_2_enabled", False)
+ else GPTBigCodeFlashAttention2(config, layer_idx=layer_idx)
+ )
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention:
if config.multi_query:
raise NotImplementedError("Cross-attention not implemented for MQA")
- self.crossattention = GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
+ self.crossattention = (
+ GPTBigCodeAttention(config, is_cross_attention=True, layer_idx=layer_idx)
+ if not getattr(config, "_flash_attn_2_enabled", False)
+ else GPTBigCodeFlashAttention2(config, is_cross_attention=True, layer_idx=layer_idx)
+ )
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigCodeMLP(self.inner_dim, config)
@@ -373,6 +627,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@@ -594,28 +849,38 @@ def forward(
key_length = past_length + query_length
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
- if attention_mask is not None:
- self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(
- dtype=torch.bool, device=self_attention_mask.device
+ if getattr(self.config, "_flash_attn_2_enabled", False):
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None
+ encoder_attention_mask = (
+ encoder_attention_mask.bool()
+ if (encoder_attention_mask is not None and 0 in encoder_attention_mask)
+ else None
)
-
- # MQA models: (batch_size, query_length, n_heads, key_length)
- # MHA models: (batch_size, n_heads, query_length, key_length)
- attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
-
- # If a 2D or 3D attention mask is provided for the cross-attention
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
- if (
- self.config.add_cross_attention
- and encoder_hidden_states is not None
- and encoder_attention_mask is not None
- ):
- if encoder_attention_mask.dim() == 2:
- encoder_attention_mask.unsqueeze(1)
- assert encoder_attention_mask.dim() == 3
- encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1)
else:
- encoder_attention_mask = None
+ # 4d mask is passed through the layers
+ if attention_mask is not None:
+ self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(
+ dtype=torch.bool, device=self_attention_mask.device
+ )
+
+ # MQA models: (batch_size, query_length, n_heads, key_length)
+ # MHA models: (batch_size, n_heads, query_length, key_length)
+ attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if (
+ self.config.add_cross_attention
+ and encoder_hidden_states is not None
+ and encoder_attention_mask is not None
+ ):
+ if encoder_attention_mask.dim() == 2:
+ encoder_attention_mask.unsqueeze(1)
+ assert encoder_attention_mask.dim() == 3
+ encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1)
+ else:
+ encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head