Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add DP framework support. #622

Merged
merged 3 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions lightllm/common/basemodel/infer_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ def get_group_wait_mark(self):
@dataclass
class G_Infer_Lock:
obj: InferStateLock = None
dp_size: int = None

def acquire(self):
if self.obj is not None:
# 当遇到有同步请求的时候,同时自己的mark已经是最大的mark的时候,就在这里休眠,
# 不去竞争锁, 因为 wait_mark == 1 的时候, 说明wait_get_locks被调用,有人
# 在申请同步点操作
# 不去竞争锁, 因为 wait_mark == 1 的时候, 说明acquire_lock_until_ready被调用,
# 有推理进程在申请同步点操作
while self.obj.get_group_wait_mark() == 1 and self.obj.judge_cur_mark_equal_max_mark_in_group():
time.sleep(0)

Expand All @@ -85,6 +86,12 @@ def release(self):

# 下面两个函数需要配对使用
def acquire_lock_until_ready(nccl_group):
# 在 deepseekv2 的tp dp 混合运行模式下, 不需要多个推理进程间做协调同步
# 所以直接加锁,解锁即可
if g_infer_state_lock.dp_size != 1:
g_infer_state_lock.obj.infer_lock.acquire()
return

g_infer_state_lock.obj.set_group_wait_mark()
while True:
g_infer_state_lock.obj.infer_lock.acquire()
Expand Down
14 changes: 12 additions & 2 deletions lightllm/common/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ def alloc_kv_move_buffer(self, max_req_total_len):
)
return

def send_to_decode_node(self, token_indexes: List[int], mem_managers: List["Deepseek2MemoryManager"]):
def send_to_decode_node(
self, token_indexes: List[int], mem_managers: List["Deepseek2MemoryManager"], dp_size: int, dp_index: int
):
assert dp_size == 1
assert dp_index == 0

# 先将数据发送到指定的一张卡上的buffer,再发送。
import torch.distributed as dist

Expand All @@ -36,7 +41,12 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
move_buffer[:, :, :, :] = self.kv_buffer[layer_index, token_indexes, :, :]
return move_buffer

def receive_from_prefill_node(self, token_indexes: List[int], mem_managers: List["MemoryManager"]):
def receive_from_prefill_node(
self, token_indexes: List[int], mem_managers: List["MemoryManager"], dp_size: int, dp_index: int
):
assert dp_size == 1
assert dp_index == 0

# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
import torch.distributed as dist

Expand Down
42 changes: 39 additions & 3 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import os
import torch
import torch.distributed as dist
from typing import List
from lightllm.utils.log_utils import init_logger
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
Expand Down Expand Up @@ -31,7 +32,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
assert nccl_port is not None
logger.info(f"mem manger get nccl port: {str(nccl_port)}")

self.shared_can_use_token_num = SharedInt(f"{str(nccl_port)}_mem_manger_can_use_token_num")
rank_id = dist.get_rank()
self.shared_can_use_token_num = SharedInt(f"{str(nccl_port)}_mem_manger_can_use_token_num_{rank_id}")

self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self._init_buffers(
Expand Down Expand Up @@ -77,7 +79,17 @@ def alloc_kv_move_buffer(self, max_req_total_len):
)
return

def send_to_decode_node(self, token_indexes: List[int], mem_managers: List["MemoryManager"]):
def send_to_decode_node(
self, token_indexes: List[int], mem_managers: List["MemoryManager"], dp_size: int, dp_index: int
):
"""
dp_size 和 dp_index 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
普通tp模式下, dp_size 一定等于 1, dp_index 一定等于 0, 同时普通模式下, 这两个参数并不会
被真正使用
"""
assert dp_size == 1
assert dp_index == 0

# 先将数据发送到指定的一张卡上的buffer,再发送。
import torch.distributed as dist

Expand Down Expand Up @@ -105,7 +117,17 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
move_buffer[:, :, :, :] = self.kv_buffer[layer_index, token_indexes, :, :]
return move_buffer

def receive_from_prefill_node(self, token_indexes: List[int], mem_managers: List["MemoryManager"]):
def receive_from_prefill_node(
self, token_indexes: List[int], mem_managers: List["MemoryManager"], dp_size: int, dp_index: int
):
"""
dp_size 和 dp_index 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
普通tp模式下, dp_size 一定等于 1, dp_index 一定等于 0, 同时普通模式下, 这两个参数并不会
被真正使用
"""
assert dp_size == 1
assert dp_index == 0

# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
import torch.distributed as dist

Expand Down Expand Up @@ -223,3 +245,17 @@ def resize_mem(self, new_size):
self._free_buffers()
self._init_buffers(size, dtype, head_num, head_dim, layer_num)
return


class ReadOnlyStaticsMemoryManager:
"""
读取一些统计信息
"""

def __init__(self, nccl_port, tp_size) -> None:
self.shared_tp_infos = [
SharedInt(f"{str(nccl_port)}_mem_manger_can_use_token_num_{tp_index}") for tp_index in range(tp_size)
]

