Skip to content

Commit

Permalink
【Feature】PD Mode Support (#607)
Browse files Browse the repository at this point in the history
Co-authored-by: wangzaijun <[email protected]>
  • Loading branch information
hiworldwzj and wangzaijun authored Nov 18, 2024
1 parent d9e3ba2 commit 06afb4a
Show file tree
Hide file tree
Showing 54 changed files with 3,069 additions and 529 deletions.
86 changes: 38 additions & 48 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.common.basemodel.cuda_graph import CudaGraph
from lightllm.utils.log_utils import init_logger
from lightllm.common.basemodel.infer_lock import g_infer_state_lock

logger = init_logger(__name__)

Expand All @@ -38,6 +39,7 @@ class TpPartBaseModel:
splitfuse_infer_state_class = SplitFuseInferStateInfo

def __init__(self, kvargs):
self.run_mode = kvargs["run_mode"]
self.tp_rank_ = kvargs["tp_rank"]
self.world_size_ = kvargs["world_size"]
self.weight_dir_ = kvargs["weight_dir"]
Expand Down Expand Up @@ -67,6 +69,7 @@ def __init__(self, kvargs):
self._verify_params()
self._init_weights()
self._init_mem_manager()
self._init_kv_move_buffer()
self._check_mem_size()
self._init_req_manager()
self._init_infer_layer()
Expand Down Expand Up @@ -131,6 +134,11 @@ def _init_mem_manager(self):
)
return

def _init_kv_move_buffer(self):
# p d 分离的推理模式下才需要做这一步初始化
if self.run_mode in ["prefill", "decode"]:
self.mem_manager.alloc_kv_move_buffer(self.max_seq_length)

