Skip to content

Commit

Permalink
feat: add eager lckv attention implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
why-in-Shanghaitech committed Nov 8, 2024
1 parent b3412a3 commit cf50b6c
Showing 1 changed file with 179 additions and 4 deletions.
183 changes: 179 additions & 4 deletions models/modeling_lckv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union

import torch
from torch import nn

from transformers.cache_utils import Cache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LLAMA_INPUTS_DOCSTRING,
LlamaAttention,
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaModel,
LlamaPreTrainedModel,
_prepare_4d_causal_attention_mask_with_cache_position,
logger,
repeat_kv,
rotate_half,
)
from transformers.utils import add_start_docstrings_to_model_forward
from transformers.utils import add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10

from .cache_utils import AutoLayerCache, LayerCache
from .configuration_lckv import LCKVLlamaConfig
Expand All @@ -49,7 +52,8 @@ def apply_rotary(q, cos, sin, unsqueeze_dim=1):
q_embed = (q * cos) + (rotate_half(q) * sin)
return q_embed

class LCKVLlamaAttention(LlamaFlashAttention2):

class LCKVLlamaAttention(LlamaAttention):
"""
LCKV Attention may not need to initialize weights for the key and value projections.
"""
Expand All @@ -64,6 +68,104 @@ def __init__(self, config: LCKVLlamaConfig, layer_idx: Optional[int] = None):
del self.k_proj
del self.v_proj

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
cos, sin = position_embeddings

query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = apply_rotary(query_states, cos, sin)

# compute key and value states
if self.layer_type.computes_kv:
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
key_states = apply_rotary(key_states, cos, sin)

if isinstance(past_key_value, Cache):
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

past_key_value.layer_set(self.layer_idx, key_states, value_states)

# get the cached key and value states
key_states, value_states = past_key_value.layer_get(
self.layer_type.attends_to,
zerofill=self.layer_type.attends_top and q_len == 1,
)

# handle GQA
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# diagonal mask from the right bottom corner
if self.layer_type.attends_top:
kv_len = key_states.size(2)
mask = attn_weights.new_full((q_len, kv_len), torch.finfo(attn_weights.dtype).min)
mask = mask.tril(diagonal=kv_len - q_len).triu(diagonal=kv_len - q_len)
attn_weights = attn_weights + mask

# sliding window mask
if self.sliding_window:
kv_len = key_states.size(2)
mask = attn_weights.new_full((q_len, kv_len), torch.finfo(attn_weights.dtype).min)
mask = mask.tril(diagonal=kv_len - q_len - self.sliding_window)
attn_weights = attn_weights + mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


class LCKVLlamaFlashAttention2(LCKVLlamaAttention):
"""
LCKV Attention may not need to initialize weights for the key and value projections.
"""

def __init__(self, config: LCKVLlamaConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)

# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -172,10 +274,16 @@ def forward(
return attn_output, attn_weights, past_key_value


LCKV_LLAMA_ATTENTION_CLASSES = {
"eager": LCKVLlamaAttention,
"flash_attention_2": LCKVLlamaFlashAttention2,
}


class LCKVLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LCKVLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.self_attn = LCKVLlamaAttention(config=config, layer_idx=layer_idx)
self.self_attn = LCKV_LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)


class LCKVLlamaPreTrainedModel(LlamaPreTrainedModel):
Expand Down Expand Up @@ -534,6 +642,73 @@ def _modeling_sequential(
attentions=all_self_attns,
)

def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
"""fix this function to handle layer cache"""
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if isinstance(past_key_values, Cache) else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)

if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask


class LCKVLlamaForCausalLM(LCKVLlamaPreTrainedModel, LlamaForCausalLM):
def __init__(self, config):
Expand Down

0 comments on commit cf50b6c

Please sign in to comment.