Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor reduce #626

Merged
merged 5 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 58 additions & 62 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
import torch
from lightllm.utils.log_utils import init_logger
from lightllm.distributed.parallel_state import graph_capture
from contextlib import nullcontext
from lightllm.distributed import lightllm_capture_graph

logger = init_logger(__name__)

Expand Down Expand Up @@ -32,8 +31,9 @@ def capture_decode(self, decode_func, input_ids, infer_state):
torch.cuda.synchronize()
decode_func(input_ids, infer_state)
torch.cuda.synchronize()
with torch.cuda.graph(graph_obj, pool=self.mempool, stream=self.stream):
predict_logics = decode_func(input_ids, infer_state)
with lightllm_capture_graph():
with torch.cuda.graph(graph_obj, pool=self.mempool):
predict_logics = decode_func(input_ids, infer_state)
self.graph[batch_size] = (graph_obj, input_ids, infer_state, predict_logics)
graph_obj.replay()
return predict_logics
Expand All @@ -49,65 +49,61 @@ def replay(self, input_ids, infer_state):
@torch.no_grad()
def warmup(self, model):
logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.")
LIGHTLLM_PYNCCL_ENABLE = os.getenv("LIGHTLLM_PYNCCL_ENABLE", "False").upper() in ["ON", "TRUE", "1"]
graph_capture_context_manager = graph_capture() if LIGHTLLM_PYNCCL_ENABLE else nullcontext()
with graph_capture_context_manager as graph_capture_context:
self.stream = graph_capture_context.stream if graph_capture_context is not None else None
for batch_size in range(self.max_batch_size, 0, -1):
# dummy prefill
prefill_input_len = 1
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
b_req_idx = model.req_manager.alloc(batch_size).int()
mem_indexes = model.mem_manager.alloc(len(dummy_input_ids))
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_start_loc = torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
total_token_num = prefill_input_len * batch_size
logics = model.forward(
batch_size,
total_token_num,
prefill_input_len,
dummy_input_ids,
mem_indexes,
b_req_idx,
b_start_loc,
b_seq_len,
b_ready_cache_len=b_ready_cache_len,
is_prefill=True,
multimodal_params=[],
)
mem_indexes = None
prob_out = torch.softmax(logics, dim=-1)
logics = None
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
prob_out = None
predict_ids = predict_ids.detach().cpu().numpy()
torch.cuda.empty_cache()
for batch_size in range(self.max_batch_size, 0, -1):
# dummy prefill
prefill_input_len = 1
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
b_req_idx = model.req_manager.alloc(batch_size).int()
mem_indexes = model.mem_manager.alloc(len(dummy_input_ids))
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_start_loc = torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
total_token_num = prefill_input_len * batch_size
logics = model.forward(
batch_size,
total_token_num,
prefill_input_len,
dummy_input_ids,
mem_indexes,
b_req_idx,
b_start_loc,
b_seq_len,
b_ready_cache_len=b_ready_cache_len,
is_prefill=True,
multimodal_params=[],
)
mem_indexes = None
prob_out = torch.softmax(logics, dim=-1)
logics = None
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
prob_out = None
predict_ids = predict_ids.detach().cpu().numpy()
torch.cuda.empty_cache()

# dummy decoding, capture the cudagraph
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
total_token_num += batch_size
b_seq_len += 1
mem_indexes = model.mem_manager.alloc(len(predict_ids))
logics = model.forward(
batch_size,
total_token_num,
prefill_input_len + 1,
torch.from_numpy(predict_ids).cuda().reshape(-1),
mem_indexes,
b_req_idx,
b_start_loc,
b_seq_len,
is_prefill=False,
)
mem_indexes = None
model.mem_manager.free_all()
model.req_manager.free_all()
# release local tensors
for var_name, var_value in list(locals().items()):
if isinstance(var_value, torch.Tensor):
del locals()[var_name]
torch.cuda.empty_cache()
# dummy decoding, capture the cudagraph
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
total_token_num += batch_size
b_seq_len += 1
mem_indexes = model.mem_manager.alloc(len(predict_ids))
logics = model.forward(
batch_size,
total_token_num,
prefill_input_len + 1,
torch.from_numpy(predict_ids).cuda().reshape(-1),
mem_indexes,
b_req_idx,
b_start_loc,
b_seq_len,
is_prefill=False,
)
mem_indexes = None
model.mem_manager.free_all()
model.req_manager.free_all()
# release local tensors
for var_name, var_value in list(locals().items()):
if isinstance(var_value, torch.Tensor):
del locals()[var_name]
torch.cuda.empty_cache()
logger.info(
f"Capture cudagraph success, batch_size <={self.max_batch_size} "
f"and max_len_in_batch <= {self.graph_max_len_in_batch} will infer with cudagraph."
Expand Down
1 change: 0 additions & 1 deletion lightllm/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .communication_op import *
from .parallel_state import *
67 changes: 42 additions & 25 deletions lightllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,59 @@

from typing import Any, Dict, Optional, Union

import os
import torch
import torch.distributed

from .parallel_state import get_tp_group, original_all_reduce
import torch.distributed as dist
from torch.distributed import ReduceOp
from lightllm.utils.log_utils import init_logger

original_all_reduce = torch.distributed.all_reduce
from contextlib import nullcontext, contextmanager

try:
HAS_VLLM = True
from .custom_all_reduce import CustomAllreduce
except:
HAS_VLLM = False

vllm_reduce = None
logger = init_logger(__name__)
# if op != ReduceOp.SUM or group != None or async_op != False:
# logger.warning("This function op, group, async_op will only run with default values")


def all_reduce(input_, op=ReduceOp.SUM, group=None, async_op=False):
if op != ReduceOp.SUM or group is not None or async_op:
original_all_reduce(input_, op, group, async_op)
@contextmanager
def lightllm_capture_graph():
if vllm_reduce is not None:
with vllm_reduce.capture():
yield
else:
input_.data = tensor_model_parallel_all_reduce(input_)


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
yield
pass


def tensor_model_parallel_all_gather(input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_tp_group().all_gather(input_, dim)

def _all_reduce(input_, op=ReduceOp.SUM, group=None, async_op=False):
if op != ReduceOp.SUM or group is not None or async_op or vllm_reduce is None:
original_all_reduce(input_, op, group, async_op)
else:
if vllm_reduce is not None:
can_use = vllm_reduce.should_custom_ar(input_)
if can_use:
input_.data = vllm_reduce.custom_all_reduce(input_)
return
original_all_reduce(input_, op, group, async_op)

def tensor_model_parallel_gather(input_: torch.Tensor, dst: int = 0, dim: int = -1) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)

def set_custom_reduce():
global vllm_reduce

def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0):
if not torch.distributed.is_initialized():
return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "False").upper() in [
"ON",
"TRUE",
"1",
]
if ENABLE_VLLM_REDUCE and HAS_VLLM:
world_size = dist.get_world_size()
ranks = list(range(world_size))
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
vllm_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device())
logger.info("Enable VLLM ALLReduce.")
dist.all_reduce = _all_reduce
Loading
Loading