Skip to content

Commit

Permalink
use triton to init window info (#420)
Browse files Browse the repository at this point in the history
  • Loading branch information
ANDgate99 authored Jun 12, 2024
1 parent 0163eb5 commit 0bf7ec9
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 126 deletions.
29 changes: 11 additions & 18 deletions lightllm/models/mistral/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,25 @@
import numpy as np
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.common.req_manager import ReqManager
from lightllm.models.mistral.triton_kernel.init_att_sliding_window_info import init_att_window_info_fwd


class MistralInferStateInfo(LlamaInferStateInfo):
def __init__(self):
super().__init__()
self.sliding_window = None
self.b_start_loc_window = None
self.b_att_seq_len = None
self.b_att_start_loc = None
self.total_cache_num = None
# self.window_postion = None

def init_some_extra_state(self, model, input_ids : torch.Tensor):
def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.sliding_window = model.config["sliding_window"]
if self.is_prefill:
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i])
for i in range(len(b_seq_len_numpy))], axis=0)).cuda()
position_ids = torch.from_numpy(
np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)
).cuda()
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
position_ids = None
Expand All @@ -30,17 +32,8 @@ def init_some_extra_state(self, model, input_ids : torch.Tensor):
# b_loc[0, max_len_in_batch - 1].item()

# [SYM] still reserve all kv cache
self.b_att_seq_len = self.b_seq_len.clone()
self.b_att_start_loc = self.b_start_loc.clone()
self.b_start_loc_window = self.b_start_loc.clone()
self.total_cache_num = 0
for i in range(0, self.batch_size):
if self.sliding_window < self.b_seq_len[i]:
self.b_start_loc_window[i] = self.b_seq_len[i] - self.sliding_window
self.b_att_seq_len[i] = self.sliding_window
else:
self.b_start_loc_window[i] = 0
self.b_att_seq_len[i] = self.b_seq_len[i]
self.b_att_start_loc[i] = self.total_cache_num
self.total_cache_num += self.b_att_seq_len[i]
return
self.b_att_seq_len = torch.zeros_like(self.b_seq_len)
init_att_window_info_fwd(self.batch_size, self.b_seq_len, self.b_att_seq_len, self.sliding_window)
self.b_att_start_loc = torch.cumsum(self.b_att_seq_len, 0) - self.b_att_seq_len
self.total_cache_num = torch.sum(self.b_att_seq_len).item()
return
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_start_loc_window,
infer_state.b_att_start_loc,
infer_state.b_att_seq_len,
infer_state.sliding_window,
Expand All @@ -79,9 +78,9 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_start_loc_window,
infer_state.b_att_start_loc,
infer_state.b_att_seq_len,
infer_state.other_kv_index,
infer_state.sliding_window,
)
return o_tensor
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch

import triton
import triton.language as tl


@triton.jit
def _fwd_kernel_init_att_window_info(
b_seq_len,
b_att_seq_len,
batch_size,
sliding_window,
BLOCK_SIZE: tl.constexpr,
):
cur_index = tl.program_id(0)
cur_start = cur_index * BLOCK_SIZE
offsets = cur_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < batch_size

cur_seq_len = tl.load(b_seq_len + offsets, mask=mask)
b_att_seq_len_data = tl.minimum(cur_seq_len, sliding_window)

tl.store(b_att_seq_len + offsets, b_att_seq_len_data, mask=mask)
return


@torch.no_grad()
def init_att_window_info_fwd(batch_size, b_seq_len, b_att_seq_len, sliding_window):
# shape constraints
assert batch_size == b_seq_len.shape[0] == b_att_seq_len.shape[0]

BLOCK_SIZE = 32
num_warps = 1
grid = (triton.cdiv(batch_size, BLOCK_SIZE),)

_fwd_kernel_init_att_window_info[grid](
b_seq_len,
b_att_seq_len,
batch_size=batch_size,
sliding_window=sliding_window,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
num_stages=1,
)
return
94 changes: 63 additions & 31 deletions lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,68 @@

