Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use triton to init window info #420

Merged
merged 5 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 11 additions & 18 deletions lightllm/models/mistral/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,25 @@
import numpy as np
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.common.req_manager import ReqManager
from lightllm.models.mistral.triton_kernel.init_att_sliding_window_info import init_att_window_info_fwd


class MistralInferStateInfo(LlamaInferStateInfo):
def __init__(self):
super().__init__()
self.sliding_window = None
self.b_start_loc_window = None
self.b_att_seq_len = None
self.b_att_start_loc = None
self.total_cache_num = None
# self.window_postion = None

def init_some_extra_state(self, model, input_ids : torch.Tensor):
def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.sliding_window = model.config["sliding_window"]
if self.is_prefill:
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i])
for i in range(len(b_seq_len_numpy))], axis=0)).cuda()
position_ids = torch.from_numpy(
np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)
).cuda()
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
position_ids = None
Expand All @@ -30,17 +32,8 @@ def init_some_extra_state(self, model, input_ids : torch.Tensor):
# b_loc[0, max_len_in_batch - 1].item()

# [SYM] still reserve all kv cache
self.b_att_seq_len = self.b_seq_len.clone()
self.b_att_start_loc = self.b_start_loc.clone()
self.b_start_loc_window = self.b_start_loc.clone()
self.total_cache_num = 0
for i in range(0, self.batch_size):
if self.sliding_window < self.b_seq_len[i]:
self.b_start_loc_window[i] = self.b_seq_len[i] - self.sliding_window
self.b_att_seq_len[i] = self.sliding_window
else:
self.b_start_loc_window[i] = 0
self.b_att_seq_len[i] = self.b_seq_len[i]
self.b_att_start_loc[i] = self.total_cache_num
self.total_cache_num += self.b_att_seq_len[i]
return
self.b_att_seq_len = torch.zeros_like(self.b_seq_len)
init_att_window_info_fwd(self.batch_size, self.b_seq_len, self.b_att_seq_len, self.sliding_window)
self.b_att_start_loc = torch.cumsum(self.b_att_seq_len, 0) - self.b_att_seq_len
self.total_cache_num = torch.sum(self.b_att_seq_len).item()
return
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_start_loc_window,
infer_state.b_att_start_loc,
infer_state.b_att_seq_len,
infer_state.sliding_window,
Expand All @@ -79,9 +78,9 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_start_loc_window,
infer_state.b_att_start_loc,
infer_state.b_att_seq_len,
infer_state.other_kv_index,
infer_state.sliding_window,
)
return o_tensor
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch

import triton
import triton.language as tl


@triton.jit
def _fwd_kernel_init_att_window_info(
b_seq_len,
b_att_seq_len,
batch_size,
sliding_window,
BLOCK_SIZE: tl.constexpr,
):
cur_index = tl.program_id(0)
cur_start = cur_index * BLOCK_SIZE
offsets = cur_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < batch_size

cur_seq_len = tl.load(b_seq_len + offsets, mask=mask)
b_att_seq_len_data = tl.minimum(cur_seq_len, sliding_window)

tl.store(b_att_seq_len + offsets, b_att_seq_len_data, mask=mask)
return


@torch.no_grad()
def init_att_window_info_fwd(batch_size, b_seq_len, b_att_seq_len, sliding_window):
# shape constraints
assert batch_size == b_seq_len.shape[0] == b_att_seq_len.shape[0]

BLOCK_SIZE = 32
num_warps = 1
grid = (triton.cdiv(batch_size, BLOCK_SIZE),)

_fwd_kernel_init_att_window_info[grid](
b_seq_len,
b_att_seq_len,
batch_size=batch_size,
sliding_window=sliding_window,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
num_stages=1,
)
return
94 changes: 63 additions & 31 deletions lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,68 @@

