From 6fe4d870411ba612852d88bea6aa9a212b1d3688 Mon Sep 17 00:00:00 2001 From: Shirin Yamani Date: Tue, 12 Nov 2024 15:25:58 -0700 Subject: [PATCH 1/4] snapkv cache implementation added to cache --- src/transformers/cache_utils.py | 80 +++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0f696cc3ac6a4d..3d76c8b9d5b20f 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -2120,3 +2120,83 @@ def _prefetch_layer_in_context(self, layer_idx: int) -> None: self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True) self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True) + +#------------------SnapKV Cache-------------------------------------- +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +class SnapKVCluster(): + def __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'): + self.window_size = window_size + self.max_capacity_prompt = max_capacity_prompt + assert self.max_capacity_prompt - self.window_size > 0 + self.kernel_size = kernel_size + self.pooling = pooling + + def reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'): + self.window_size = window_size + self.max_capacity_prompt = max_capacity_prompt + assert self.max_capacity_prompt - self.window_size > 0 + self.kernel_size = kernel_size + self.pooling = pooling + + def update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups): + # check if prefix phase + assert key_states.shape[-2] == query_states.shape[-2] + bsz, num_heads, q_len, head_dim = query_states.shape + if q_len < self.max_capacity_prompt: + return key_states, value_states + else: + attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim) + mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + mask_cond = torch.arange(mask.size(-1), device=attn_weights.device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(attn_weights.device) + attention_mask = mask[None, None, :, :] + + attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights_sum = attn_weights[:, :, -self.window_size:, : -self.window_size].sum(dim = -2) + if self.pooling == 'avgpool': + attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1) + elif self.pooling == 'maxpool': + attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1) + else: + raise ValueError('Pooling method not supported') + indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indices + indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim) + k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices) + v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices) + k_cur = key_states[:, :, -self.window_size:, :] + v_cur = value_states[:, :, -self.window_size:, :] + key_states = torch.cat([k_past_compress, k_cur], dim = 2) + value_states = torch.cat([v_past_compress, v_cur], dim = 2) + return key_states, value_states + +#initiate snapkv with window_size, etc if not given +def init_snapkv(self): + if not hasattr(self, "kv_cluster"): + if not hasattr(self.config, 'window_size'): + self.config.window_size = 32 + if not hasattr(self.config, 'max_capacity_prompt'): + self.config.max_capacity_prompt = 2048 + if not hasattr(self.config, 'kernel_size'): + self.config.kernel_size = 5 + if not hasattr(self.config, 'pooling'): + self.config.pooling = 'avgpool' + self.kv_cluster = SnapKVCluster( + window_size = self.config.window_size, + max_capacity_prompt = self.config.max_capacity_prompt, + kernel_size = self.config.kernel_size, + pooling = self.config.pooling + ) \ No newline at end of file From 80d3362cbc17ab69e7a452dc3b1ce119af85c088 Mon Sep 17 00:00:00 2001 From: Shirin Yamani Date: Tue, 12 Nov 2024 15:38:14 -0700 Subject: [PATCH 2/4] flash attention2 for snapkv cache added --- src/transformers/models/llama/llama_snapkv.py | 197 ++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 src/transformers/models/llama/llama_snapkv.py diff --git a/src/transformers/models/llama/llama_snapkv.py b/src/transformers/models/llama/llama_snapkv.py new file mode 100644 index 00000000000000..d391833742116f --- /dev/null +++ b/src/transformers/models/llama/llama_snapkv.py @@ -0,0 +1,197 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Optional, Tuple, Union +import warnings +from transformers.cache_utils import Cache, DynamicCache, SnapKVCluster, init_snapkv +from transformers.models.llama.modeling_llama import ( + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import ( + logging, +) + + +logger = logging.get_logger(__name__) + +#flash attention changes for using snapKV approach +# https://github.com/huggingface/transformers/blob/v4.37-release/src/transformers/models/llama/modeling_llama.py +def llama_flash_attn2_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # [SnapKV] register kv_cluster + init_snapkv(self) + # LlamaFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) #(B,T,nh, hs) ---> (B,nh,T,hs) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] #(T) + # if past_key_value is not None: + # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len + if self.kv_seq_len != 0: + kv_seq_len += self.kv_seq_len + else: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [SnapKV] move to ahead + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # print('kv_seq_len:', kv_seq_len) + # print('key_states.shape:', key_states.shape) + if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster + self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len + key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups) + past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs) + else: + self.kv_seq_len += q_len + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + +def prepare_inputs_for_generation_llama( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs +): + if past_key_values is None: # [SnapKV] + for layer in self.model.layers: + layer.self_attn.kv_seq_len = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + # cache_length = past_length = past_key_values[0][0].shape[2] + # max_cache_length = None + cache_length = past_length = self.model.layers[0].self_attn.kv_seq_len + max_cache_length = None + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs \ No newline at end of file From f4e0f15a56593a783468a2aac3d577fe16aa7936 Mon Sep 17 00:00:00 2001 From: Shirin Yamani Date: Tue, 12 Nov 2024 16:59:25 -0700 Subject: [PATCH 3/4] cache_utils updated accofing to ruff --- src/transformers/cache_utils.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 3d76c8b9d5b20f..edacb90e1e67d9 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -2,10 +2,13 @@ import importlib.metadata import json import os +import math from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn +from torch.nn import functional as F from packaging import version from .configuration_utils import PretrainedConfig @@ -2182,8 +2185,8 @@ def update_kv(self, key_states, query_states, value_states, attention_mask, num_ key_states = torch.cat([k_past_compress, k_cur], dim = 2) value_states = torch.cat([v_past_compress, v_cur], dim = 2) return key_states, value_states - -#initiate snapkv with window_size, etc if not given + +#initiate snapkv with window_size, etc if not given def init_snapkv(self): if not hasattr(self, "kv_cluster"): if not hasattr(self.config, 'window_size'): @@ -2194,9 +2197,9 @@ def init_snapkv(self): self.config.kernel_size = 5 if not hasattr(self.config, 'pooling'): self.config.pooling = 'avgpool' - self.kv_cluster = SnapKVCluster( - window_size = self.config.window_size, - max_capacity_prompt = self.config.max_capacity_prompt, + self.kv_cluster = SnapKVCluster( + window_size = self.config.window_size, + max_capacity_prompt = self.config.max_capacity_prompt, kernel_size = self.config.kernel_size, pooling = self.config.pooling - ) \ No newline at end of file + ) From debb5ad66348d444edf0412f8c96875c80ba9524 Mon Sep 17 00:00:00 2001 From: Shirin Yamani Date: Tue, 12 Nov 2024 17:01:11 -0700 Subject: [PATCH 4/4] ruff errors solved --- src/transformers/cache_utils.py | 75 +++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index edacb90e1e67d9..629dfc0b3d8423 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1,15 +1,15 @@ import copy import importlib.metadata import json -import os import math +import os from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn -from torch.nn import functional as F from packaging import version +from torch.nn import functional as F from .configuration_utils import PretrainedConfig from .utils import ( @@ -2124,7 +2124,8 @@ def _prefetch_layer_in_context(self, layer_idx: int) -> None: self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True) self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True) -#------------------SnapKV Cache-------------------------------------- + +# ------------------SnapKV Cache-------------------------------------- # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ @@ -2137,15 +2138,16 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class SnapKVCluster(): - def __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'): + +class SnapKVCluster: + def __init__(self, window_size=64, max_capacity_prompt=256 + 64, kernel_size=5, pooling="avgpool"): self.window_size = window_size self.max_capacity_prompt = max_capacity_prompt assert self.max_capacity_prompt - self.window_size > 0 self.kernel_size = kernel_size self.pooling = pooling - def reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'): + def reset(self, window_size=64, max_capacity_prompt=256 + 64, kernel_size=5, pooling="avgpool"): self.window_size = window_size self.max_capacity_prompt = max_capacity_prompt assert self.max_capacity_prompt - self.window_size > 0 @@ -2159,47 +2161,56 @@ def update_kv(self, key_states, query_states, value_states, attention_mask, num_ if q_len < self.max_capacity_prompt: return key_states, value_states else: - attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim) - mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + attn_weights = torch.matmul( + query_states[..., -self.window_size :, :], key_states.transpose(2, 3) + ) / math.sqrt(head_dim) + mask = torch.full( + (self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device + ) mask_cond = torch.arange(mask.size(-1), device=attn_weights.device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(attn_weights.device) attention_mask = mask[None, None, :, :] - attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_mask + attn_weights[:, :, -self.window_size :, -self.window_size :] += attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights_sum = attn_weights[:, :, -self.window_size:, : -self.window_size].sum(dim = -2) - if self.pooling == 'avgpool': - attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1) - elif self.pooling == 'maxpool': - attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1) + attn_weights_sum = attn_weights[:, :, -self.window_size :, : -self.window_size].sum(dim=-2) + if self.pooling == "avgpool": + attn_cache = F.avg_pool1d( + attn_weights_sum, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1 + ) + elif self.pooling == "maxpool": + attn_cache = F.max_pool1d( + attn_weights_sum, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1 + ) else: - raise ValueError('Pooling method not supported') + raise ValueError("Pooling method not supported") indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim) - k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices) - v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices) - k_cur = key_states[:, :, -self.window_size:, :] - v_cur = value_states[:, :, -self.window_size:, :] - key_states = torch.cat([k_past_compress, k_cur], dim = 2) - value_states = torch.cat([v_past_compress, v_cur], dim = 2) + k_past_compress = key_states[:, :, : -self.window_size, :].gather(dim=2, index=indices) + v_past_compress = value_states[:, :, : -self.window_size, :].gather(dim=2, index=indices) + k_cur = key_states[:, :, -self.window_size :, :] + v_cur = value_states[:, :, -self.window_size :, :] + key_states = torch.cat([k_past_compress, k_cur], dim=2) + value_states = torch.cat([v_past_compress, v_cur], dim=2) return key_states, value_states -#initiate snapkv with window_size, etc if not given + +# initiate snapkv with window_size, etc if not given def init_snapkv(self): if not hasattr(self, "kv_cluster"): - if not hasattr(self.config, 'window_size'): + if not hasattr(self.config, "window_size"): self.config.window_size = 32 - if not hasattr(self.config, 'max_capacity_prompt'): + if not hasattr(self.config, "max_capacity_prompt"): self.config.max_capacity_prompt = 2048 - if not hasattr(self.config, 'kernel_size'): + if not hasattr(self.config, "kernel_size"): self.config.kernel_size = 5 - if not hasattr(self.config, 'pooling'): - self.config.pooling = 'avgpool' + if not hasattr(self.config, "pooling"): + self.config.pooling = "avgpool" self.kv_cluster = SnapKVCluster( - window_size = self.config.window_size, - max_capacity_prompt = self.config.max_capacity_prompt, - kernel_size = self.config.kernel_size, - pooling = self.config.pooling - ) + window_size=self.config.window_size, + max_capacity_prompt=self.config.max_capacity_prompt, + kernel_size=self.config.kernel_size, + pooling=self.config.pooling, + )