Skip to content

Commit

Permalink
complete all_reduce and test (#619)
Browse files Browse the repository at this point in the history
  • Loading branch information
WANDY666 authored Nov 27, 2024
1 parent bd27a96 commit af0d743
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 19 deletions.
7 changes: 5 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from lightllm.utils.log_utils import init_logger
from lightllm.distributed.parallel_state import graph_capture
from contextlib import nullcontext

logger = init_logger(__name__)

Expand Down Expand Up @@ -48,8 +49,10 @@ def replay(self, input_ids, infer_state):
@torch.no_grad()
def warmup(self, model):
logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.")
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
LIGHTLLM_PYNCCL_ENABLE = os.getenv("LIGHTLLM_PYNCCL_ENABLE", "True").upper() in ["ON", "TRUE", "1"]
graph_capture_context_manager = graph_capture() if LIGHTLLM_PYNCCL_ENABLE else nullcontext()
with graph_capture_context_manager as graph_capture_context:
self.stream = graph_capture_context.stream if graph_capture_context is not None else None
for batch_size in range(self.max_batch_size, 0, -1):
# dummy prefill
prefill_input_len = 1
Expand Down
7 changes: 5 additions & 2 deletions lightllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
import torch.distributed

from .parallel_state import get_tp_group
from .parallel_state import get_tp_group, original_all_reduce
from torch.distributed import ReduceOp
from lightllm.utils.log_utils import init_logger

Expand All @@ -33,7 +33,10 @@


def all_reduce(input_, op=ReduceOp.SUM, group=None, async_op=False):
input_.data = tensor_model_parallel_all_reduce(input_)
if op != ReduceOp.SUM or group is not None or async_op:
original_all_reduce(input_, op, group, async_op)
else:
input_.data = tensor_model_parallel_all_reduce(input_)


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
Expand Down
10 changes: 7 additions & 3 deletions lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,12 @@ def init_model(self, kvargs):
max_total_token_num = kvargs["max_total_token_num"]

torch.cuda.set_device(self.tp_rank)
LIGHTLLM_DISTRIBUTED_ENABLE = os.getenv("LIGHTLLM_DISTRIBUTED_ENABLE", not self.disable_cudagraph)
if LIGHTLLM_DISTRIBUTED_ENABLE:
LIGHTLLM_PYNCCL_ENABLE = os.getenv("LIGHTLLM_PYNCCL_ENABLE", str(not self.disable_cudagraph)).upper() in [
"ON",
"TRUE",
"1",
]
if LIGHTLLM_PYNCCL_ENABLE:
# Multiple nodes are not currently supported, so local_rank == rank
init_distributed_environment(
backend="nccl",
Expand Down Expand Up @@ -112,7 +116,7 @@ def init_model(self, kvargs):
self.infer_state_lock = g_infer_state_lock
# 防止InferStateLock 中的全局共享信息被重复异常初始化,导致同步异常的问题。
# 所以做一次barrier等待
if LIGHTLLM_DISTRIBUTED_ENABLE:
if LIGHTLLM_PYNCCL_ENABLE:
self.tp_group.barrier()
else:
dist.barrier()
Expand Down
30 changes: 24 additions & 6 deletions test/model/model_infer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from multiprocessing import Queue
import multiprocessing
import os


def test_model_inference(world_size, model_class, batch_size, input_len, output_len, extra_model_kvargs):
Expand Down Expand Up @@ -41,19 +42,36 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
all_reduce,
)
import torch.distributed as dist

rank_id = model_kvargs["tp_rank"]
world_size = model_kvargs["world_size"]

init_distributed_environment(
backend="nccl", world_size=world_size, rank=rank_id, distributed_init_method="tcp://127.0.0.1:28765"
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
tp_group = get_tp_group()
disable_cudagraph = model_kvargs.get("disable_cudagraph", False)
torch.cuda.set_device(rank_id)
LIGHTLLM_PYNCCL_ENABLE = os.getenv("LIGHTLLM_PYNCCL_ENABLE", str(not disable_cudagraph)).upper() in [
"ON",
"TRUE",
"1",
]
if LIGHTLLM_PYNCCL_ENABLE:
init_distributed_environment(
backend="nccl", world_size=world_size, rank=rank_id, distributed_init_method="tcp://127.0.0.1:28765"
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
tp_group = get_tp_group()
dist.all_reduce = all_reduce
dist.get_rank = get_tensor_model_parallel_rank
dist.get_world_size = get_tensor_model_parallel_world_size
tp_group.barrier()
else:
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size)
dist.barrier()

tp_group.barrier()
torch.cuda.empty_cache()

model_part = model_class(model_kvargs)
Expand Down
29 changes: 23 additions & 6 deletions test/model/test_settings/model_infer_batchs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,36 @@ def tppart_model_infer(model_class, model_kvargs, batch_sizes, input_len, output
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
all_reduce,
)
import torch.distributed as dist

rank_id = model_kvargs["tp_rank"]
world_size = model_kvargs["world_size"]

init_distributed_environment(
backend="nccl", world_size=world_size, rank=rank_id, distributed_init_method="tcp://127.0.0.1:28765"
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
tp_group = get_tp_group()
disable_cudagraph = model_kvargs.get("disable_cudagraph", False)
torch.cuda.set_device(rank_id)
LIGHTLLM_PYNCCL_ENABLE = os.getenv("LIGHTLLM_PYNCCL_ENABLE", str(not disable_cudagraph)).upper() in [
"ON",
"TRUE",
"1",
]
if LIGHTLLM_PYNCCL_ENABLE:
init_distributed_environment(
backend="nccl", world_size=world_size, rank=rank_id, distributed_init_method="tcp://127.0.0.1:28765"
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
tp_group = get_tp_group()
dist.all_reduce = all_reduce
dist.get_rank = get_tensor_model_parallel_rank
dist.get_world_size = get_tensor_model_parallel_world_size
tp_group.barrier()
else:
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size)
dist.barrier()

tp_group.barrier()
torch.cuda.empty_cache()

model_part = model_class(model_kvargs)
Expand Down

0 comments on commit af0d743

Please sign in to comment.