Skip to content

Commit

Permalink
[AutoParallel] Support new communication library for hogwild_worker, …
Browse files Browse the repository at this point in the history
…graph_helper, data_norm_op and margin_cross_entropy_op. (PaddlePaddle#57519)
  • Loading branch information
GhostScreaming authored and jiahy0825 committed Oct 16, 2023
1 parent 10b0b0b commit 90886e4
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 76 deletions.
69 changes: 59 additions & 10 deletions paddle/fluid/framework/hogwild_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/flags.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif

#if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
Expand All @@ -30,7 +37,6 @@ limitations under the License. */
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
#include "paddle/phi/core/flags.h"

PHI_DECLARE_bool(enable_exit_when_partial_worker);

Expand Down Expand Up @@ -152,16 +158,59 @@ bool HogwildWorker::CheckBatchNum(int flag) {
}
g_barrier.wait();
float *stat_ptr = sync_stat_.data<float>();
auto comm =
platform::NCCLCommContext::Instance().Get(0, place_.GetDeviceId());
int nranks = 0;
int ring_id = 0;
platform::NCCLComm *comm = nullptr;
const auto &comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
phi::distributed::NCCLCommContext *comm_ctx = nullptr;
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
comm_context_manager.Get(std::to_string(ring_id)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
nranks = comm_ctx->GetSize();
} else {
comm = platform::NCCLCommContext::Instance().Get(ring_id,
place_.GetDeviceId());
nranks = comm->nranks();
}

auto stream = static_cast<phi::GPUContext *>(dev_ctx_)->stream();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(&stat_ptr[flag],
&stat_ptr[2],
1,
ncclFloat32,
ncclProd,
comm->comm(),
stream));
if (comm_ctx) {
// comm_ctx->AllReduce only support allreduce on the whole tensor,
// single element is not supported now.
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllReduce(&stat_ptr[flag],
&stat_ptr[2],
1,
ncclFloat32,
ncclProd,
comm_ctx->GetNcclComm(),
stream));

} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(&stat_ptr[flag],
&stat_ptr[2],
1,
ncclFloat32,
ncclProd,
comm->comm(),
stream));
}

PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&ret, // output
&stat_ptr[2],
sizeof(float),
Expand Down
17 changes: 14 additions & 3 deletions paddle/fluid/framework/ir/graph_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
#include "paddle/fluid/platform/flags.h"
PD_DECLARE_bool(convert_all_blocks);
Expand Down Expand Up @@ -564,9 +568,16 @@ void ReplaceAllReduceOp(const Node &node,
all_reduce_op_desc.SetType("c_allreduce_sum");
all_reduce_op_desc.SetInput("X", {all_reduce_var_name});
all_reduce_op_desc.SetOutput("Out", {all_reduce_var_name});

int ring_id = platform::NCCLCommContext::Instance().GetRingId(
dynamic_cast<details::NCCLOpHandleBase *>(&op_handle)->GetComm());
int ring_id = -1;
if (FLAGS_dynamic_static_unified_comm) {
ring_id = phi::distributed::CommContextManager::GetInstance().GetRingId(
dynamic_cast<details::NCCLOpHandleBase *>(&op_handle)->GetComm());
VLOG(3) << "New CommContextManager gets ring_id: " << ring_id;
} else {
ring_id = platform::NCCLCommContext::Instance().GetRingId(
dynamic_cast<details::NCCLOpHandleBase *>(&op_handle)->GetComm());
VLOG(3) << "Old NCCLCommContext gets ring_id: " << ring_id;
}
all_reduce_op_desc.SetAttr("ring_id", ring_id);
all_reduce_op_desc.SetAttr("use_calc_stream", false);
all_reduce_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
Expand Down
115 changes: 90 additions & 25 deletions paddle/fluid/operators/data_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif

namespace paddle {
Expand Down Expand Up @@ -213,31 +217,92 @@ class DataNormGradKernel<T, phi::GPUContext> : public framework::OpKernel<T> {

if (need_sync_stats) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace());
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_size),
reinterpret_cast<void *>(d_batch_size),
C,
platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())),
ncclSum,
comm->comm(),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_sum),
reinterpret_cast<void *>(d_batch_sum),
C,
platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())),
ncclSum,
comm->comm(),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_square_sum),
reinterpret_cast<void *>(d_batch_square_sum),
C,
platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())),
ncclSum,
comm->comm(),
stream));
int rid = 0;
platform::NCCLComm *comm = nullptr;
const auto &comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
phi::distributed::NCCLCommContext *comm_ctx = nullptr;
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(
comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
} else {
comm = paddle::platform::NCCLCommContext::Instance().Get(
rid, ctx.GetPlace());
}

if (comm_ctx) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_size),
reinterpret_cast<void *>(d_batch_size),
C,
platform::ToNCCLDataType(
framework::TransToProtoVarType(x->dtype())),
ncclSum,
comm_ctx->GetNcclComm(),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_sum),
reinterpret_cast<void *>(d_batch_sum),
C,
platform::ToNCCLDataType(
framework::TransToProtoVarType(x->dtype())),
ncclSum,
comm_ctx->GetNcclComm(),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_square_sum),
reinterpret_cast<void *>(d_batch_square_sum),
C,
platform::ToNCCLDataType(
framework::TransToProtoVarType(x->dtype())),
ncclSum,
comm_ctx->GetNcclComm(),
stream));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_size),
reinterpret_cast<void *>(d_batch_size),
C,
platform::ToNCCLDataType(
framework::TransToProtoVarType(x->dtype())),
ncclSum,
comm->comm(),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_sum),
reinterpret_cast<void *>(d_batch_sum),
C,
platform::ToNCCLDataType(
framework::TransToProtoVarType(x->dtype())),
ncclSum,
comm->comm(),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_square_sum),
reinterpret_cast<void *>(d_batch_square_sum),
C,
platform::ToNCCLDataType(
framework::TransToProtoVarType(x->dtype())),
ncclSum,
comm->comm(),
stream));
}
platform::GpuStreamSync(stream);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
Expand Down
Loading

0 comments on commit 90886e4

Please sign in to comment.