@triton.jit
def _fwd_kernel_token_att1(
Q, K, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen,
B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen,
Q,
K,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
B_Att_Start_Loc,
B_Att_Seqlen,
Att_Out,
stride_req_to_tokens_b, stride_req_to_tokens_s,
stride_qbs, stride_qh, stride_qd,
stride_kbs, stride_kh, stride_kd,
att_stride_h, att_stride_bs,
kv_group_num, sliding_window,
stride_req_to_tokens_b,
stride_req_to_tokens_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
att_stride_h,
att_stride_bs,
kv_group_num,
sliding_window,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_n = tl.program_id(2)

cur_kv_head = cur_head // kv_group_num

offs_d = tl.arange(0, BLOCK_DMODEL) # [D]
offs_d = tl.arange(0, BLOCK_DMODEL) # [D]
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index
cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch)

# use new start index of k value
cur_batch_start_index = tl.load(B_Start_Loc_Window + cur_batch)
cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0)
cur_batch_end_index = cur_batch_seq_len

off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D]
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D]

offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32]
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32]

# use new value to decide block mask
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number
block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number

for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark
offs_n_new = cur_batch_start_index + offs_n # the latest window of token
k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new,
mask=offs_n_new < cur_batch_end_index, other=0)
off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd # [32, D], find token index
q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark
offs_n_new = cur_batch_start_index + offs_n # the latest window of token
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new,
mask=offs_n_new < cur_batch_end_index,
other=0,
)
off_k = (
k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd
) # [32, D], find token index
k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32]
att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32]
att_value *= sm_scale
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
Expand All @@ -58,8 +77,8 @@ def _fwd_kernel_token_att1(

@torch.no_grad()
def token_att_fwd(
q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen,
B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, sliding_window):
q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window
):
BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k.shape[-1]
Expand All @@ -71,25 +90,38 @@ def token_att_fwd(

grid = (batch, head_num, triton.cdiv(sliding_window, BLOCK))
kv_group_num = q.shape[1] // k.shape[1]

if kv_group_num == 1:
num_warps = 4
else:
num_warps = 2

_fwd_kernel_token_att1[grid](
q, k, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen,
B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen,
q,
k,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
B_Att_Start_Loc,
B_Att_Seqlen,
att_out,
Req_to_tokens.stride(0), Req_to_tokens.stride(1),
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
att_out.stride(0), att_out.stride(1),
Req_to_tokens.stride(0),
Req_to_tokens.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
att_out.stride(0),
att_out.stride(1),
kv_group_num=kv_group_num,
sliding_window=sliding_window,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
return
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def _fwd_kernel_token_att2(
B_req_idx,
B_Start_Loc,
B_Seqlen,
B_Start_Loc_Window,
B_Att_Start_Loc,
B_Att_Seqlen,
stride_req_to_tokens_b,
Expand All @@ -27,6 +26,7 @@ def _fwd_kernel_token_att2(
stride_oh,
stride_od,
kv_group_num,
sliding_window,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
Expand All @@ -38,7 +38,7 @@ def _fwd_kernel_token_att2(
offs_n = tl.arange(0, BLOCK_N) # [64]
offs_d = tl.arange(0, BLOCK_DMODEL) # [D]
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_index = tl.load(B_Start_Loc_Window + cur_batch) # new index
cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index
# cur_batch_end_index = cur_batch_seq_len
cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # new index
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
Expand Down Expand Up @@ -75,7 +75,7 @@ def _fwd_kernel_token_att2(

@torch.no_grad()
def token_att_fwd2(
prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen
prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window
):
BLOCK = 128
# BLOCK = 64 # for triton 2.0.0dev
Expand All @@ -94,7 +94,6 @@ def token_att_fwd2(
B_req_idx,
B_Start_Loc,
B_Seqlen,
B_Start_Loc_Window,
B_Att_Start_Loc,
B_Att_Seqlen,
Req_to_tokens.stride(0),
Expand All @@ -108,6 +107,7 @@ def token_att_fwd2(
out.stride(1),
out.stride(2),
kv_group_num=kv_group_num,
siliding_window=sliding_window,
BLOCK_DMODEL=dim,
BLOCK_N=BLOCK,
num_warps=num_warps,
Expand Down
Loading

0 comments on commit 0bf7ec9

Please sign in to comment.