Skip to content

Commit

Permalink
feat: add cc, acc method and ppl decoding mha kernel for deepseek2
Browse files Browse the repository at this point in the history
  • Loading branch information
niushengxiao committed Nov 28, 2024
1 parent bd27a96 commit a1fe340
Show file tree
Hide file tree
Showing 3 changed files with 599 additions and 2 deletions.
174 changes: 172 additions & 2 deletions lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Tuple
import torch
import torch.functional as F
import torch.nn.functional as FN
import torch.distributed as dist
import numpy as np
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd,
context_attention_fwd_no_prompt_cache,
context_attention_fwd_with_v,
context_attention_fwd_no_prompt_cache_with_v,
)

from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
Expand All @@ -18,6 +21,7 @@
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from functools import partial
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
import os


class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
Expand Down Expand Up @@ -55,6 +59,11 @@ def __init__(
self.softmax_scale = self.softmax_scale * mscale * mscale
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
self.tp_o_head_num_ = self.tp_q_head_num_

self.num_heads = network_config["num_attention_heads"]
self.num_kv_heads = network_config["num_key_value_heads"]
self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"]

return

def _bind_attention(self):
Expand Down Expand Up @@ -97,7 +106,8 @@ def _get_qkv(

q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
if layer_weight.mla_type == "ACCM":
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)

layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))

Expand All @@ -123,11 +133,157 @@ def _get_o(
input = input.view(-1, self.tp_q_head_num_ * self.kv_lora_rank)
o_tensor = layer_weight.fuse_vo_weight_.mm(input)
else:
input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1)
if layer_weight.mla_type == "ACCM":
input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1)
o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim))
return o_tensor

def _CC_method(
self, q, compressed_kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
):
num_local_heads = self.num_heads
num_local_kv_heads = self.num_kv_heads
if self.world_size_ > 1:
num_local_heads //= self.world_size_
num_local_kv_heads //= self.world_size_
if infer_state.use_dynamic_prompt_cache:
compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
# CC
compressed_kv, k_pe = torch.split( # (b*s, 1, kv_lora + qk_r)
compressed_kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1
)
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank)
k = torch.empty(
k_pe.shape[0],
num_local_kv_heads,
layer_weight.qk_nope_head_dim + layer_weight.qk_rope_head_dim,
dtype=q[0].dtype,
device=q[0].device,
)
k[..., layer_weight.qk_nope_head_dim :] = k_pe
k[..., : layer_weight.qk_nope_head_dim] = FN.linear(
compressed_kv, layer_weight.k_b_proj_.weight.view(-1, layer_weight.k_b_proj_.weight.shape[-1])
).view(-1, num_local_kv_heads, layer_weight.qk_nope_head_dim)
trans_weight = layer_weight.v_b_proj_.weight.transpose(1, 2)
v = FN.linear(compressed_kv, trans_weight.view(-1, trans_weight.shape[-1])).view(
-1, num_local_kv_heads, layer_weight.qk_nope_head_dim
) # (b*s, h, vo_d)
return self._context_attention_kernel_with_v(q, k, v, infer_state, layer_weight)

