Skip to content

Commit

Permalink
improve code style
Browse files Browse the repository at this point in the history
  • Loading branch information
ANDgate99 committed Jun 11, 2024
1 parent 37a8f49 commit 2c19173
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 86 deletions.
10 changes: 6 additions & 4 deletions lightllm/models/mistral/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lightllm.common.req_manager import ReqManager
from lightllm.models.mistral.triton_kernel.init_att_sliding_window_info import init_att_window_info_fwd


class MistralInferStateInfo(LlamaInferStateInfo):
def __init__(self):
super().__init__()
Expand All @@ -13,12 +14,13 @@ def __init__(self):
self.total_cache_num = None
# self.window_postion = None

def init_some_extra_state(self, model, input_ids : torch.Tensor):
def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.sliding_window = model.config["sliding_window"]
if self.is_prefill:
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i])
for i in range(len(b_seq_len_numpy))], axis=0)).cuda()
position_ids = torch.from_numpy(
np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)
).cuda()
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
position_ids = None
Expand All @@ -34,4 +36,4 @@ def init_some_extra_state(self, model, input_ids : torch.Tensor):
init_att_window_info_fwd(self.batch_size, self.b_seq_len, self.b_att_seq_len, self.sliding_window)
self.b_att_start_loc = torch.cumsum(self.b_att_seq_len, 0) - self.b_att_seq_len
self.total_cache_num = torch.sum(self.b_att_seq_len).item()
return
return
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

@triton.jit
def _fwd_kernel_init_att_window_info(
b_seq_len, b_att_seq_len,
batch_size, sliding_window,
b_seq_len,
b_att_seq_len,
batch_size,
sliding_window,
BLOCK_SIZE: tl.constexpr,
):
cur_index = tl.program_id(0)
Expand All @@ -23,20 +25,21 @@ def _fwd_kernel_init_att_window_info(


@torch.no_grad()
def init_att_window_info_fwd(
batch_size, b_seq_len, b_att_seq_len, sliding_window):
def init_att_window_info_fwd(batch_size, b_seq_len, b_att_seq_len, sliding_window):
# shape constraints
assert batch_size == b_seq_len.shape[0] == b_att_seq_len.shape[0]

BLOCK_SIZE = 32
num_warps = 1
grid = (triton.cdiv(batch_size, BLOCK_SIZE), )
grid = (triton.cdiv(batch_size, BLOCK_SIZE),)

_fwd_kernel_init_att_window_info[grid](
b_seq_len, b_att_seq_len,
batch_size=batch_size, sliding_window=sliding_window,
b_seq_len,
b_att_seq_len,
batch_size=batch_size,
sliding_window=sliding_window,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
num_stages=1,
)
return
return
92 changes: 62 additions & 30 deletions lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,68 @@

@triton.jit
def _fwd_kernel_token_att1(
Q, K, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen,
B_Att_Start_Loc, B_Att_Seqlen,
Q,
K,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
B_Att_Start_Loc,
B_Att_Seqlen,
Att_Out,
stride_req_to_tokens_b, stride_req_to_tokens_s,
stride_qbs, stride_qh, stride_qd,
stride_kbs, stride_kh, stride_kd,
att_stride_h, att_stride_bs,
kv_group_num, sliding_window,
stride_req_to_tokens_b,
stride_req_to_tokens_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
att_stride_h,
att_stride_bs,
kv_group_num,
sliding_window,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_n = tl.program_id(2)

cur_kv_head = cur_head // kv_group_num

offs_d = tl.arange(0, BLOCK_DMODEL) # [D]
offs_d = tl.arange(0, BLOCK_DMODEL) # [D]
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index
cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch)

# use new start index of k value
cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0)
cur_batch_end_index = cur_batch_seq_len

off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D]
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D]

offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32]
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32]

# use new value to decide block mask
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number
block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number

