From 0222c0d4316671105f14ee47c07d794da34c23c8 Mon Sep 17 00:00:00 2001 From: Eduardo Pacheco <69953243+EduardoPach@users.noreply.github.com> Date: Thu, 28 Mar 2024 10:31:24 +0100 Subject: [PATCH] Adding Flash Attention 2 Support for GPT2 (#29226) * First commit to add flash attention 2 for GPT-2 * more improvements * Make GPT2 pass tests and fixed Decison Transformers copies * Fixed missing arg * fix copies * Added expected speedup * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Added test * Fixed attn attribute * Update docs/source/en/model_doc/gpt2.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/gpt2.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update Decision transformer attentions * More updates * Passing tests * Fix copies * Fix copies part 2 * Decision transformer updates * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fix copies * Decision transformer not supporting flash attn * Addressed comments * Addressed comments * Addressed comments --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/en/model_doc/gpt2.md | 67 +++++ docs/source/en/perf_infer_gpu_one.md | 1 + .../modeling_decision_transformer.py | 7 +- src/transformers/models/gpt2/modeling_gpt2.py | 279 ++++++++++++++++-- tests/models/gpt2/test_modeling_gpt2.py | 48 ++- 5 files changed, 377 insertions(+), 25 deletions(-) diff --git a/docs/source/en/model_doc/gpt2.md b/docs/source/en/model_doc/gpt2.md index 4708edde0b65d4..b2afbbd3b2ec40 100644 --- a/docs/source/en/model_doc/gpt2.md +++ b/docs/source/en/model_doc/gpt2.md @@ -60,6 +60,73 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o - Enabling the *scale_attn_by_inverse_layer_idx* and *reorder_and_upcast_attn* flags will apply the training stability improvements from [Mistral](https://github.com/stanford-crfm/mistral/) (for PyTorch only). +## Usage example + +The `generate()` method can be used to generate text using GPT2 model. + +```python +>>> from transformers import AutoModelForCausalLM, AutoTokenizer + +>>> model = AutoModelForCausalLM.from_pretrained("gpt2") +>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + +>>> prompt = "GPT2 is a model developed by OpenAI." + +>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + +>>> gen_tokens = model.generate( +... input_ids, +... do_sample=True, +... temperature=0.9, +... max_length=100, +... ) +>>> gen_text = tokenizer.batch_decode(gen_tokens)[0] +``` + +## Using Flash Attention 2 + +Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels. + +### Installation + +First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). If your hardware is not compatible with Flash Attention 2, you can still benefit from attention kernel optimisations through Better Transformer support covered [above](https://huggingface.co/docs/transformers/main/en/model_doc/bark#using-better-transformer). + +Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2: + +```bash +pip install -U flash-attn --no-build-isolation +``` + +### Usage + +To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference: + +```python +>>> import torch +>>> from transformers import AutoModelForCausalLM, AutoTokenizer +>>> device = "cuda" # the device to load the model onto + +>>> model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, attn_implementation="flash_attention_2") +>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + +>>> prompt = "def hello_world():" + +>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device) +>>> model.to(device) + +>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) +>>> tokenizer.batch_decode(generated_ids)[0] +``` + + +### Expected speedups + +Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `gpt2` checkpoint and the Flash Attention 2 version of the model using a sequence length of 512. + +
+ +
+ ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 90409b1c21bc5b..0fbea1cd8d3d03 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -42,6 +42,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) +* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel) * [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 9dd9d95c387968..6f939460aab86f 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -108,7 +108,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): class DecisionTransformerGPT2Attention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() - + self.config = config max_positions = config.max_position_embeddings self.register_buffer( "bias", @@ -146,6 +146,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True self.pruned_heads = set() @@ -346,6 +347,7 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2 class DecisionTransformerGPT2Block(nn.Module): + # Ignore copy def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -497,7 +499,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings - # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Model.forward def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -548,7 +549,7 @@ def forward( position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) - # GPT2Attention mask. + # Attention mask. if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 9511baafca36ac..1409a3fc3f0fcb 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -22,6 +22,7 @@ from typing import Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.cuda.amp import autocast @@ -42,6 +43,8 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -49,6 +52,11 @@ from .configuration_gpt2 import GPT2Config +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "openai-community/gpt2" @@ -58,6 +66,19 @@ from ..deprecated._archive_maps import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402 +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): """Load tf checkpoints in a pytorch model""" try: @@ -117,7 +138,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): class GPT2Attention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() - + self.config = config max_positions = config.max_position_embeddings self.register_buffer( "bias", @@ -155,6 +176,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True self.pruned_heads = set() @@ -335,6 +357,210 @@ def forward( return outputs # a, present, (attentions) +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + class GPT2MLP(nn.Module): def __init__(self, intermediate_size, config): super().__init__() @@ -352,18 +578,25 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + class GPT2Block(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, layer_idx=layer_idx) + self.attn = attention_class(config=config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: - self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx) + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(inner_dim, config) @@ -443,6 +676,7 @@ class GPT2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GPT2Block"] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -673,6 +907,7 @@ def __init__(self, config): self.model_parallel = False self.device_map = None self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation # Initialize weights and apply final processing self.post_init() @@ -790,25 +1025,26 @@ def forward( position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) - # GPT2Attention mask. + # Attention mask. if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -817,7 +1053,8 @@ def forward( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + if self._attn_implementation != "flash_attention_2": + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_attention_mask = None diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index d2b9ce8dcf0d16..cde28cbc58617e 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -19,8 +19,17 @@ import math import unittest +import pytest + from transformers import GPT2Config, is_torch_available -from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device +from transformers.testing_utils import ( + backend_empty_cache, + require_flash_attn, + require_torch, + require_torch_gpu, + slow, + torch_device, +) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -858,3 +867,40 @@ def test_contrastive_search_gpt2(self): "but said in a statement to The Associated Press that" ], ) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_left(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch.float16).to(0) + + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + + texts = ["hi", "Hello this is a very long sentence"] + + tokenizer.padding_side = "left" + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) + + output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_native = tokenizer.batch_decode(output_native) + + model = GPT2LMHeadModel.from_pretrained( + "gpt2", device_map={"": 0}, attn_implementation="flash_attention_2", torch_dtype=torch.float16 + ) + + output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_fa_2 = tokenizer.batch_decode(output_fa_2) + + expected_output = [ + "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>hi, who was born in the city of Kolkata, was a member of the Kolkata", + "Hello this is a very long sentence. I'm sorry. I'm sorry. I'm sorry. I'm sorry. I'm sorry", + ] + + self.assertListEqual(output_native, output_fa_2) + self.assertListEqual(output_native, expected_output)