From cf50b6cd14b11465128227b13b7752455bec270d Mon Sep 17 00:00:00 2001 From: whywhy-rtx3090 <43395692+why-in-Shanghaitech@users.noreply.github.com> Date: Fri, 8 Nov 2024 22:54:39 +0800 Subject: [PATCH] feat: add eager lckv attention implementation --- models/modeling_lckv.py | 183 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 179 insertions(+), 4 deletions(-) diff --git a/models/modeling_lckv.py b/models/modeling_lckv.py index 0934d1c..62be1bb 100644 --- a/models/modeling_lckv.py +++ b/models/modeling_lckv.py @@ -18,25 +18,28 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch LLaMA model.""" +import math from typing import List, Optional, Tuple, Union import torch from torch import nn from transformers.cache_utils import Cache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import ( LLAMA_INPUTS_DOCSTRING, + LlamaAttention, LlamaDecoderLayer, - LlamaFlashAttention2, LlamaForCausalLM, LlamaModel, LlamaPreTrainedModel, _prepare_4d_causal_attention_mask_with_cache_position, logger, + repeat_kv, rotate_half, ) -from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10 from .cache_utils import AutoLayerCache, LayerCache from .configuration_lckv import LCKVLlamaConfig @@ -49,7 +52,8 @@ def apply_rotary(q, cos, sin, unsqueeze_dim=1): q_embed = (q * cos) + (rotate_half(q) * sin) return q_embed -class LCKVLlamaAttention(LlamaFlashAttention2): + +class LCKVLlamaAttention(LlamaAttention): """ LCKV Attention may not need to initialize weights for the key and value projections. """ @@ -64,6 +68,104 @@ def __init__(self, config: LCKVLlamaConfig, layer_idx: Optional[int] = None): del self.k_proj del self.v_proj + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + cos, sin = position_embeddings + + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = apply_rotary(query_states, cos, sin) + + # compute key and value states + if self.layer_type.computes_kv: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + key_states = apply_rotary(key_states, cos, sin) + + if isinstance(past_key_value, Cache): + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + past_key_value.layer_set(self.layer_idx, key_states, value_states) + + # get the cached key and value states + key_states, value_states = past_key_value.layer_get( + self.layer_type.attends_to, + zerofill=self.layer_type.attends_top and q_len == 1, + ) + + # handle GQA + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # diagonal mask from the right bottom corner + if self.layer_type.attends_top: + kv_len = key_states.size(2) + mask = attn_weights.new_full((q_len, kv_len), torch.finfo(attn_weights.dtype).min) + mask = mask.tril(diagonal=kv_len - q_len).triu(diagonal=kv_len - q_len) + attn_weights = attn_weights + mask + + # sliding window mask + if self.sliding_window: + kv_len = key_states.size(2) + mask = attn_weights.new_full((q_len, kv_len), torch.finfo(attn_weights.dtype).min) + mask = mask.tril(diagonal=kv_len - q_len - self.sliding_window) + attn_weights = attn_weights + mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LCKVLlamaFlashAttention2(LCKVLlamaAttention): + """ + LCKV Attention may not need to initialize weights for the key and value projections. + """ + + def __init__(self, config: LCKVLlamaConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def forward( self, hidden_states: torch.Tensor, @@ -172,10 +274,16 @@ def forward( return attn_output, attn_weights, past_key_value +LCKV_LLAMA_ATTENTION_CLASSES = { + "eager": LCKVLlamaAttention, + "flash_attention_2": LCKVLlamaFlashAttention2, +} + + class LCKVLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LCKVLlamaConfig, layer_idx: int): super().__init__(config, layer_idx) - self.self_attn = LCKVLlamaAttention(config=config, layer_idx=layer_idx) + self.self_attn = LCKV_LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) class LCKVLlamaPreTrainedModel(LlamaPreTrainedModel): @@ -534,6 +642,73 @@ def _modeling_sequential( attentions=all_self_attns, ) + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + """fix this function to handle layer cache""" + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if isinstance(past_key_values, Cache) else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + class LCKVLlamaForCausalLM(LCKVLlamaPreTrainedModel, LlamaForCausalLM): def __init__(self, config):