for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark
offs_n_new = cur_batch_start_index + offs_n # the latest window of token
k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new,
mask=offs_n_new < cur_batch_end_index, other=0)
off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd # [32, D], find token index
q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark
offs_n_new = cur_batch_start_index + offs_n # the latest window of token
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new,
mask=offs_n_new < cur_batch_end_index,
other=0,
)
off_k = (
k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd
) # [32, D], find token index
k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32]
att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32]
att_value *= sm_scale
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
Expand All @@ -58,8 +77,8 @@ def _fwd_kernel_token_att1(

@torch.no_grad()
def token_att_fwd(
q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen,
B_Att_Start_Loc, B_Att_Seqlen, sliding_window):
q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window
):
BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k.shape[-1]
Expand All @@ -71,25 +90,38 @@ def token_att_fwd(

grid = (batch, head_num, triton.cdiv(sliding_window, BLOCK))
kv_group_num = q.shape[1] // k.shape[1]

if kv_group_num == 1:
num_warps = 4
else:
num_warps = 2

_fwd_kernel_token_att1[grid](
q, k, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen,
B_Att_Start_Loc, B_Att_Seqlen,
q,
k,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
B_Att_Start_Loc,
B_Att_Seqlen,
att_out,
Req_to_tokens.stride(0), Req_to_tokens.stride(1),
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
att_out.stride(0), att_out.stride(1),
Req_to_tokens.stride(0),
Req_to_tokens.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
att_out.stride(0),
att_out.stride(1),
kv_group_num=kv_group_num,
sliding_window=sliding_window,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
return
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _fwd_kernel_token_att2(
offs_n = tl.arange(0, BLOCK_N) # [64]
offs_d = tl.arange(0, BLOCK_DMODEL) # [D]
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index
cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index
# cur_batch_end_index = cur_batch_seq_len
cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # new index
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,28 @@

@triton.jit
def _fwd_kernel(
Logics, V, Out,
Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen,
B_Att_Start_Loc, B_Att_Seqlen,
stride_logic_h, stride_logic_bs,
stride_vbs, stride_vh, stride_vd,
stride_obs, stride_oh, stride_od,
stride_req_to_token_b, stride_req_to_token_s,
other_kv_index, # 避免读取到nan的数据
kv_group_num, sliding_window,
Logics,
V,
Out,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
B_Att_Start_Loc,
B_Att_Seqlen,
stride_logic_h,
stride_logic_bs,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_req_to_token_b,
stride_req_to_token_s,
other_kv_index, # 避免读取到nan的数据
kv_group_num,
sliding_window,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
Expand All @@ -24,36 +37,43 @@ def _fwd_kernel(
cur_kv_head = cur_head // kv_group_num

cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_loc = tl.load(B_Att_Start_Loc + cur_batch) # new index
cur_batch_start_loc = tl.load(B_Att_Start_Loc + cur_batch) # new index
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # new index
cur_cache_start_loc = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index
cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # new index
cur_cache_start_loc = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index

offs_n = tl.arange(0, BLOCK_N) # [64]
offs_d = tl.arange(0, BLOCK_DMODEL) # [D]
offs_n = tl.arange(0, BLOCK_N) # [64]
offs_d = tl.arange(0, BLOCK_DMODEL) # [D]

off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd # [1, D]
off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd # [1, D]
v_ptrs = V + off_v

e_max = float("-inf")
e_sum = 0.0
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [D]
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [D]

for start_n in range(0, cur_att_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) # check
v_index = tl.load(Req_to_tokens + cur_batch_req_idx * stride_req_to_token_b +
(cur_cache_start_loc + start_n + offs_n) * stride_req_to_token_s,
mask=(start_n + offs_n) < cur_att_seq_len, other=other_kv_index) # [64]
start_n = tl.multiple_of(start_n, BLOCK_N) # check
v_index = tl.load(
Req_to_tokens
+ cur_batch_req_idx * stride_req_to_token_b
+ (cur_cache_start_loc + start_n + offs_n) * stride_req_to_token_s,
mask=(start_n + offs_n) < cur_att_seq_len,
other=other_kv_index,
) # [64]

qk = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs,
mask=(start_n + offs_n) < cur_att_seq_len, other=float("-inf")) # [64]

n_e_max = tl.maximum(tl.max(qk, 0), e_max)
qk = tl.load(
Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs,
mask=(start_n + offs_n) < cur_att_seq_len,
other=float("-inf"),
) # [64]

n_e_max = tl.maximum(tl.max(qk, 0), e_max)
old_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max)
e_sum = e_sum * old_scale + tl.sum(p, 0)
v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) # [1, D] + [64, 1] = [64, D]
acc = acc * old_scale + tl.sum(p[:, None] * v, 0) # [64, 1] * [64, D] = [64, D] -> [D]
v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) # [1, D] + [64, 1] = [64, D]
acc = acc * old_scale + tl.sum(p[:, None] * v, 0) # [64, 1] * [64, D] = [64, D] -> [D]
e_max = n_e_max

acc = acc / e_sum
Expand All @@ -65,26 +85,50 @@ def _fwd_kernel(

@torch.no_grad()
def token_softmax_reducev_fwd(
logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len,
b_att_start_loc, b_att_seq_len, other_kv_index, sliding_window):
logics,
v,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_att_start_loc,
b_att_seq_len,
other_kv_index,
sliding_window,
):
BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0]
grid = (batch, head)
kv_group_num = logics.shape[0] // v.shape[1]

num_warps = 1
_fwd_kernel[grid](
logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len,
b_att_start_loc, b_att_seq_len,
logics.stride(0), logics.stride(1),
v.stride(0), v.stride(1), v.stride(2),
o.stride(0), o.stride(1), o.stride(2),
req_to_tokens.stride(0), req_to_tokens.stride(1),
logics,
v,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_att_start_loc,
b_att_seq_len,
logics.stride(0),
logics.stride(1),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
req_to_tokens.stride(0),
req_to_tokens.stride(1),
other_kv_index,
kv_group_num, sliding_window,
kv_group_num,
sliding_window,
BLOCK_DMODEL=v.shape[-1],
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=3
num_stages=3,
)
return
return
Loading

0 comments on commit 2c19173

Please sign in to comment.