diff --git a/src/transformers/models/gemma2/diff_gemma2.py b/src/transformers/models/gemma2/diff_gemma2.py index 443be0cf87f5de..5dae5d928f9427 100644 --- a/src/transformers/models/gemma2/diff_gemma2.py +++ b/src/transformers/models/gemma2/diff_gemma2.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from typing import List, Optional, Tuple, Union import torch @@ -40,10 +39,7 @@ if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + from ...flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) @@ -89,9 +85,8 @@ class Gemma2Attention(GemmaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - self.scaling = config.query_pre_attn_scalar**-0.5 - super().__init__(config, layer_idx) + self.scaling = config.query_pre_attn_scalar**-0.5 class Gemma2FlashAttention2(Gemma2Attention): @@ -177,7 +172,7 @@ def forward( value_states = value_states.to(target_dtype) ########### ONLY DIFFERENCE IS WE USE SLIDING AND PASS THE SOFTMAX SCALING - attn_output = self._flash_attention_forward( + attn_output = _flash_attention_forward( query_states, key_states, value_states, @@ -185,6 +180,8 @@ def forward( q_len, dropout=dropout_rate, softmax_scale=self.scaling, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -195,116 +192,6 @@ def forward( return attn_output, attn_weights, past_key_value - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - cache_position=0, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in Gemma2FlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - use_sliding_windows = ( - _flash_supports_window_size and self.sliding_window is not None and cache_position > self.sliding_window - ) - flash_kwargs = {"window_size": (self.sliding_window, self.sliding_window)} if use_sliding_windows else {} - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - **flash_kwargs, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) - - return attn_output - - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - class Gemma2SdpaAttention(Gemma2Attention): """ diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 6b2b47b5159e28..b0974b954799f5 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from typing import List, Optional, Tuple, Union import torch @@ -30,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache +from ...flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -49,13 +49,7 @@ if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - - -logger = logging.get_logger(__name__) + from ...flash_attention_utils import _flash_attention_forward def _get_unpad_data(attention_mask): @@ -70,6 +64,9 @@ def _get_unpad_data(attention_mask): ) +logger = logging.get_logger(__name__) + + class Gemma2RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -374,7 +371,8 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = self._flash_attention_forward( + ########### ONLY DIFFERENCE IS WE USE SLIDING AND PASS THE SOFTMAX SCALING + attn_output = _flash_attention_forward( query_states, key_states, value_states, @@ -382,6 +380,8 @@ def forward( q_len, dropout=dropout_rate, softmax_scale=self.scaling, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -392,117 +392,6 @@ def forward( return attn_output, attn_weights, past_key_value - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - cache_position=0, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in Gemma2FlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # TODO this is not compile compatible - use_sliding_windows = ( - _flash_supports_window_size and self.sliding_window is not None and cache_position > self.sliding_window - ) - flash_kwargs = {"window_size": (self.sliding_window, self.sliding_window)} if use_sliding_windows else {} - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - **flash_kwargs, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) - - return attn_output - - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - class Gemma2SdpaAttention(Gemma2Attention): """