diff --git a/lightllm/models/mistral/infer_struct.py b/lightllm/models/mistral/infer_struct.py index e64df83c4..b53a92d61 100644 --- a/lightllm/models/mistral/infer_struct.py +++ b/lightllm/models/mistral/infer_struct.py @@ -2,23 +2,25 @@ import numpy as np from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.req_manager import ReqManager +from lightllm.models.mistral.triton_kernel.init_att_sliding_window_info import init_att_window_info_fwd + class MistralInferStateInfo(LlamaInferStateInfo): def __init__(self): super().__init__() self.sliding_window = None - self.b_start_loc_window = None self.b_att_seq_len = None self.b_att_start_loc = None self.total_cache_num = None # self.window_postion = None - def init_some_extra_state(self, model, input_ids : torch.Tensor): + def init_some_extra_state(self, model, input_ids: torch.Tensor): self.sliding_window = model.config["sliding_window"] if self.is_prefill: b_seq_len_numpy = self.b_seq_len.cpu().numpy() - position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) - for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + position_ids = torch.from_numpy( + np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0) + ).cuda() self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) position_ids = None @@ -30,17 +32,8 @@ def init_some_extra_state(self, model, input_ids : torch.Tensor): # b_loc[0, max_len_in_batch - 1].item() # [SYM] still reserve all kv cache - self.b_att_seq_len = self.b_seq_len.clone() - self.b_att_start_loc = self.b_start_loc.clone() - self.b_start_loc_window = self.b_start_loc.clone() - self.total_cache_num = 0 - for i in range(0, self.batch_size): - if self.sliding_window < self.b_seq_len[i]: - self.b_start_loc_window[i] = self.b_seq_len[i] - self.sliding_window - self.b_att_seq_len[i] = self.sliding_window - else: - self.b_start_loc_window[i] = 0 - self.b_att_seq_len[i] = self.b_seq_len[i] - self.b_att_start_loc[i] = self.total_cache_num - self.total_cache_num += self.b_att_seq_len[i] - return \ No newline at end of file + self.b_att_seq_len = torch.zeros_like(self.b_seq_len) + init_att_window_info_fwd(self.batch_size, self.b_seq_len, self.b_att_seq_len, self.sliding_window) + self.b_att_start_loc = torch.cumsum(self.b_att_seq_len, 0) - self.b_att_seq_len + self.total_cache_num = torch.sum(self.b_att_seq_len).item() + return diff --git a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py index 8ed5a7fb5..51cb4817d 100755 --- a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py @@ -59,7 +59,6 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.b_start_loc_window, infer_state.b_att_start_loc, infer_state.b_att_seq_len, infer_state.sliding_window, @@ -79,9 +78,9 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.b_start_loc_window, infer_state.b_att_start_loc, infer_state.b_att_seq_len, infer_state.other_kv_index, + infer_state.sliding_window, ) return o_tensor diff --git a/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py b/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py new file mode 100644 index 000000000..a60fe970b --- /dev/null +++ b/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py @@ -0,0 +1,45 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_init_att_window_info( + b_seq_len, + b_att_seq_len, + batch_size, + sliding_window, + BLOCK_SIZE: tl.constexpr, +): + cur_index = tl.program_id(0) + cur_start = cur_index * BLOCK_SIZE + offsets = cur_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < batch_size + + cur_seq_len = tl.load(b_seq_len + offsets, mask=mask) + b_att_seq_len_data = tl.minimum(cur_seq_len, sliding_window) + + tl.store(b_att_seq_len + offsets, b_att_seq_len_data, mask=mask) + return + + +@torch.no_grad() +def init_att_window_info_fwd(batch_size, b_seq_len, b_att_seq_len, sliding_window): + # shape constraints + assert batch_size == b_seq_len.shape[0] == b_att_seq_len.shape[0] + + BLOCK_SIZE = 32 + num_warps = 1 + grid = (triton.cdiv(batch_size, BLOCK_SIZE),) + + _fwd_kernel_init_att_window_info[grid]( + b_seq_len, + b_att_seq_len, + batch_size=batch_size, + sliding_window=sliding_window, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py index 1dec71d42..09ce9d2ab 100644 --- a/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py @@ -7,49 +7,68 @@ @triton.jit def _fwd_kernel_token_att1( - Q, K, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, - B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, + Q, + K, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + B_Att_Start_Loc, + B_Att_Seqlen, Att_Out, - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - att_stride_h, att_stride_bs, - kv_group_num, sliding_window, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + att_stride_h, + att_stride_bs, + kv_group_num, + sliding_window, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr + BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_n = tl.program_id(2) - + cur_kv_head = cur_head // kv_group_num - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] + offs_d = tl.arange(0, BLOCK_DMODEL) # [D] cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index + cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index cur_batch_req_idx = tl.load(B_req_idx + cur_batch) cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # use new start index of k value - cur_batch_start_index = tl.load(B_Start_Loc_Window + cur_batch) + cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) cur_batch_end_index = cur_batch_seq_len - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D] + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D] - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32] + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32] # use new value to decide block mask block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number + block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark - offs_n_new = cur_batch_start_index + offs_n # the latest window of token - k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, - mask=offs_n_new < cur_batch_end_index, other=0) - off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd # [32, D], find token index + q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark + offs_n_new = cur_batch_start_index + offs_n # the latest window of token + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + off_k = ( + k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd + ) # [32, D], find token index k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32] + att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32] att_value *= sm_scale off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) @@ -58,8 +77,8 @@ def _fwd_kernel_token_att1( @torch.no_grad() def token_att_fwd( - q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, - B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, sliding_window): + q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window +): BLOCK = 32 # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] @@ -71,20 +90,33 @@ def token_att_fwd( grid = (batch, head_num, triton.cdiv(sliding_window, BLOCK)) kv_group_num = q.shape[1] // k.shape[1] - + if kv_group_num == 1: num_warps = 4 else: num_warps = 2 _fwd_kernel_token_att1[grid]( - q, k, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, - B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, + q, + k, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + B_Att_Start_Loc, + B_Att_Seqlen, att_out, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - att_out.stride(0), att_out.stride(1), + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + att_out.stride(0), + att_out.stride(1), kv_group_num=kv_group_num, sliding_window=sliding_window, BLOCK_DMODEL=Lk, @@ -92,4 +124,4 @@ def token_att_fwd( num_warps=num_warps, num_stages=1, ) - return \ No newline at end of file + return diff --git a/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py index 60e3d13b9..acf4923f8 100644 --- a/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py +++ b/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py @@ -13,7 +13,6 @@ def _fwd_kernel_token_att2( B_req_idx, B_Start_Loc, B_Seqlen, - B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, stride_req_to_tokens_b, @@ -27,6 +26,7 @@ def _fwd_kernel_token_att2( stride_oh, stride_od, kv_group_num, + sliding_window, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -38,7 +38,7 @@ def _fwd_kernel_token_att2( offs_n = tl.arange(0, BLOCK_N) # [64] offs_d = tl.arange(0, BLOCK_DMODEL) # [D] cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = tl.load(B_Start_Loc_Window + cur_batch) # new index + cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index # cur_batch_end_index = cur_batch_seq_len cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # new index cur_batch_req_idx = tl.load(B_req_idx + cur_batch) @@ -75,7 +75,7 @@ def _fwd_kernel_token_att2( @torch.no_grad() def token_att_fwd2( - prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen + prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window ): BLOCK = 128 # BLOCK = 64 # for triton 2.0.0dev @@ -94,7 +94,6 @@ def token_att_fwd2( B_req_idx, B_Start_Loc, B_Seqlen, - B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, Req_to_tokens.stride(0), @@ -108,6 +107,7 @@ def token_att_fwd2( out.stride(1), out.stride(2), kv_group_num=kv_group_num, + siliding_window=sliding_window, BLOCK_DMODEL=dim, BLOCK_N=BLOCK, num_warps=num_warps, diff --git a/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py b/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py index a620706eb..c37013f18 100644 --- a/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py +++ b/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py @@ -6,15 +6,28 @@ @triton.jit def _fwd_kernel( - Logics, V, Out, - Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, - B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, - stride_logic_h, stride_logic_bs, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - stride_req_to_token_b, stride_req_to_token_s, - other_kv_index, # 避免读取到nan的数据 + Logics, + V, + Out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + B_Att_Start_Loc, + B_Att_Seqlen, + stride_logic_h, + stride_logic_bs, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_token_b, + stride_req_to_token_s, + other_kv_index, # 避免读取到nan的数据 kv_group_num, + sliding_window, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -24,36 +37,43 @@ def _fwd_kernel( cur_kv_head = cur_head // kv_group_num cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Att_Start_Loc + cur_batch) # new index + cur_batch_start_loc = tl.load(B_Att_Start_Loc + cur_batch) # new index cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # new index - cur_cache_start_loc = tl.load(B_Start_Loc_Window + cur_batch) # new index + cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # new index + cur_cache_start_loc = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index - offs_n = tl.arange(0, BLOCK_N) # [64] - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] + offs_n = tl.arange(0, BLOCK_N) # [64] + offs_d = tl.arange(0, BLOCK_DMODEL) # [D] - off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd # [1, D] + off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd # [1, D] v_ptrs = V + off_v e_max = float("-inf") e_sum = 0.0 - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [D] + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [D] for start_n in range(0, cur_att_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) # check - v_index = tl.load(Req_to_tokens + cur_batch_req_idx * stride_req_to_token_b + - (cur_cache_start_loc + start_n + offs_n) * stride_req_to_token_s, - mask=(cur_cache_start_loc + start_n + offs_n) < cur_batch_seq_len, other=other_kv_index) # [64] + start_n = tl.multiple_of(start_n, BLOCK_N) # check + v_index = tl.load( + Req_to_tokens + + cur_batch_req_idx * stride_req_to_token_b + + (cur_cache_start_loc + start_n + offs_n) * stride_req_to_token_s, + mask=(start_n + offs_n) < cur_att_seq_len, + other=other_kv_index, + ) # [64] - qk = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, - mask=start_n + offs_n < cur_batch_seq_len, other=float("-inf")) # [64] - - n_e_max = tl.maximum(tl.max(qk, 0), e_max) + qk = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, + mask=(start_n + offs_n) < cur_att_seq_len, + other=float("-inf"), + ) # [64] + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) old_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max) e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) # [1, D] + [64, 1] = [64, D] - acc = acc * old_scale + tl.sum(p[:, None] * v, 0) # [64, 1] * [64, D] = [64, D] -> [D] + v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) # [1, D] + [64, 1] = [64, D] + acc = acc * old_scale + tl.sum(p[:, None] * v, 0) # [64, 1] * [64, D] = [64, D] -> [D] e_max = n_e_max acc = acc / e_sum @@ -65,8 +85,18 @@ def _fwd_kernel( @torch.no_grad() def token_softmax_reducev_fwd( - logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len, - b_start_loc_window, b_att_start_loc, b_att_seq_len, other_kv_index): + logics, + v, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + b_att_start_loc, + b_att_seq_len, + other_kv_index, + sliding_window, +): BLOCK = 64 batch, head = b_seq_len.shape[0], logics.shape[0] grid = (batch, head) @@ -74,17 +104,31 @@ def token_softmax_reducev_fwd( num_warps = 1 _fwd_kernel[grid]( - logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len, - b_start_loc_window, b_att_start_loc, b_att_seq_len, - logics.stride(0), logics.stride(1), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - req_to_tokens.stride(0), req_to_tokens.stride(1), + logics, + v, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + b_att_start_loc, + b_att_seq_len, + logics.stride(0), + logics.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_tokens.stride(0), + req_to_tokens.stride(1), other_kv_index, kv_group_num, + sliding_window, BLOCK_DMODEL=v.shape[-1], BLOCK_N=BLOCK, num_warps=num_warps, - num_stages=3 + num_stages=3, ) - return \ No newline at end of file + return diff --git a/lightllm/models/mixtral/infer_struct.py b/lightllm/models/mixtral/infer_struct.py index 19303be39..426b28c5a 100644 --- a/lightllm/models/mixtral/infer_struct.py +++ b/lightllm/models/mixtral/infer_struct.py @@ -3,30 +3,32 @@ from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.req_manager import ReqManager from lightllm.models.mistral.infer_struct import MistralInferStateInfo +from lightllm.models.mistral.triton_kernel.init_att_sliding_window_info import init_att_window_info_fwd from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) + class MixtralInferStateInfo(MistralInferStateInfo): def __init__(self): super().__init__() self.sliding_window = None - self.b_start_loc_window = None self.b_att_seq_len = None self.b_att_start_loc = None self.total_cache_num = None self.experts_topk = None self.num_local_experts = None - def init_some_extra_state(self, model, input_ids : torch.Tensor): + def init_some_extra_state(self, model, input_ids: torch.Tensor): # sliding_window is not used in Mixtral 8x7b, ignore it self.sliding_window = 4096 if model.config["sliding_window"] is None else model.config["sliding_window"] self.experts_topk = model.config["num_experts_per_tok"] self.num_local_experts = model.config["num_local_experts"] if self.is_prefill: b_seq_len_numpy = self.b_seq_len.cpu().numpy() - position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) - for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + position_ids = torch.from_numpy( + np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0) + ).cuda() self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) position_ids = None @@ -38,17 +40,8 @@ def init_some_extra_state(self, model, input_ids : torch.Tensor): # b_loc[0, max_len_in_batch - 1].item() # [SYM] still reserve all kv cache - self.b_att_seq_len = self.b_seq_len.clone() - self.b_att_start_loc = self.b_start_loc.clone() - self.b_start_loc_window = self.b_start_loc.clone() - self.total_cache_num = 0 - for i in range(0, self.batch_size): - if self.sliding_window < self.b_seq_len[i]: - self.b_start_loc_window[i] = self.b_seq_len[i] - self.sliding_window - self.b_att_seq_len[i] = self.sliding_window - else: - self.b_start_loc_window[i] = 0 - self.b_att_seq_len[i] = self.b_seq_len[i] - self.b_att_start_loc[i] = self.total_cache_num - self.total_cache_num += self.b_att_seq_len[i] - return \ No newline at end of file + self.b_att_seq_len = torch.zeros_like(self.b_seq_len) + init_att_window_info_fwd(self.batch_size, self.b_seq_len, self.b_att_seq_len, self.sliding_window) + self.b_att_start_loc = torch.cumsum(self.b_att_seq_len, 0) - self.b_att_seq_len + self.total_cache_num = torch.sum(self.b_att_seq_len).item() + return diff --git a/lightllm/models/qwen2/infer_struct.py b/lightllm/models/qwen2/infer_struct.py index 074c457dc..4cb1b61ab 100644 --- a/lightllm/models/qwen2/infer_struct.py +++ b/lightllm/models/qwen2/infer_struct.py @@ -1,6 +1,7 @@ import torch import numpy as np from lightllm.common.basemodel import InferStateInfo +from lightllm.models.mistral.triton_kernel.init_att_sliding_window_info import init_att_window_info_fwd from lightllm.common.req_manager import ReqManager @@ -8,7 +9,6 @@ class Qwen2InferStateInfo(InferStateInfo): def __init__(self): super().__init__() self.sliding_window = None - self.b_start_loc_window = None self.b_att_seq_len = None self.b_att_start_loc = None self.total_cache_num = None @@ -30,17 +30,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item() - self.b_att_seq_len = self.b_seq_len.clone() - self.b_att_start_loc = self.b_start_loc.clone() - self.b_start_loc_window = self.b_start_loc.clone() - self.total_cache_num = 0 - for i in range(0, self.batch_size): - if self.sliding_window < self.b_seq_len[i]: - self.b_start_loc_window[i] = self.b_seq_len[i] - self.sliding_window - self.b_att_seq_len[i] = self.sliding_window - else: - self.b_start_loc_window[i] = 0 - self.b_att_seq_len[i] = self.b_seq_len[i] - self.b_att_start_loc[i] = self.total_cache_num - self.total_cache_num += self.b_att_seq_len[i] + self.b_att_seq_len = torch.zeros_like(self.b_seq_len) + init_att_window_info_fwd(self.batch_size, self.b_seq_len, self.b_att_seq_len, self.sliding_window) + self.b_att_start_loc = torch.cumsum(self.b_att_seq_len, 0) - self.b_att_seq_len + self.total_cache_num = torch.sum(self.b_att_seq_len).item() return diff --git a/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py index 003f38dc5..018085136 100644 --- a/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py @@ -85,7 +85,6 @@ def _token_decode_attention_normal(self, q, infer_state: Qwen2InferStateInfo, la infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.b_start_loc_window, infer_state.b_att_start_loc, infer_state.b_att_seq_len, infer_state.sliding_window, @@ -107,9 +106,9 @@ def _token_decode_attention_normal(self, q, infer_state: Qwen2InferStateInfo, la infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.b_start_loc_window, infer_state.b_att_start_loc, infer_state.b_att_seq_len, infer_state.other_kv_index, + infer_state.sliding_window, ) return o_tensor diff --git a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py index 5ddc8faa2..5aed58a59 100644 --- a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py @@ -119,7 +119,6 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.b_start_loc_window, infer_state.b_att_start_loc, infer_state.b_att_seq_len, infer_state.sliding_window, @@ -141,9 +140,9 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.b_start_loc_window, infer_state.b_att_start_loc, infer_state.b_att_seq_len, infer_state.other_kv_index, + infer_state.sliding_window, ) return o_tensor