diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index bf5db846..4ce4d286 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -1,6 +1,7 @@ from typing import Tuple import torch import torch.functional as F +import torch.nn.functional as FN import torch.distributed as dist import numpy as np from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight @@ -8,6 +9,8 @@ from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad import ( context_attention_fwd, context_attention_fwd_no_prompt_cache, + context_attention_fwd_with_v, + context_attention_fwd_no_prompt_cache_with_v, ) from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding @@ -18,6 +21,7 @@ from lightllm.models.llama.infer_struct import LlamaInferStateInfo from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale +import os class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer): @@ -55,6 +59,11 @@ def __init__( self.softmax_scale = self.softmax_scale * mscale * mscale super().__init__(layer_num, tp_rank, world_size, network_config, mode) self.tp_o_head_num_ = self.tp_q_head_num_ + + self.num_heads = network_config["num_attention_heads"] + self.num_kv_heads = network_config["num_key_value_heads"] + self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"] + return def _bind_attention(self): @@ -97,7 +106,8 @@ def _get_qkv( q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + if layer_weight.mla_type == "ACCM": + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim)) @@ -123,11 +133,157 @@ def _get_o( input = input.view(-1, self.tp_q_head_num_ * self.kv_lora_rank) o_tensor = layer_weight.fuse_vo_weight_.mm(input) else: - input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) + if layer_weight.mla_type == "ACCM": + input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1) o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim)) return o_tensor + def _CC_method( + self, q, compressed_kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight + ): + num_local_heads = self.num_heads + num_local_kv_heads = self.num_kv_heads + if self.world_size_ > 1: + num_local_heads //= self.world_size_ + num_local_kv_heads //= self.world_size_ + if infer_state.use_dynamic_prompt_cache: + compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + # CC + compressed_kv, k_pe = torch.split( # (b*s, 1, kv_lora + qk_r) + compressed_kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1 + ) + compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank) + k = torch.empty( + k_pe.shape[0], + num_local_kv_heads, + layer_weight.qk_nope_head_dim + layer_weight.qk_rope_head_dim, + dtype=q[0].dtype, + device=q[0].device, + ) + k[..., layer_weight.qk_nope_head_dim :] = k_pe + k[..., : layer_weight.qk_nope_head_dim] = FN.linear( + compressed_kv, layer_weight.k_b_proj_.weight.view(-1, layer_weight.k_b_proj_.weight.shape[-1]) + ).view(-1, num_local_kv_heads, layer_weight.qk_nope_head_dim) + trans_weight = layer_weight.v_b_proj_.weight.transpose(1, 2) + v = FN.linear(compressed_kv, trans_weight.view(-1, trans_weight.shape[-1])).view( + -1, num_local_kv_heads, layer_weight.qk_nope_head_dim + ) # (b*s, h, vo_d) + return self._context_attention_kernel_with_v(q, k, v, infer_state, layer_weight) + + def _ACC_method( + self, q, compressed_kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight + ): + q_ne, q_pe = q + num_local_heads = self.num_heads + num_local_kv_heads = self.num_kv_heads + if self.world_size_ > 1: + num_local_heads //= self.world_size_ + num_local_kv_heads //= self.world_size_ + # ACC + q = torch.empty( + q_ne.shape[0], + num_local_heads, + self.kv_lora_rank + self.qk_rope_head_dim, + dtype=q_ne.dtype, + device=q_ne.device, + ) + q[..., self.kv_lora_rank :] = q_pe + q[..., : self.kv_lora_rank] = torch.bmm( # TODO: 转换成einsum 或者 cublas + q_ne.transpose(0, 1), # (h, b*s, qk_n) + layer_weight.k_b_proj_.weight.view( + num_local_kv_heads, self.qk_nope_head_dim, self.kv_lora_rank + ), # (h, qk_n, kv_lora) + ).transpose( + 0, 1 + ) # (b*s, h, kv_lora) + q_nope, q_rope = torch.split( # (b*s, h, qk_n + qk_r) -> (b*s, h, qk_n), (b*s, h, qk_r) + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + if self.enable_opt_decoding_mha: + import lightllm_ppl_mla + + o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) + kvstarts = torch.zeros(infer_state.batch_size + 1, dtype=torch.int, device=q.device) + kvstarts[1:] = infer_state.b_seq_len.clone().detach().cumsum(dim=0) + lightllm_ppl_mla.decode_mla( + o_tensor, + q, + compressed_kv[: infer_state.mem_end, :, :], + infer_state.b_start_loc, + kvstarts, + self.softmax_scale, + q.shape[-1], + q_nope.shape[-1], + ) + output_parallel = o_tensor + else: + output_parallel = self._token_gqa_decode_attention_flashdecoding_origin( + (q_nope, q_rope), infer_state, layer_weight + ) + trans_weight = layer_weight.v_b_proj_.weight.transpose(1, 2) + output_parallel = torch.bmm( # TODO: 转换成einsum 或者 cublas + output_parallel.transpose(0, 1), # (h, b*s, kv_lora) + trans_weight.view(num_local_kv_heads, layer_weight.qk_nope_head_dim, self.kv_lora_rank).transpose( + 1, 2 + ), # (h, kv_lora, vo_d) + ).transpose( + 0, 1 + ) # (b*s, h, vo_d) + return output_parallel + def _context_attention_kernel( + self, q, kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None + ) -> torch.Tensor: + if layer_weight.mla_type == "MIX": + return self._context_attention_kernel_with_CC(q, kv, infer_state, layer_weight, out) + else: + return self._context_attention_kernel_origin(q, kv, infer_state, layer_weight, out) + + def _context_attention_kernel_with_CC( + self, q, kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None + ) -> torch.Tensor: + return self._CC_method(q, kv, infer_state, layer_weight) + + def _context_attention_kernel_with_v( + self, q: Tuple[torch.Tensor, torch.Tensor], kv, v, infer_state: LlamaInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + q_nope, q_rope = q + nope_head_dim = q_nope.shape[-1] + o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out + if infer_state.use_dynamic_prompt_cache: + context_attention_fwd_with_v( + q_nope, + q_rope, + kv[:, :, :nope_head_dim], + kv[:, :, nope_head_dim:], + v, + o_tensor.view(-1, self.tp_q_head_num_, nope_head_dim), + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + infer_state.req_manager.req_to_token_indexs, + self.softmax_scale, + ) + else: + context_attention_fwd_no_prompt_cache_with_v( + q_nope, + q_rope, + kv[:, :, :nope_head_dim], + kv[:, :, nope_head_dim:], + v, + o_tensor.view(-1, self.tp_q_head_num_, nope_head_dim), + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + self.softmax_scale, + ) + q_nope = None + q_rope = None + return o_tensor + + def _context_attention_kernel_origin( self, q: Tuple[torch.Tensor, torch.Tensor], kv, infer_state: LlamaInferStateInfo, layer_weight, out=None ) -> torch.Tensor: q_nope, q_rope = q @@ -166,6 +322,20 @@ def _context_attention_kernel( return o_tensor def _token_gqa_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): + if layer_weight.mla_type == "MIX": + return self._token_gqa_decode_attention_flashdecoding_with_ACC(q, infer_state, layer_weight, out) + else: + return self._token_gqa_decode_attention_flashdecoding_origin(q, infer_state, layer_weight, out) + + def _token_gqa_decode_attention_flashdecoding_with_ACC( + self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None + ): + compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_][: infer_state.mem_end, :, :] + return self._ACC_method(q, compressed_kv, infer_state, layer_weight) + + def _token_gqa_decode_attention_flashdecoding_origin( + self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None + ): q_nope, q_rope = q kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank] kv_rope = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank :] diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 614df1ce..995ad1f1 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -75,6 +75,11 @@ def __init__( self.disable_qk_absorb = disable_qk_absorb self.disable_vo_absorb = disable_vo_absorb super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg) + # mla_type = "ACCM", "MIX" + # MIX是prefilled CC,decoding ACC + self.mla_type = "MIX" + if not disable_vo_absorb or not disable_qk_absorb: + self.mla_type = "ACCM" return def _parse_config(self): diff --git a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad.py b/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad.py index 99e2b7b9..ed88ab08 100644 --- a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad.py +++ b/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad.py @@ -9,6 +9,428 @@ CUDA_CAPABILITY = torch.cuda.get_device_capability() +@triton.jit +def _fwd_kernel_with_v( + Q_nope, + Q_rope, + K_nope, + K_rope, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 + Out, + Req_to_tokens, + B_req_idx, + stride_q_bs, + stride_q_h, + stride_q_d, + stride_q_rope_bs, + stride_q_rope_h, + stride_q_rope_d, + stride_k_bs, + stride_k_h, + stride_k_d, + stride_k_rope_bs, + stride_k_rope_h, + stride_k_rope_d, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + b_prompt_cache_len, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_ROPE_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_k_head = cur_head + + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_bs + + cur_head * stride_q_h + + offs_d[None, :] * stride_q_d + ) + off_q_rope = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_rope_bs + + cur_head * stride_q_rope_h + + offs_rope_d[None, :] * stride_q_rope_d + ) + + q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + q_rope = tl.load(Q_rope + off_q_rope, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) + + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), + mask=(start_n + offs_n) < block_end_loc, + other=0, + ) + off_k = k_loc[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] * stride_k_d + off_k_rope = ( + k_loc[None, :] * stride_k_rope_bs + cur_k_head * stride_k_rope_h + offs_rope_d[:, None] * stride_k_rope_d + ) + k = tl.load(K_nope + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) + k_rope = tl.load(K_rope + off_k_rope, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk += tl.dot(q_rope, k_rope) + + qk *= sm_scale + qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float("-100000000.0")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc_scale = tl.where(offs_m + prompt_cache_len >= start_n, acc_scale, 1.0) + acc = acc * acc_scale[:, None] + # update acc + off_v = k_loc[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :] * stride_vd + v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + +@torch.no_grad() +def context_attention_fwd_with_v( + q_nope, + q_rope, + k_nope, + k_rope, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, + softmax_scale, +): + + BLOCK = 128 if not TESLA else 64 + q_nope_dim = q_nope.shape[-1] + q_rope_dim = q_rope.shape[-1] + assert q_nope_dim == k_nope.shape[-1] + assert q_rope_dim == k_rope.shape[-1] + assert q_nope_dim in {16, 32, 64, 128, 256, 512} + assert q_rope_dim in {16, 32, 64, 128, 256} + assert q_nope_dim == v.shape[-1] + + if q_nope_dim >= 512: + BLOCK = 64 if not TESLA else 32 + else: + BLOCK = 128 if not TESLA else 64 + + if q_nope.dtype == torch.float32: + BLOCK = BLOCK // 4 + + sm_scale = softmax_scale + batch, head = b_seq_len.shape[0], q_nope.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + num_warps = 4 if q_nope_dim <= 64 else 8 + + _fwd_kernel_with_v[grid]( + q_nope, + q_rope, + k_nope, + k_rope, + v, + sm_scale, + b_start_loc, + b_seq_len, + o, + req_to_token_indexs, + b_req_idx, + q_nope.stride(0), + q_nope.stride(1), + q_nope.stride(2), + q_rope.stride(0), + q_rope.stride(1), + q_rope.stride(2), + k_nope.stride(0), + k_nope.stride(1), + k_nope.stride(2), + k_rope.stride(0), + k_rope.stride(1), + k_rope.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + b_prompt_cache_len=b_prompt_cache_len, + BLOCK_M=BLOCK, + BLOCK_DMODEL=q_nope_dim, + BLOCK_ROPE_DMODEL=q_rope_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + +@triton.jit +def _fwd_kernel_no_prompt_cache_with_v( + Q_nope, + Q_rope, + K_nope, + K_rope, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 + Out, + stride_q_bs, + stride_q_h, + stride_q_d, + stride_q_rope_bs, + stride_q_rope_h, + stride_q_rope_d, + stride_k_bs, + stride_k_h, + stride_k_d, + stride_k_rope_bs, + stride_k_rope_h, + stride_k_rope_d, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_ROPE_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_k_head = cur_head + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_bs + + cur_head * stride_q_h + + offs_d[None, :] * stride_q_d + ) + off_rope_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_rope_bs + + cur_head * stride_q_rope_h + + offs_rope_d[None, :] * stride_q_rope_d + ) + off_k = offs_n[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] * stride_k_d + off_rope_k = ( + offs_n[None, :] * stride_k_rope_bs + cur_k_head * stride_k_rope_h + offs_rope_d[:, None] * stride_k_rope_d + ) + off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :] * stride_vd + + q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + q_rope = tl.load(Q_rope + off_rope_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K_nope + off_k + k_rope_ptrs = K_rope + off_rope_k + v_ptrs = V + off_v + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_k_bs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + k_rope = tl.load( + k_rope_ptrs + (cur_batch_in_all_start_index + start_n) * stride_k_rope_bs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk += tl.dot(q_rope, k_rope) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + +@torch.no_grad() +def context_attention_fwd_no_prompt_cache_with_v( + q_nope, q_rope, k_nope, k_rope, v, o, b_start_loc, b_seq_len, max_input_len, softmax_scale +): + q_nope_dim = q_nope.shape[-1] + q_rope_dim = q_rope.shape[-1] + assert q_nope_dim == k_nope.shape[-1] + assert q_rope_dim == k_rope.shape[-1] + assert q_nope_dim in {16, 32, 64, 128, 256, 512} + assert q_rope_dim in {16, 32, 64, 128, 256} + assert q_nope_dim == v.shape[-1] + + if q_nope_dim >= 512: + BLOCK = 64 if not TESLA else 32 + else: + BLOCK = 128 if not TESLA else 64 + + if q_nope.dtype == torch.float32: + BLOCK = BLOCK // 4 + + sm_scale = softmax_scale + batch, head = b_seq_len.shape[0], q_nope.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + num_warps = 4 if q_nope_dim <= 64 else 8 + _fwd_kernel_no_prompt_cache_with_v[grid]( + q_nope, + q_rope, + k_nope, + k_rope, + v, + sm_scale, + b_start_loc, + b_seq_len, + o, + q_nope.stride(0), + q_nope.stride(1), + q_nope.stride(2), + q_rope.stride(0), + q_rope.stride(1), + q_rope.stride(2), + k_nope.stride(0), + k_nope.stride(1), + k_nope.stride(2), + k_rope.stride(0), + k_rope.stride(1), + k_rope.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=q_nope_dim, + BLOCK_ROPE_DMODEL=q_rope_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @triton.jit def _fwd_kernel( Q_nope,