def get_unrefed_token_num(self, tp_index: int):
return self.shared_tp_infos[tp_index].get_value()
8 changes: 8 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def make_argument_parser() -> argparse.ArgumentParser:
"--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time"
)
parser.add_argument("--tp", type=int, default=1, help="model tp parral size, the default is 1")
parser.add_argument(
"--dp",
type=int,
default=1,
help="""This is just a useful parameter for deepseekv2. When
using the deepseekv2 model, set dp to be equal to the tp parameter. In other cases, please
do not set it and keep the default value as 1.""",
)
parser.add_argument(
"--max_req_total_len", type=int, default=2048 + 1024, help="the max value for req_input_len + req_output_len"
)
Expand Down
30 changes: 19 additions & 11 deletions lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,25 @@ async def healthcheck(request: Request):

@app.get("/token_load", summary="Get the current server's load of tokens")
async def token_load(request: Request):
return JSONResponse(
{
# 当前使用token量,估计的负载
"current_load": float(g_objs.shared_token_load.get_current_load()),
# 朴素估计的负载,简单将当前请求的输入和输出长度想加得到,目前已未使用,其值与dynamic_max_load一样。
"logical_max_load": float(g_objs.shared_token_load.get_logical_max_load()),
# 动态估计的最大负载,考虑请求中途退出的情况的负载
"dynamic_max_load": float(g_objs.shared_token_load.get_dynamic_max_load()),
},
status_code=200,
)
ans_dict = {
# 当前使用 token 量,估计的负载
"current_load": [
float(g_objs.shared_token_load.get_current_load(dp_index)) for dp_index in range(g_objs.args.dp)
],
# 朴素估计的负载,简单将当前请求的输入和输出长度想加得到,目前已未使用,其值与 dynamic_max_load 一样。
"logical_max_load": [
float(g_objs.shared_token_load.get_logical_max_load(dp_index)) for dp_index in range(g_objs.args.dp)
],
# 动态估计的最大负载,考虑请求中途退出的情况的负载
"dynamic_max_load": [
float(g_objs.shared_token_load.get_dynamic_max_load(dp_index)) for dp_index in range(g_objs.args.dp)
],
}

if g_objs.args.dp == 1:
ans_dict = {k: v[0] for k, v in ans_dict.items()}

return JSONResponse(ans_dict, status_code=200)


@app.post("/generate")
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def normal_or_p_d_start(g_objs):

start_submodule_processes(start_funcs=[start_health_check_process], start_args=[(args,)])

g_objs.shared_token_load = TokenLoad(f"{str(args.nccl_port)}_shared_token_load", 1)
g_objs.shared_token_load = TokenLoad(f"{str(args.nccl_port)}_shared_token_load", args.dp)