def _check_mem_size(self):
self.max_total_token_num = self.mem_manager.size
assert self.max_seq_length < self.max_total_token_num
Expand Down Expand Up @@ -192,6 +200,7 @@ def forward(
total_token_num,
max_len_in_batch,
input_ids: torch.Tensor,
mem_indexes: torch.Tensor,
b_req_idx: torch.Tensor,
b_start_loc: torch.Tensor,
b_seq_len: torch.Tensor,
Expand All @@ -205,6 +214,7 @@ def forward(
total_token_num,
max_len_in_batch,
input_ids,
mem_indexes,
b_req_idx,
b_start_loc,
b_seq_len,
Expand All @@ -217,6 +227,7 @@ def forward(
total_token_num,
max_len_in_batch,
input_ids,
mem_indexes,
b_req_idx,
b_start_loc,
b_seq_len,
Expand All @@ -229,6 +240,7 @@ def _prefill(
total_token_num,
max_len_in_batch,
input_ids,
mem_indexes,
b_req_idx,
b_start_loc,
b_seq_len,
Expand Down Expand Up @@ -256,22 +268,13 @@ def _prefill(
infer_state.mem_manager = self.mem_manager
infer_state.req_manager = self.req_manager

alloc_mem = self.mem_manager.alloc_contiguous(input_ids.shape[0])
if alloc_mem is not None:
infer_state.mem_is_contiguous = True
infer_state.mem_index = alloc_mem[0]
infer_state.mem_start = alloc_mem[1]
infer_state.mem_end = alloc_mem[2]

else:
infer_state.mem_is_contiguous = False
alloc_mem = self.mem_manager.alloc(input_ids.shape[0])
infer_state.mem_index = alloc_mem
infer_state.kv_buffer = torch.empty(
(input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
dtype=self.data_type,
device="cuda",
)
infer_state.mem_is_contiguous = False
infer_state.mem_index = mem_indexes
infer_state.kv_buffer = torch.empty(
(input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
dtype=self.data_type,
device="cuda",
)

init_req_to_token_indexes(
self.req_manager.req_to_token_indexs,
Expand All @@ -292,6 +295,7 @@ def _decode(
total_token_num,
max_len_in_batch,
input_ids,
mem_indexes,
b_req_idx,
b_start_loc,
b_seq_len,
Expand All @@ -314,23 +318,14 @@ def _decode(

# 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
# 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
alloc_mem = None if self.graph is not None else self.mem_manager.alloc_contiguous(batch_size)
if alloc_mem is not None:
infer_state.mem_is_contiguous = True
infer_state.mem_index = alloc_mem[0]
infer_state.mem_start = alloc_mem[1]
infer_state.mem_end = alloc_mem[2]
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
else:
infer_state.mem_is_contiguous = False
alloc_mem = self.mem_manager.alloc(batch_size)
infer_state.mem_index = alloc_mem
infer_state.kv_buffer = torch.empty(
(batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
dtype=self.data_type,
device="cuda",
)
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
infer_state.mem_is_contiguous = False
infer_state.mem_index = mem_indexes
infer_state.kv_buffer = torch.empty(
(batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
dtype=self.data_type,
device="cuda",
)
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)

infer_state.init_some_extra_state(self, input_ids)
if self.graph is not None and self.graph.can_run(batch_size, max_len_in_batch):
Expand All @@ -347,6 +342,7 @@ def _decode(
def splitfuse_forward(
self,
input_ids,
mem_indexes,
decode_req_num,
decode_total_token_num,
decode_b_req_idx: torch.Tensor,
Expand Down Expand Up @@ -384,21 +380,13 @@ def splitfuse_forward(
infer_state.req_manager = self.req_manager

alloc_size = len(input_ids)
alloc_mem = self.mem_manager.alloc_contiguous(alloc_size)
if alloc_mem is not None:
infer_state.mem_is_contiguous = True
infer_state.mem_index = alloc_mem[0]
infer_state.mem_start = alloc_mem[1]
infer_state.mem_end = alloc_mem[2]
else:
infer_state.mem_is_contiguous = False
alloc_mem = self.mem_manager.alloc(alloc_size)
infer_state.mem_index = alloc_mem
infer_state.kv_buffer = torch.empty(
(alloc_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
dtype=self.data_type,
device="cuda",
)
infer_state.mem_is_contiguous = False
infer_state.mem_index = mem_indexes
infer_state.kv_buffer = torch.empty(
(alloc_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
dtype=self.data_type,
device="cuda",
)

# decode 部分
if decode_req_num != 0:
Expand Down Expand Up @@ -474,6 +462,7 @@ def _check_max_len_infer(self):
logger.info("begin check max_len infer")
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda")
b_req_idx = self.req_manager.alloc(1).int()
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids))
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
b_seq_len[:] = self.batch_max_tokens
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
Expand All @@ -484,6 +473,7 @@ def _check_max_len_infer(self):
total_token_num,
self.batch_max_tokens,
dummy_input_ids,
mem_indexes,
b_req_idx,
b_start_loc,
b_seq_len,
Expand Down
6 changes: 6 additions & 0 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def warmup(self, model):
prefill_input_len = 1
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
b_req_idx = model.req_manager.alloc(batch_size).int()
mem_indexes = model.mem_manager.alloc(len(dummy_input_ids))
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_start_loc = torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
Expand All @@ -61,13 +62,15 @@ def warmup(self, model):
total_token_num,
prefill_input_len,
dummy_input_ids,
mem_indexes,
b_req_idx,
b_start_loc,
b_seq_len,
b_ready_cache_len=b_ready_cache_len,
is_prefill=True,
multimodal_params=[],
)
mem_indexes = None
prob_out = torch.softmax(logics, dim=-1)
logics = None
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
Expand All @@ -79,16 +82,19 @@ def warmup(self, model):
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
total_token_num += batch_size
b_seq_len += 1
mem_indexes = model.mem_manager.alloc(len(predict_ids))
logics = model.forward(
batch_size,
total_token_num,
prefill_input_len + 1,
torch.from_numpy(predict_ids).cuda().reshape(-1),
mem_indexes,
b_req_idx,
b_start_loc,
b_seq_len,
is_prefill=False,
)
mem_indexes = None
model.mem_manager.free_all()
model.req_manager.free_all()
# release local tensors
Expand Down
128 changes: 128 additions & 0 deletions lightllm/common/basemodel/infer_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# 这不是一个很好的设计但是不是很好找到更好更简单对架构入侵更小的实现方法。
# 这个地方声明的锁和计数,主要是用来解决在 PD 分离模式下,kv_move_manager 进程中会出现
# 通过rpyc调用操作 radix cache 和 mem_manager 中的数据的问题,这可能导致严重的数据同步
# 问题,主要原因是各个tp的推理进程运行到的位置节点并没有严格的保证,导致radix cache 和
# mem manager 中的数据出现各个进程间不一致的问题。
# 下面的实现中,通过一个锁和计数对象, 配合使用的方式,来解决这个问题。
from dataclasses import dataclass
import numpy as np
import threading
from lightllm.server.router.dynamic_prompt.shared_arr import SharedArray
import torch.distributed as dist
import time
import torch.multiprocessing as mp
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class InferStateLock:
def __init__(self, name):
self.infer_lock = threading.Lock()
# 默认开 128 tp 的空间, 现在应该没什么卡能开这么大的tp 吧
self.lock_tp_infos = SharedArray(f"{name}_lock_tp_infos", shape=(129,), dtype=np.int64)
self.lock_tp_infos.arr[:] = 0
self.rank_id = dist.get_rank()
self.world_size = dist.get_world_size()

def add_cur_mark(self):
self.lock_tp_infos.arr[self.rank_id] += 1

def get_cur_mark(self):
return self.lock_tp_infos.arr[self.rank_id]

def get_max_mark_in_group(self):
return np.max(self.lock_tp_infos.arr[0 : self.world_size])

def judge_cur_mark_equal_max_mark_in_group(self):
return self.get_cur_mark() == self.get_max_mark_in_group()

def judge_mark_in_group_all_same(self):
marks = self.lock_tp_infos.arr[0 : self.world_size]
return bool(np.all(marks == marks[0]))

def acquire_lock_and_update_cur_mark(self):
self.infer_lock.acquire()
self.add_cur_mark()

def release_lock(self):
self.infer_lock.release()

def set_group_wait_mark(self):
if self.rank_id == 0:
self.lock_tp_infos.arr[-1] = 1

def unset_group_wait_mark(self):
if self.rank_id == 0:
self.lock_tp_infos.arr[-1] = 0

def get_group_wait_mark(self):
return self.lock_tp_infos.arr[-1]


@dataclass
class G_Infer_Lock:
obj: InferStateLock = None

def acquire(self):
if self.obj is not None:
# 当遇到有同步请求的时候,同时自己的mark已经是最大的mark的时候,就在这里休眠,
# 不去竞争锁, 因为 wait_mark == 1 的时候, 说明wait_get_locks被调用,有人
# 在申请同步点操作
while self.obj.get_group_wait_mark() == 1 and self.obj.judge_cur_mark_equal_max_mark_in_group():
time.sleep(0)

self.obj.acquire_lock_and_update_cur_mark()

def release(self):
if self.obj is not None:
self.obj.release_lock()


# 后续由 backend 对象来对obj进行初始化赋值,方便进行全局调用
g_infer_state_lock = G_Infer_Lock()


# 下面两个函数需要配对使用
def acquire_lock_until_ready(nccl_group):
g_infer_state_lock.obj.set_group_wait_mark()
while True:
g_infer_state_lock.obj.infer_lock.acquire()
dist.barrier(nccl_group)
judge_ans = g_infer_state_lock.obj.judge_mark_in_group_all_same()
dist.barrier(nccl_group)

if judge_ans is not True:
# 释放锁进行重试
g_infer_state_lock.obj.infer_lock.release()
time.sleep(0.001)
logger.info("wait get locks sleep 1ms")
else:
break

g_infer_state_lock.obj.unset_group_wait_mark()
return


def release_acquired_lock():
g_infer_state_lock.obj.infer_lock.release()


@dataclass
class G_Router_Lock:
"""
保护pd分离模式下, 一些数据的操作。
"""

obj = None # 进程锁对象

def acquire(self):
if self.obj is not None:
self.obj.acquire()

def release(self):
if self.obj is not None:
self.obj.release()


g_router_lock = G_Router_Lock()
Loading

0 comments on commit 06afb4a

Please sign in to comment.