From 0222c0d4316671105f14ee47c07d794da34c23c8 Mon Sep 17 00:00:00 2001
From: Eduardo Pacheco <69953243+EduardoPach@users.noreply.github.com>
Date: Thu, 28 Mar 2024 10:31:24 +0100
Subject: [PATCH] Adding Flash Attention 2 Support for GPT2 (#29226)
* First commit to add flash attention 2 for GPT-2
* more improvements
* Make GPT2 pass tests and fixed Decison Transformers copies
* Fixed missing arg
* fix copies
* Added expected speedup
* Update src/transformers/models/gpt2/modeling_gpt2.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/models/gpt2/modeling_gpt2.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/models/gpt2/modeling_gpt2.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Added test
* Fixed attn attribute
* Update docs/source/en/model_doc/gpt2.md
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update docs/source/en/model_doc/gpt2.md
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update Decision transformer attentions
* More updates
* Passing tests
* Fix copies
* Fix copies part 2
* Decision transformer updates
* Update src/transformers/models/gpt2/modeling_gpt2.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Fix copies
* Decision transformer not supporting flash attn
* Addressed comments
* Addressed comments
* Addressed comments
---------
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
---
docs/source/en/model_doc/gpt2.md | 67 +++++
docs/source/en/perf_infer_gpu_one.md | 1 +
.../modeling_decision_transformer.py | 7 +-
src/transformers/models/gpt2/modeling_gpt2.py | 279 ++++++++++++++++--
tests/models/gpt2/test_modeling_gpt2.py | 48 ++-
5 files changed, 377 insertions(+), 25 deletions(-)
diff --git a/docs/source/en/model_doc/gpt2.md b/docs/source/en/model_doc/gpt2.md
index 4708edde0b65d4..b2afbbd3b2ec40 100644
--- a/docs/source/en/model_doc/gpt2.md
+++ b/docs/source/en/model_doc/gpt2.md
@@ -60,6 +60,73 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
- Enabling the *scale_attn_by_inverse_layer_idx* and *reorder_and_upcast_attn* flags will apply the training stability
improvements from [Mistral](https://github.com/stanford-crfm/mistral/) (for PyTorch only).
+## Usage example
+
+The `generate()` method can be used to generate text using GPT2 model.
+
+```python
+>>> from transformers import AutoModelForCausalLM, AutoTokenizer
+
+>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
+>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
+
+>>> prompt = "GPT2 is a model developed by OpenAI."
+
+>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+
+>>> gen_tokens = model.generate(
+... input_ids,
+... do_sample=True,
+... temperature=0.9,
+... max_length=100,
+... )
+>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
+```
+
+## Using Flash Attention 2
+
+Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.
+
+### Installation
+
+First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). If your hardware is not compatible with Flash Attention 2, you can still benefit from attention kernel optimisations through Better Transformer support covered [above](https://huggingface.co/docs/transformers/main/en/model_doc/bark#using-better-transformer).
+
+Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
+
+```bash
+pip install -U flash-attn --no-build-isolation
+```
+
+### Usage
+
+To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
+
+```python
+>>> import torch
+>>> from transformers import AutoModelForCausalLM, AutoTokenizer
+>>> device = "cuda" # the device to load the model onto
+
+>>> model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
+>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
+
+>>> prompt = "def hello_world():"
+
+>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
+>>> model.to(device)
+
+>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
+>>> tokenizer.batch_decode(generated_ids)[0]
+```
+
+
+### Expected speedups
+
+Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `gpt2` checkpoint and the Flash Attention 2 version of the model using a sequence length of 512.
+
+
+
+
+
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index 90409b1c21bc5b..0fbea1cd8d3d03 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -42,6 +42,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
+* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py
index 9dd9d95c387968..6f939460aab86f 100755
--- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py
+++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py
@@ -108,7 +108,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
class DecisionTransformerGPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
-
+ self.config = config
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
@@ -146,6 +146,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
+ self.is_causal = True
self.pruned_heads = set()
@@ -346,6 +347,7 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
class DecisionTransformerGPT2Block(nn.Module):
+ # Ignore copy
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
@@ -497,7 +499,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
- # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Model.forward
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -548,7 +549,7 @@ def forward(
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
- # GPT2Attention mask.
+ # Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py
index 9511baafca36ac..1409a3fc3f0fcb 100644
--- a/src/transformers/models/gpt2/modeling_gpt2.py
+++ b/src/transformers/models/gpt2/modeling_gpt2.py
@@ -22,6 +22,7 @@
from typing import Optional, Tuple, Union
import torch
+import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast
@@ -42,6 +43,8 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
@@ -49,6 +52,11 @@
from .configuration_gpt2 import GPT2Config
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
+
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "openai-community/gpt2"
@@ -58,6 +66,19 @@
from ..deprecated._archive_maps import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
"""Load tf checkpoints in a pytorch model"""
try:
@@ -117,7 +138,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
class GPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
-
+ self.config = config
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
@@ -155,6 +176,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
+ self.is_causal = True
self.pruned_heads = set()
@@ -335,6 +357,210 @@ def forward(
return outputs # a, present, (attentions)
+class GPT2FlashAttention2(GPT2Attention):
+ """
+ GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+ bsz, _, _ = hidden_states.size()
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
+ )
+
+ query = self.q_attn(hidden_states)
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ attention_mask = encoder_attention_mask
+ else:
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
+
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
+
+ if layer_past is not None:
+ past_key = layer_past[0]
+ past_value = layer_past[1]
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
+
+ present = None
+ if use_cache is True:
+ present = (key, value)
+
+ query_length = query.shape[2]
+ tgt_len = key.shape[2]
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim)
+ key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+ value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+
+ attn_dropout = self.attn_dropout.p if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ if query.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.c_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query = query.to(target_dtype)
+ key = key.to(target_dtype)
+ value = value.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(
+ query, key, value, attention_mask, query_length, dropout=attn_dropout
+ )
+
+ attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
+ attn_output = self.c_proj(attn_weights_reshaped)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ outputs += (attn_weights_reshaped,)
+
+ return outputs
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
+ def _flash_attention_forward(
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
+ )
+
+ return attn_output
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
class GPT2MLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
@@ -352,18 +578,25 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states
+GPT2_ATTENTION_CLASSES = {
+ "eager": GPT2Attention,
+ "flash_attention_2": GPT2FlashAttention2,
+}
+
+
class GPT2Block(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
+ attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation]
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.attn = GPT2Attention(config, layer_idx=layer_idx)
+ self.attn = attention_class(config=config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention:
- self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
+ self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config)
@@ -443,6 +676,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GPT2Block"]
_skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@@ -673,6 +907,7 @@ def __init__(self, config):
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False
+ self._attn_implementation = config._attn_implementation
# Initialize weights and apply final processing
self.post_init()
@@ -790,25 +1025,26 @@ def forward(
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
- # GPT2Attention mask.
+ # Attention mask.
if attention_mask is not None:
- if batch_size <= 0:
- raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
- # We create a 3D attention mask from a 2D tensor mask.
- # Sizes are [batch_size, 1, 1, to_seq_length]
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
- # this attention mask is more simple than the triangular masking of causal attention
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
- attention_mask = attention_mask[:, None, None, :]
-
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and the dtype's smallest value for masked positions.
- # Since we are adding it to the raw scores before the softmax, this is
- # effectively the same as removing these entirely.
- attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+ if self._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ else:
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -817,7 +1053,8 @@ def forward(
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
- encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ if self._attn_implementation != "flash_attention_2":
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py
index d2b9ce8dcf0d16..cde28cbc58617e 100644
--- a/tests/models/gpt2/test_modeling_gpt2.py
+++ b/tests/models/gpt2/test_modeling_gpt2.py
@@ -19,8 +19,17 @@
import math
import unittest
+import pytest
+
from transformers import GPT2Config, is_torch_available
-from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device
+from transformers.testing_utils import (
+ backend_empty_cache,
+ require_flash_attn,
+ require_torch,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -858,3 +867,40 @@ def test_contrastive_search_gpt2(self):
"but said in a statement to The Associated Press that"
],
)
+
+ @require_flash_attn
+ @require_torch_gpu
+ @pytest.mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_generate_padding_left(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch.float16).to(0)
+
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+
+ texts = ["hi", "Hello this is a very long sentence"]
+
+ tokenizer.padding_side = "left"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
+
+ output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_native = tokenizer.batch_decode(output_native)
+
+ model = GPT2LMHeadModel.from_pretrained(
+ "gpt2", device_map={"": 0}, attn_implementation="flash_attention_2", torch_dtype=torch.float16
+ )
+
+ output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_2 = tokenizer.batch_decode(output_fa_2)
+
+ expected_output = [
+ "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>hi, who was born in the city of Kolkata, was a member of the Kolkata",
+ "Hello this is a very long sentence. I'm sorry. I'm sorry. I'm sorry. I'm sorry. I'm sorry",
+ ]
+
+ self.assertListEqual(output_native, output_fa_2)
+ self.assertListEqual(output_native, expected_output)