diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 6e83074b..9c256b1a 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -2,6 +2,7 @@ import torch from lightllm.utils.log_utils import init_logger from lightllm.distributed.parallel_state import graph_capture +from contextlib import nullcontext logger = init_logger(__name__) @@ -48,8 +49,10 @@ def replay(self, input_ids, infer_state): @torch.no_grad() def warmup(self, model): logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.") - with graph_capture() as graph_capture_context: - self.stream = graph_capture_context.stream + LIGHTLLM_PYNCCL_ENABLE = os.getenv("LIGHTLLM_PYNCCL_ENABLE", "True").upper() in ["ON", "TRUE", "1"] + graph_capture_context_manager = graph_capture() if LIGHTLLM_PYNCCL_ENABLE else nullcontext() + with graph_capture_context_manager as graph_capture_context: + self.stream = graph_capture_context.stream if graph_capture_context is not None else None for batch_size in range(self.max_batch_size, 0, -1): # dummy prefill prefill_input_len = 1 diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index 7c0fda20..70303823 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -22,7 +22,7 @@ import torch import torch.distributed -from .parallel_state import get_tp_group +from .parallel_state import get_tp_group, original_all_reduce from torch.distributed import ReduceOp from lightllm.utils.log_utils import init_logger @@ -33,7 +33,10 @@ def all_reduce(input_, op=ReduceOp.SUM, group=None, async_op=False): - input_.data = tensor_model_parallel_all_reduce(input_) + if op != ReduceOp.SUM or group is not None or async_op: + original_all_reduce(input_, op, group, async_op) + else: + input_.data = tensor_model_parallel_all_reduce(input_) def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 4a93325f..21f510b9 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -83,8 +83,12 @@ def init_model(self, kvargs): max_total_token_num = kvargs["max_total_token_num"] torch.cuda.set_device(self.tp_rank) - LIGHTLLM_DISTRIBUTED_ENABLE = os.getenv("LIGHTLLM_DISTRIBUTED_ENABLE", not self.disable_cudagraph) - if LIGHTLLM_DISTRIBUTED_ENABLE: + LIGHTLLM_PYNCCL_ENABLE = os.getenv("LIGHTLLM_PYNCCL_ENABLE", str(not self.disable_cudagraph)).upper() in [ + "ON", + "TRUE", + "1", + ] + if LIGHTLLM_PYNCCL_ENABLE: # Multiple nodes are not currently supported, so local_rank == rank init_distributed_environment( backend="nccl", @@ -112,7 +116,7 @@ def init_model(self, kvargs): self.infer_state_lock = g_infer_state_lock # 防止InferStateLock 中的全局共享信息被重复异常初始化,导致同步异常的问题。 # 所以做一次barrier等待 - if LIGHTLLM_DISTRIBUTED_ENABLE: + if LIGHTLLM_PYNCCL_ENABLE: self.tp_group.barrier() else: dist.barrier() diff --git a/test/model/model_infer.py b/test/model/model_infer.py index f73af2bf..bb6f6e26 100644 --- a/test/model/model_infer.py +++ b/test/model/model_infer.py @@ -1,6 +1,7 @@ import numpy as np from multiprocessing import Queue import multiprocessing +import os def test_model_inference(world_size, model_class, batch_size, input_len, output_len, extra_model_kvargs): @@ -41,19 +42,36 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_ get_tp_group, init_distributed_environment, initialize_model_parallel, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, + all_reduce, ) + import torch.distributed as dist rank_id = model_kvargs["tp_rank"] world_size = model_kvargs["world_size"] - init_distributed_environment( - backend="nccl", world_size=world_size, rank=rank_id, distributed_init_method="tcp://127.0.0.1:28765" - ) - initialize_model_parallel(tensor_model_parallel_size=world_size) - tp_group = get_tp_group() + disable_cudagraph = model_kvargs.get("disable_cudagraph", False) torch.cuda.set_device(rank_id) + LIGHTLLM_PYNCCL_ENABLE = os.getenv("LIGHTLLM_PYNCCL_ENABLE", str(not disable_cudagraph)).upper() in [ + "ON", + "TRUE", + "1", + ] + if LIGHTLLM_PYNCCL_ENABLE: + init_distributed_environment( + backend="nccl", world_size=world_size, rank=rank_id, distributed_init_method="tcp://127.0.0.1:28765" + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + tp_group = get_tp_group() + dist.all_reduce = all_reduce + dist.get_rank = get_tensor_model_parallel_rank + dist.get_world_size = get_tensor_model_parallel_world_size + tp_group.barrier() + else: + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size) + dist.barrier() - tp_group.barrier() torch.cuda.empty_cache() model_part = model_class(model_kvargs) diff --git a/test/model/test_settings/model_infer_batchs.py b/test/model/test_settings/model_infer_batchs.py index c3c57566..84d61c50 100644 --- a/test/model/test_settings/model_infer_batchs.py +++ b/test/model/test_settings/model_infer_batchs.py @@ -68,19 +68,36 @@ def tppart_model_infer(model_class, model_kvargs, batch_sizes, input_len, output get_tp_group, init_distributed_environment, initialize_model_parallel, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, + all_reduce, ) + import torch.distributed as dist rank_id = model_kvargs["tp_rank"] world_size = model_kvargs["world_size"] - init_distributed_environment( - backend="nccl", world_size=world_size, rank=rank_id, distributed_init_method="tcp://127.0.0.1:28765" - ) - initialize_model_parallel(tensor_model_parallel_size=world_size) - tp_group = get_tp_group() + disable_cudagraph = model_kvargs.get("disable_cudagraph", False) torch.cuda.set_device(rank_id) + LIGHTLLM_PYNCCL_ENABLE = os.getenv("LIGHTLLM_PYNCCL_ENABLE", str(not disable_cudagraph)).upper() in [ + "ON", + "TRUE", + "1", + ] + if LIGHTLLM_PYNCCL_ENABLE: + init_distributed_environment( + backend="nccl", world_size=world_size, rank=rank_id, distributed_init_method="tcp://127.0.0.1:28765" + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + tp_group = get_tp_group() + dist.all_reduce = all_reduce + dist.get_rank = get_tensor_model_parallel_rank + dist.get_world_size = get_tensor_model_parallel_world_size + tp_group.barrier() + else: + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size) + dist.barrier() - tp_group.barrier() torch.cuda.empty_cache() model_part = model_class(model_kvargs)