From 1c8cab434b563f2057e0a7737c5c3042602f6e0d Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 12 Sep 2023 16:28:40 +0800 Subject: [PATCH 1/6] [NewComm] update allreduce_op --- .../tensorrt/plugin/c_allreduce_op_plugin.cu | 46 ++++++++++++++++--- .../inference/test_trt_convert_c_allreduce.py | 15 ++++++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu index 8ec06071301c9..06dbb043192a0 100644 --- a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu @@ -15,8 +15,13 @@ #include #include "glog/logging.h" +#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h" #include "paddle/fluid/platform/collective_helper.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); namespace paddle { namespace inference { @@ -175,12 +180,41 @@ int CAllReducePluginDynamic::enqueue( PADDLE_THROW(platform::errors::InvalidArgument("Invalid reduce type: %d", red_type_)); } - - auto comm = platform::NCCLCommContext::Instance().Get(ring_id_); - cudaStream_t custream = use_calc_stream_ ? stream : comm->stream(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); - + if (FLAGS_dynamic_static_unified_comm) { + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id_)), + true, + platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(ring_id_))); + auto comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(ring_id_))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + auto stream = comm_ctx->GetStream(); + ncclRedOp_t nccl_red_type = ncclSum; + comm_ctx->AllReduce(out, *in, nccl_red_type, stream); + VLOG(3) << "new NCCLCommContext has ring_id_ " << ring_id_; + } else { + auto comm = platform::NCCLCommContext::Instance().Get(ring_id_); + cudaStream_t custream = use_calc_stream_ ? stream : comm->stream(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(sendbuff, + recvbuff, + numel, + dtype, + nccl_red_type, + comm->comm(), + custream)); + VLOG(3) << "old NCCLCommContext has ring_id_ " << ring_id_; + } #endif return (cudaGetLastError() != cudaSuccess); } diff --git a/test/ir/inference/test_trt_convert_c_allreduce.py b/test/ir/inference/test_trt_convert_c_allreduce.py index 6e3bc5ae9a894..0412ebb2099ef 100644 --- a/test/ir/inference/test_trt_convert_c_allreduce.py +++ b/test/ir/inference/test_trt_convert_c_allreduce.py @@ -43,6 +43,21 @@ def test_run(self): if len(results) == 2 and results[0] == "c_allreduce_out": self.assertEqual(float(results[1]), self.target_value) + def test_allreduce_nccl_with_new_comm(self): + env = dict(os.environ) + env["CUDA_VISIBLE_DEVICES"] = "0,1" + env["FLAGS_dynamic_static_unified_comm"] = "1" + cmd = f"python -u -m paddle.distributed.fleet.launch --gpus 0,1 {self.script} {self.op_type} {self.precision}" + cmd = cmd.split(" ") + + local_proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env) + + local_out, local_err = local_proc.communicate() + for line in local_out.decode("utf-8").split("\n"): + results = line.split("=") + if len(results) == 2 and results[0] == "c_allreduce_out": + self.assertEqual(float(results[1]), self.target_value) + class TestMin(TestDistTRT): def init_case(self): From 151ea63da9c990b91b63c0faced392e0d58d5e89 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 12 Sep 2023 17:58:10 +0800 Subject: [PATCH 2/6] fix --- .../fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu index 06dbb043192a0..ca79bbd5c75f1 100644 --- a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu @@ -180,9 +180,9 @@ int CAllReducePluginDynamic::enqueue( PADDLE_THROW(platform::errors::InvalidArgument("Invalid reduce type: %d", red_type_)); } + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); if (FLAGS_dynamic_static_unified_comm) { - const auto& comm_context_manager = - phi::distributed::CommContextManager::GetInstance(); PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id_)), true, platform::errors::InvalidArgument( From c9af2774ce4ccd59f5c6d124cc97795154c55fac Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 12 Sep 2023 18:14:36 +0800 Subject: [PATCH 3/6] fix include --- paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu index ca79bbd5c75f1..8f70fc43a901d 100644 --- a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu @@ -19,9 +19,11 @@ #include "paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/phi/core/distributed/comm_context_manager.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/flags.h" PHI_DECLARE_bool(dynamic_static_unified_comm); +#endif namespace paddle { namespace inference { From c34a674d12d96160ec2e30eb19317514ae5512b4 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 12 Sep 2023 20:14:50 +0800 Subject: [PATCH 4/6] fix AllReduce --- paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu index 8f70fc43a901d..44616d20841e7 100644 --- a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu @@ -203,7 +203,7 @@ int CAllReducePluginDynamic::enqueue( "has ring_id attr.")); auto stream = comm_ctx->GetStream(); ncclRedOp_t nccl_red_type = ncclSum; - comm_ctx->AllReduce(out, *in, nccl_red_type, stream); + comm_ctx->AllReduce(output_desc, input_desc, nccl_red_type, stream); VLOG(3) << "new NCCLCommContext has ring_id_ " << ring_id_; } else { auto comm = platform::NCCLCommContext::Instance().Get(ring_id_); From b1a351d3ed610981b7c420c988041e83fdd61fdf Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Thu, 14 Sep 2023 22:52:54 +0800 Subject: [PATCH 5/6] try fix --- paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu index 44616d20841e7..6d7eca27fc722 100644 --- a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu @@ -203,7 +203,7 @@ int CAllReducePluginDynamic::enqueue( "has ring_id attr.")); auto stream = comm_ctx->GetStream(); ncclRedOp_t nccl_red_type = ncclSum; - comm_ctx->AllReduce(output_desc, input_desc, nccl_red_type, stream); + comm_ctx->AllReduce(&inputs[0], inputs[0], nccl_red_type, stream); VLOG(3) << "new NCCLCommContext has ring_id_ " << ring_id_; } else { auto comm = platform::NCCLCommContext::Instance().Get(ring_id_); From 13190788b8b1ec193af7506941aea5b408dcb93f Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Fri, 22 Sep 2023 00:02:23 +0800 Subject: [PATCH 6/6] try ci --- .../inference/tensorrt/plugin/c_allreduce_op_plugin.cu | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu index 6d7eca27fc722..1033dc65f2dcc 100644 --- a/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu @@ -203,7 +203,14 @@ int CAllReducePluginDynamic::enqueue( "has ring_id attr.")); auto stream = comm_ctx->GetStream(); ncclRedOp_t nccl_red_type = ncclSum; - comm_ctx->AllReduce(&inputs[0], inputs[0], nccl_red_type, stream); + // comm_ctx->AllReduce(&inputs[0], inputs[0], nccl_red_type, stream); + phi::dynload::ncclAllReduce(sendbuff, + recvbuff, + numel, + dtype, + nccl_red_type, + comm_ctx->GetNcclComm(), + stream); VLOG(3) << "new NCCLCommContext has ring_id_ " << ring_id_; } else { auto comm = platform::NCCLCommContext::Instance().Get(ring_id_);