Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewComm] No.13 compatiable upgrade for allreduce #57233

Merged
merged 8 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
#include <cstring>

#include "glog/logging.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/inference/tensorrt/plugin/c_allreduce_op_plugin.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif

namespace paddle {
namespace inference {
Expand Down Expand Up @@ -175,12 +182,48 @@ int CAllReducePluginDynamic::enqueue(
PADDLE_THROW(platform::errors::InvalidArgument("Invalid reduce type: %d",
red_type_));
}

auto comm = platform::NCCLCommContext::Instance().Get(ring_id_);
cudaStream_t custream = use_calc_stream_ ? stream : comm->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));

const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id_)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id_)));
auto comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id_)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
auto stream = comm_ctx->GetStream();
ncclRedOp_t nccl_red_type = ncclSum;
// comm_ctx->AllReduce(&inputs[0], inputs[0], nccl_red_type, stream);
phi::dynload::ncclAllReduce(sendbuff,
recvbuff,
numel,
dtype,
nccl_red_type,
comm_ctx->GetNcclComm(),
stream);
VLOG(3) << "new NCCLCommContext has ring_id_ " << ring_id_;
} else {
auto comm = platform::NCCLCommContext::Instance().Get(ring_id_);
cudaStream_t custream = use_calc_stream_ ? stream : comm->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(sendbuff,
recvbuff,
numel,
dtype,
nccl_red_type,
comm->comm(),
custream));
VLOG(3) << "old NCCLCommContext has ring_id_ " << ring_id_;
}
#endif
return (cudaGetLastError() != cudaSuccess);
}
Expand Down
15 changes: 15 additions & 0 deletions test/ir/inference/test_trt_convert_c_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ def test_run(self):
if len(results) == 2 and results[0] == "c_allreduce_out":
self.assertEqual(float(results[1]), self.target_value)

def test_allreduce_nccl_with_new_comm(self):
env = dict(os.environ)
env["CUDA_VISIBLE_DEVICES"] = "0,1"
env["FLAGS_dynamic_static_unified_comm"] = "1"
cmd = f"python -u -m paddle.distributed.fleet.launch --gpus 0,1 {self.script} {self.op_type} {self.precision}"
cmd = cmd.split(" ")

local_proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env)

local_out, local_err = local_proc.communicate()
for line in local_out.decode("utf-8").split("\n"):
results = line.split("=")
if len(results) == 2 and results[0] == "c_allreduce_out":
self.assertEqual(float(results[1]), self.target_value)


class TestMin(TestDistTRT):
def init_case(self):
Expand Down