From 08a09e3528974f980593f5340c6ceee2ba8e25f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Wed, 26 Jun 2024 16:24:17 +0200 Subject: [PATCH] Attention WIP --- .../flash_deepseek_v2_modeling.py | 85 ++++++++++++++----- 1 file changed, 64 insertions(+), 21 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index cb19a9b00a4..f8357316cbf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -250,10 +250,16 @@ def __init__( self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads + self.kv_lora_rank = config.kv_lora_rank + self.q_lora_rank = config.q_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.q_head_size = config.qk_nope_head_dim + config.qk_rope_head_dim + self.v_head_dim = config.v_head_dim self.rotary_emb = PositionRotaryEmbedding.static( config=config, - dim=self.head_size, + dim=self.qk_rope_head_dim, base=config.rope_theta, device=weights.device, ) @@ -318,23 +324,60 @@ def forward( input_lengths, max_s, ): - qkv = self.query_key_value(hidden_states) - if self.clip_qkv is not None: - qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) + if self.q_lora_rank is None: + query = self.query(hidden_states) + else: + query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) + query = query.view(-1, self.num_heads, self.q_head_size) + from loguru import logger - query, kv = qkv.split( - [ - self.head_size * self.num_heads, - 2 * self.head_size * self.num_key_value_heads, - ], - dim=1, + query_nope, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + # .transpose(1, 2) + ) + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + + self.rotary_emb(query_pe, k_pe, cos, sin) + + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = k_nope + key[..., self.qk_nope_head_dim :] = k_pe + + logger.warning( + f"query: {query.shape}, key: {key.shape}, value: {value_states.shape}" + ) + + reshape_and_cache(key, value_states, kv_cache[0], kv_cache[1], slots) + + logger.warning( + f"after reshape_and_cache -> query: {query.shape}, key: {key.shape}, value: {value_states.shape}" ) - query = query.view(-1, self.num_heads, self.head_size) - kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad(query, (0, 256 - self.q_head_size), value=0) + key = torch.nn.functional.pad(key, (0, 256 - self.q_head_size), value=0) + value_states = torch.nn.functional.pad( + value_states, (0, 256 - self.v_head_dim), value=0 + ) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + logger.warning( + f"after pad -> query: {query.shape}, key: {key.shape}, value: {value_states.shape}" + ) # output tensor attn_output = torch.empty_like(query) @@ -344,8 +387,8 @@ def forward( # flash attention attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), + key, + value_states, attn_output, cu_seqlen_prefill, max_s, @@ -356,8 +399,8 @@ def forward( paged_attention( attn_output, query, - kv_cache[0], - kv_cache[1], + key, + value_states, self.kv_head_mapping, self.softmax_scale, block_tables, @@ -665,9 +708,9 @@ def forward( attn_output, res ) - moe_output = self.moe(normed_attn_res_output) + output = self.mlp(normed_attn_res_output) - return moe_output, attn_res + return output, attn_res class DeepseekV2Model(torch.nn.Module): @@ -711,7 +754,7 @@ def forward( # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin( + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids, max_s, hidden_states.dtype )