def _ACC_method(
self, q, compressed_kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
):
q_ne, q_pe = q
num_local_heads = self.num_heads
num_local_kv_heads = self.num_kv_heads
if self.world_size_ > 1:
num_local_heads //= self.world_size_
num_local_kv_heads //= self.world_size_
# ACC
q = torch.empty(
q_ne.shape[0],
num_local_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
dtype=q_ne.dtype,
device=q_ne.device,
)
q[..., self.kv_lora_rank :] = q_pe
q[..., : self.kv_lora_rank] = torch.bmm( # TODO: 转换成einsum 或者 cublas
q_ne.transpose(0, 1), # (h, b*s, qk_n)
layer_weight.k_b_proj_.weight.view(
num_local_kv_heads, self.qk_nope_head_dim, self.kv_lora_rank
), # (h, qk_n, kv_lora)
).transpose(
0, 1
) # (b*s, h, kv_lora)
q_nope, q_rope = torch.split( # (b*s, h, qk_n + qk_r) -> (b*s, h, qk_n), (b*s, h, qk_r)
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
if self.enable_opt_decoding_mha:
import lightllm_ppl_mla

o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype)
kvstarts = torch.zeros(infer_state.batch_size + 1, dtype=torch.int, device=q.device)
kvstarts[1:] = infer_state.b_seq_len.clone().detach().cumsum(dim=0)
lightllm_ppl_mla.decode_mla(
o_tensor,
q,
compressed_kv[: infer_state.mem_end, :, :],
infer_state.b_start_loc,
kvstarts,
self.softmax_scale,
q.shape[-1],
q_nope.shape[-1],
)
output_parallel = o_tensor
else:
output_parallel = self._token_gqa_decode_attention_flashdecoding_origin(
(q_nope, q_rope), infer_state, layer_weight
)
trans_weight = layer_weight.v_b_proj_.weight.transpose(1, 2)
output_parallel = torch.bmm( # TODO: 转换成einsum 或者 cublas
output_parallel.transpose(0, 1), # (h, b*s, kv_lora)
trans_weight.view(num_local_kv_heads, layer_weight.qk_nope_head_dim, self.kv_lora_rank).transpose(
1, 2
), # (h, kv_lora, vo_d)
).transpose(
0, 1
) # (b*s, h, vo_d)
return output_parallel

def _context_attention_kernel(
self, q, kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
) -> torch.Tensor:
if layer_weight.mla_type == "MIX":
return self._context_attention_kernel_with_CC(q, kv, infer_state, layer_weight, out)
else:
return self._context_attention_kernel_origin(q, kv, infer_state, layer_weight, out)

def _context_attention_kernel_with_CC(
self, q, kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
) -> torch.Tensor:
return self._CC_method(q, kv, infer_state, layer_weight)

def _context_attention_kernel_with_v(
self, q: Tuple[torch.Tensor, torch.Tensor], kv, v, infer_state: LlamaInferStateInfo, layer_weight, out=None
) -> torch.Tensor:
q_nope, q_rope = q
nope_head_dim = q_nope.shape[-1]
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
if infer_state.use_dynamic_prompt_cache:
context_attention_fwd_with_v(
q_nope,
q_rope,
kv[:, :, :nope_head_dim],
kv[:, :, nope_head_dim:],
v,
o_tensor.view(-1, self.tp_q_head_num_, nope_head_dim),
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
infer_state.req_manager.req_to_token_indexs,
self.softmax_scale,
)
else:
context_attention_fwd_no_prompt_cache_with_v(
q_nope,
q_rope,
kv[:, :, :nope_head_dim],
kv[:, :, nope_head_dim:],
v,
o_tensor.view(-1, self.tp_q_head_num_, nope_head_dim),
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
self.softmax_scale,
)
q_nope = None
q_rope = None
return o_tensor

def _context_attention_kernel_origin(
self, q: Tuple[torch.Tensor, torch.Tensor], kv, infer_state: LlamaInferStateInfo, layer_weight, out=None
) -> torch.Tensor:
q_nope, q_rope = q
Expand Down Expand Up @@ -166,6 +322,20 @@ def _context_attention_kernel(
return o_tensor

def _token_gqa_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
if layer_weight.mla_type == "MIX":
return self._token_gqa_decode_attention_flashdecoding_with_ACC(q, infer_state, layer_weight, out)
else:
return self._token_gqa_decode_attention_flashdecoding_origin(q, infer_state, layer_weight, out)

def _token_gqa_decode_attention_flashdecoding_with_ACC(
self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None
):
compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_][: infer_state.mem_end, :, :]
return self._ACC_method(q, compressed_kv, infer_state, layer_weight)

def _token_gqa_decode_attention_flashdecoding_origin(
self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None
):
q_nope, q_rope = q
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank]
kv_rope = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank :]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def __init__(
self.disable_qk_absorb = disable_qk_absorb
self.disable_vo_absorb = disable_vo_absorb
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
# mla_type = "ACCM", "MIX"
# MIX是prefilled CC,decoding ACC
self.mla_type = "MIX"
if not disable_vo_absorb or not disable_qk_absorb:
self.mla_type = "ACCM"
return

def _parse_config(self):
Expand Down
Loading

0 comments on commit a1fe340

Please sign in to comment.