@triton.jit
def _fwd_kernel_token_att1(
Q, K, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen,
B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen,
Q,
K,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
B_Att_Start_Loc,
B_Att_Seqlen,
Att_Out,
stride_req_to_tokens_b, stride_req_to_tokens_s,
stride_qbs, stride_qh, stride_qd,
stride_kbs, stride_kh, stride_kd,
att_stride_h, att_stride_bs,
kv_group_num, sliding_window,
stride_req_to_tokens_b,
stride_req_to_tokens_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
att_stride_h,
att_stride_bs,
kv_group_num,
sliding_window,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_n = tl.program_id(2)

cur_kv_head = cur_head // kv_group_num

offs_d = tl.arange(0, BLOCK_DMODEL) # [D]
offs_d = tl.arange(0, BLOCK_DMODEL) # [D]
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index
cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch)

# use new start index of k value
cur_batch_start_index = tl.load(B_Start_Loc_Window + cur_batch)
cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0)
cur_batch_end_index = cur_batch_seq_len

off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D]
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D]

offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32]
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32]

# use new value to decide block mask
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number
block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number

for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark
offs_n_new = cur_batch_start_index + offs_n # the latest window of token
k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new,
mask=offs_n_new < cur_batch_end_index, other=0)
off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd # [32, D], find token index
q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark
offs_n_new = cur_batch_start_index + offs_n # the latest window of token
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new,
mask=offs_n_new < cur_batch_end_index,
other=0,
)
off_k = (
k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd
) # [32, D], find token index
k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32]
att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32]
att_value *= sm_scale
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
Expand All @@ -58,8 +77,8 @@ def _fwd_kernel_token_att1(

@torch.no_grad()
def token_att_fwd(
q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen,
B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, sliding_window):
q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window
):
BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k.shape[-1]
Expand All @@ -71,25 +90,38 @@ def token_att_fwd(

grid = (batch, head_num, triton.cdiv(sliding_window, BLOCK))
kv_group_num = q.shape[1] // k.shape[1]

if kv_group_num == 1:
num_warps = 4
else:
num_warps = 2

_fwd_kernel_token_att1[grid](
q, k, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen,
B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen,
q,
k,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
B_Att_Start_Loc,
B_Att_Seqlen,
att_out,
Req_to_tokens.stride(0), Req_to_tokens.stride(1),
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
att_out.stride(0), att_out.stride(1),
Req_to_tokens.stride(0),
Req_to_tokens.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
att_out.stride(0),
att_out.stride(1),
kv_group_num=kv_group_num,
sliding_window=sliding_window,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
return
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def _fwd_kernel_token_att2(
B_req_idx,
B_Start_Loc,
B_Seqlen,
B_Start_Loc_Window,
B_Att_Start_Loc,
B_Att_Seqlen,
stride_req_to_tokens_b,
Expand All @@ -27,6 +26,7 @@ def _fwd_kernel_token_att2(
stride_oh,
stride_od,
kv_group_num,
sliding_window,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
Expand All @@ -38,7 +38,7 @@ def _fwd_kernel_token_att2(
offs_n = tl.arange(0, BLOCK_N) # [64]
offs_d = tl.arange(0, BLOCK_DMODEL) # [D]
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_index = tl.load(B_Start_Loc_Window + cur_batch) # new index
cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index
# cur_batch_end_index = cur_batch_seq_len
cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # new index
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
Expand Down Expand Up @@ -75,7 +75,7 @@ def _fwd_kernel_token_att2(

@torch.no_grad()
def token_att_fwd2(
prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen
prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window
):
BLOCK = 128
# BLOCK = 64 # for triton 2.0.0dev
Expand All @@ -94,7 +94,6 @@ def token_att_fwd2(
B_req_idx,
B_Start_Loc,
B_Seqlen,
B_Start_Loc_Window,
B_Att_Start_Loc,
B_Att_Seqlen,
Req_to_tokens.stride(0),
Expand All @@ -108,6 +107,7 @@ def token_att_fwd2(
out.stride(1),
out.stride(2),
kv_group_num=kv_group_num,
siliding_window=sliding_window,
BLOCK_DMODEL=dim,
BLOCK_N=BLOCK,
num_warps=num_warps,
Expand Down
Loading
Loading