g_objs.server.install_signal_handlers()
uvicorn.run(
Expand Down
4 changes: 4 additions & 0 deletions lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ async def remove_pd(self, pd_info_json):
async def update_req_status(self, upkv_status: UpKVStatus):
try:
event = self.id_to_event[upkv_status.group_request_id]
event.upkv_status = upkv_status
event.set()
del self.id_to_event[upkv_status.group_request_id]
except:
Expand Down Expand Up @@ -175,6 +176,7 @@ async def fetch_stream(
old_max_new_tokens = sampling_params.max_new_tokens
sampling_params.max_new_tokens = 1
sampling_params.move_kv_to_decode_node = decode_node_dict if old_max_new_tokens != 1 else None
sampling_params.suggested_dp_index = None

req = await self._to_req_info(prompt, sampling_params, multimodal_params)
create_start_time = time.time()
Expand Down Expand Up @@ -212,6 +214,8 @@ async def fetch_stream(

sampling_params.move_kv_to_decode_node = None
sampling_params.max_new_tokens = old_max_new_tokens - 1
sampling_params.suggested_dp_index = event.upkv_status.dp_index

req = await self._to_req_info(prompt_ids, sampling_params, multimodal_params)
async with self.session.post(d_node.to_llm_url(), json=req) as response:
if response.status == 200:
Expand Down
26 changes: 20 additions & 6 deletions lightllm/server/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def to_rpc_obj(self):
"sampling_param": self.sample_params.to_dict(),
"multimodal_params": self.multimodal_params.to_dict(),
"req_status": self.req_status,
"dp_index": self.sample_params.suggested_dp_index,
}

def __repr__(self):
Expand Down Expand Up @@ -272,14 +273,15 @@ def __init__(


class Batch:
def __init__(self, batch_id, reqs: List[Req]):
def __init__(self, batch_id, reqs: List[Req], dp_size: int):
self.batch_id = batch_id
self.reqs = reqs
self.id_to_reqs = {req.request_id: req for req in reqs}
self.dp_size = dp_size

# 该参数只会在batch init, prefill, decode 后进行更新,并在剔除请求时减少
# 在 batch rpc init 之后才会被填充正确的值,初始化为 None
self.batch_decode_need_tokens = None
self.batch_decode_need_tokens = [None for _ in range(dp_size)]
return

def input_tokens(self):
Expand All @@ -293,8 +295,9 @@ def mark_and_get_finished_req_and_preupdate_status(self):
for req in self.reqs:
if req.finish_status.is_finished():
finished_req_ids.append(req.request_id)
req_dp_index = req.sample_params.suggested_dp_index
# 标记的时候,也同时更新一些这些请求被移除掉的更新量,有点dirty
self.batch_decode_need_tokens -= req.get_decode_need_tokens()
self.batch_decode_need_tokens[req_dp_index] -= req.get_decode_need_tokens()
else:
unfinished_req_ids.append(req.request_id)

Expand All @@ -311,17 +314,28 @@ def pop_req(self, req_id):
self.reqs = [req for req in self.reqs if req.request_id != req_id]
req = self.id_to_reqs[req_id]
self.id_to_reqs.pop(req_id)
self.batch_decode_need_tokens -= req.get_decode_need_tokens()
req_dp_index = req.sample_params.suggested_dp_index
self.batch_decode_need_tokens[req_dp_index] -= req.get_decode_need_tokens()
return

def is_clear(self):
return len(self.reqs) == 0

def merge(self, mini_batch):
def merge(self, mini_batch: "Batch"):
for _req in mini_batch.reqs:
self.reqs.append(_req)
self.id_to_reqs = {req.request_id: req for req in self.reqs}
for dp_index in range(self.dp_size):
self.batch_decode_need_tokens[dp_index] += mini_batch.batch_decode_need_tokens[dp_index]
return

def dp_merge(self, mini_batch: "Batch"):
if mini_batch is None:
return

for _req in mini_batch.reqs:
self.reqs.append(_req)
self.id_to_reqs = {req.request_id: req for req in self.reqs}
self.batch_decode_need_tokens += mini_batch.batch_decode_need_tokens
return

def __repr__(self):
Expand Down
13 changes: 11 additions & 2 deletions lightllm/server/pd_io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def to_llm_url(self):
class UpKVStatus:
type: str = "kv_move_status"
group_request_id: int = None
dp_index: int = None

def __post_init__(self):
if self.type != "kv_move_status":
Expand Down Expand Up @@ -71,6 +72,10 @@ class KVMoveTask:
move_kv_len: int # 因为 prompt cache 的原因,当prefill节点和decode节点沟通后,传输的kv的数量可能少于 prefill_value 的长度
prefill_node_id: str
decode_node: DecodeNodeInfo
# 保存prefill 和 decode 节点对应处理的dp_index, 如果是普通tp模式,这个值一定是0,
# 如果是deepseekv2的tp dp 混合模式, 才有真正的意义。
prefill_dp_index: int
decode_dp_index: int

def __post_init__(self):
if len(self.input_tokens) <= 0:
Expand All @@ -80,12 +85,16 @@ def __post_init__(self):

def to_prefill_log_info(self):
v_len = None if self.prefill_token_indexes is None else len(self.prefill_token_indexes)
log = f"id: {self.group_request_id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len}"
d_i = self.prefill_dp_index
id = self.group_request_id
log = f"id: {id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len} dp_index:{d_i}"
return log

def to_decode_log_info(self):
v_len = None if self.decode_token_indexes is None else len(self.decode_token_indexes)
log = f"id: {self.group_request_id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len}"
d_i = self.decode_dp_index
id = self.group_request_id
log = f"id: {id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len} dp_index:{d_i}"
return log

def id(self):
Expand Down
22 changes: 21 additions & 1 deletion lightllm/server/router/dynamic_prompt/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def release_mem(mem_index):
return


class RadixCacheReadOnlyClient:
class _RadixCacheReadOnlyClient:
"""
router 端只读用的客户端,用于从共享内存中读取树结构中的信息,用于进行prompt cache 的调度估计。
"""
Expand Down Expand Up @@ -391,6 +391,26 @@ def get_all_parent_shared_nodes(self, idx):
return ans_list


class RadixCacheReadOnlyClient:
def __init__(self, unique_name, total_token_num, tp_size):
self.tp_clients = [_RadixCacheReadOnlyClient(unique_name, total_token_num, tp_id) for tp_id in range(tp_size)]

def get_refed_tokens_num(self, index):
return self.tp_clients[index].get_refed_tokens_num()

def get_tree_total_tokens_num(self, index):
return self.tp_clients[index].get_tree_total_tokens_num()

def get_unrefed_tokens_num(self, index):
return self.tp_clients[index].get_unrefed_tokens_num()

def get_shared_node(self, tp_index, idx):
return self.tp_clients[tp_index].get_shared_node(idx)

def get_all_parent_shared_nodes(self, tp_index, idx):
return self.tp_clients[tp_index].get_all_parent_shared_nodes(idx)


# ///////////////////////////////////////////////////////////////////////////////

if __name__ == "__main__":
Expand Down
Loading