From d522a86ec6a5f9a8f9139af8a084f9c81f30a0db Mon Sep 17 00:00:00 2001 From: Ghost Screaming Date: Thu, 21 Sep 2023 11:10:14 +0800 Subject: [PATCH] [AutoParallel] Support new communication library for hogwild_worker, graph_helper, data_norm_op and margin_cross_entropy_op. (#57519) --- paddle/fluid/framework/hogwild_worker.cc | 69 +++++++-- paddle/fluid/framework/ir/graph_helper.cc | 17 +- paddle/fluid/operators/data_norm_op.cu | 115 +++++++++++--- .../operators/margin_cross_entropy_op.cu | 145 +++++++++++++----- .../core/distributed/comm_context_manager.cc | 14 ++ .../core/distributed/comm_context_manager.h | 8 + 6 files changed, 292 insertions(+), 76 deletions(-) diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index cc2c70506a34cf..e638fbcb8a54dc 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -22,6 +22,13 @@ limitations under the License. */ #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/lodtensor_printer.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/flags.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/phi/core/distributed/nccl_comm_context.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); +#endif #if defined PADDLE_WITH_PSCORE #include "paddle/fluid/distributed/ps/service/communicator/communicator.h" @@ -30,7 +37,6 @@ limitations under the License. */ #if defined(PADDLE_WITH_GLOO) #include "paddle/fluid/framework/fleet/gloo_wrapper.h" #endif -#include "paddle/phi/core/flags.h" PHI_DECLARE_bool(enable_exit_when_partial_worker); @@ -152,16 +158,59 @@ bool HogwildWorker::CheckBatchNum(int flag) { } g_barrier.wait(); float *stat_ptr = sync_stat_.data(); - auto comm = - platform::NCCLCommContext::Instance().Get(0, place_.GetDeviceId()); + int nranks = 0; + int ring_id = 0; + platform::NCCLComm *comm = nullptr; + const auto &comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + phi::distributed::NCCLCommContext *comm_ctx = nullptr; + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), + true, + platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(ring_id))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(ring_id))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + nranks = comm_ctx->GetSize(); + } else { + comm = platform::NCCLCommContext::Instance().Get(ring_id, + place_.GetDeviceId()); + nranks = comm->nranks(); + } + auto stream = static_cast(dev_ctx_)->stream(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(&stat_ptr[flag], - &stat_ptr[2], - 1, - ncclFloat32, - ncclProd, - comm->comm(), - stream)); + if (comm_ctx) { + // comm_ctx->AllReduce only support allreduce on the whole tensor, + // single element is not supported now. + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclAllReduce(&stat_ptr[flag], + &stat_ptr[2], + 1, + ncclFloat32, + ncclProd, + comm_ctx->GetNcclComm(), + stream)); + + } else { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(&stat_ptr[flag], + &stat_ptr[2], + 1, + ncclFloat32, + ncclProd, + comm->comm(), + stream)); + } + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&ret, // output &stat_ptr[2], sizeof(float), diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index b322e3f8bce287..5d7054721db53a 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -23,10 +23,14 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/program_utils.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/framework/details/nccl_op_handle.h" #include "paddle/fluid/platform/collective_helper.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif #include "paddle/fluid/platform/flags.h" PD_DECLARE_bool(convert_all_blocks); @@ -564,9 +568,16 @@ void ReplaceAllReduceOp(const Node &node, all_reduce_op_desc.SetType("c_allreduce_sum"); all_reduce_op_desc.SetInput("X", {all_reduce_var_name}); all_reduce_op_desc.SetOutput("Out", {all_reduce_var_name}); - - int ring_id = platform::NCCLCommContext::Instance().GetRingId( - dynamic_cast(&op_handle)->GetComm()); + int ring_id = -1; + if (FLAGS_dynamic_static_unified_comm) { + ring_id = phi::distributed::CommContextManager::GetInstance().GetRingId( + dynamic_cast(&op_handle)->GetComm()); + VLOG(3) << "New CommContextManager gets ring_id: " << ring_id; + } else { + ring_id = platform::NCCLCommContext::Instance().GetRingId( + dynamic_cast(&op_handle)->GetComm()); + VLOG(3) << "Old NCCLCommContext gets ring_id: " << ring_id; + } all_reduce_op_desc.SetAttr("ring_id", ring_id); all_reduce_op_desc.SetAttr("use_calc_stream", false); all_reduce_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), diff --git a/paddle/fluid/operators/data_norm_op.cu b/paddle/fluid/operators/data_norm_op.cu index a212bc0ee94782..509c067e24e421 100644 --- a/paddle/fluid/operators/data_norm_op.cu +++ b/paddle/fluid/operators/data_norm_op.cu @@ -21,6 +21,10 @@ limitations under the License. */ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif namespace paddle { @@ -213,31 +217,92 @@ class DataNormGradKernel : public framework::OpKernel { if (need_sync_stats) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace()); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - reinterpret_cast(d_batch_size), - reinterpret_cast(d_batch_size), - C, - platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())), - ncclSum, - comm->comm(), - stream)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - reinterpret_cast(d_batch_sum), - reinterpret_cast(d_batch_sum), - C, - platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())), - ncclSum, - comm->comm(), - stream)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - reinterpret_cast(d_batch_square_sum), - reinterpret_cast(d_batch_square_sum), - C, - platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())), - ncclSum, - comm->comm(), - stream)); + int rid = 0; + platform::NCCLComm *comm = nullptr; + const auto &comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + phi::distributed::NCCLCommContext *comm_ctx = nullptr; + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ( + comm_context_manager.Has(std::to_string(rid)), + true, + platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(rid))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(rid))); + PADDLE_ENFORCE_NE( + comm_ctx, + nullptr, + platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + } else { + comm = paddle::platform::NCCLCommContext::Instance().Get( + rid, ctx.GetPlace()); + } + + if (comm_ctx) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + reinterpret_cast(d_batch_size), + reinterpret_cast(d_batch_size), + C, + platform::ToNCCLDataType( + framework::TransToProtoVarType(x->dtype())), + ncclSum, + comm_ctx->GetNcclComm(), + stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + reinterpret_cast(d_batch_sum), + reinterpret_cast(d_batch_sum), + C, + platform::ToNCCLDataType( + framework::TransToProtoVarType(x->dtype())), + ncclSum, + comm_ctx->GetNcclComm(), + stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + reinterpret_cast(d_batch_square_sum), + reinterpret_cast(d_batch_square_sum), + C, + platform::ToNCCLDataType( + framework::TransToProtoVarType(x->dtype())), + ncclSum, + comm_ctx->GetNcclComm(), + stream)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + reinterpret_cast(d_batch_size), + reinterpret_cast(d_batch_size), + C, + platform::ToNCCLDataType( + framework::TransToProtoVarType(x->dtype())), + ncclSum, + comm->comm(), + stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + reinterpret_cast(d_batch_sum), + reinterpret_cast(d_batch_sum), + C, + platform::ToNCCLDataType( + framework::TransToProtoVarType(x->dtype())), + ncclSum, + comm->comm(), + stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + reinterpret_cast(d_batch_square_sum), + reinterpret_cast(d_batch_square_sum), + C, + platform::ToNCCLDataType( + framework::TransToProtoVarType(x->dtype())), + ncclSum, + comm->comm(), + stream)); + } platform::GpuStreamSync(stream); #else PADDLE_THROW(platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/operators/margin_cross_entropy_op.cu b/paddle/fluid/operators/margin_cross_entropy_op.cu index d741bc5b425495..75ef56accb10b4 100644 --- a/paddle/fluid/operators/margin_cross_entropy_op.cu +++ b/paddle/fluid/operators/margin_cross_entropy_op.cu @@ -30,6 +30,7 @@ namespace cub = hipcub; #include "paddle/phi/kernels/margin_cross_entropy_grad_kernel.h" #include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" @@ -39,6 +40,9 @@ namespace cub = hipcub; #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif #include "paddle/phi/backends/gpu/gpu_context.h" @@ -87,21 +91,50 @@ void GetClassInterval(const gpuStream_t& stream, auto task = pg->AllReduce(in_tensor, out_tensor, opts); task->Wait(); } else { - const auto& comm = - paddle::platform::NCCLCommContext::Instance().Get(rid, place); + paddle::platform::NCCLComm* comm = nullptr; + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + phi::distributed::NCCLCommContext* comm_ctx = nullptr; + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)), + true, + paddle::platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(rid))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(rid))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + paddle::platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + } else { + comm = paddle::platform::NCCLCommContext::Instance().Get(rid, place); + } + // use global calculate stream const auto calcu_stream = static_cast(phi::DeviceContextPool::Instance().Get(place)) ->stream(); - - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( - num_classes_per_device_ptr, - num_classes_per_device_ptr, - num_classes_per_device.numel(), - phi::ToNCCLDataType(num_classes_per_device.dtype()), - ncclSum, - comm->comm(), - calcu_stream)); + if (comm_ctx) { + comm_ctx->AllReduce(&num_classes_per_device, + num_classes_per_device, + ncclSum, + calcu_stream); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( + num_classes_per_device_ptr, + num_classes_per_device_ptr, + num_classes_per_device.numel(), + phi::ToNCCLDataType(num_classes_per_device.dtype()), + ncclSum, + comm->comm(), + calcu_stream)); + } } class_interval->Resize({nranks + 1}); @@ -238,7 +271,10 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, const auto& place = dev_ctx.GetPlace(); // old code #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - paddle::platform::NCCLComm* comm; + paddle::platform::NCCLComm* comm = nullptr; + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + phi::distributed::NCCLCommContext* comm_ctx = nullptr; paddle::distributed::ProcessGroup* pg = nullptr; gpuStream_t stream; if (nranks > 1) { @@ -247,8 +283,29 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, // Use ProcessGroup pg = map->get(ring_id); } else { - comm = paddle::platform::NCCLCommContext::Instance().Get(ring_id, place); - + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ( + comm_context_manager.Has(std::to_string(ring_id)), + true, + paddle::platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(ring_id))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(ring_id))); + PADDLE_ENFORCE_NE( + comm_ctx, + nullptr, + paddle::platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + } else { + comm = + paddle::platform::NCCLCommContext::Instance().Get(ring_id, place); + } // use global calculate stream stream = static_cast( phi::DeviceContextPool::Instance().Get(place)) @@ -361,14 +418,18 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, auto task = pg->AllReduce(in_tensor, out_tensor, opts); task->Wait(); } else { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclAllReduce(logits_max_buff, - logits_max_buff, - logits_max.numel(), - phi::ToNCCLDataType(logits_max.dtype()), - ncclMax, - comm->comm(), - stream)); + if (comm_ctx) { + comm_ctx->AllReduce(&logits_max, logits_max, ncclMax, stream); + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclAllReduce(logits_max_buff, + logits_max_buff, + logits_max.numel(), + phi::ToNCCLDataType(logits_max.dtype()), + ncclMax, + comm->comm(), + stream)); + } } } #endif @@ -402,14 +463,18 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, auto task = pg->AllReduce(in_tensor, out_tensor, opts); task->Wait(); } else { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( - sum_exp_logits_buff, - sum_exp_logits_buff, - sum_exp_logits.numel(), - phi::ToNCCLDataType(sum_exp_logits.dtype()), - ncclSum, - comm->comm(), - stream)); + if (comm_ctx) { + comm_ctx->AllReduce(&sum_exp_logits, sum_exp_logits, ncclSum, stream); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( + sum_exp_logits_buff, + sum_exp_logits_buff, + sum_exp_logits.numel(), + phi::ToNCCLDataType(sum_exp_logits.dtype()), + ncclSum, + comm->comm(), + stream)); + } } } #endif @@ -460,14 +525,18 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, auto task = pg->AllReduce(in_tensor, out_tensor, opts); task->Wait(); } else { - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclAllReduce(loss_ptr, - loss_ptr, - loss->numel(), - phi::ToNCCLDataType(loss->dtype()), - ncclSum, - comm->comm(), - stream)); + if (comm_ctx) { + comm_ctx->AllReduce(loss, *loss, ncclSum, stream); + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclAllReduce(loss_ptr, + loss_ptr, + loss->numel(), + phi::ToNCCLDataType(loss->dtype()), + ncclSum, + comm->comm(), + stream)); + } } } #endif diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index e7931282724ab8..342a86313bf3fb 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -176,6 +176,20 @@ CommContext* CommContextManager::Get(const std::string& unique_comm_key) const { return id_to_comm_context_.at(unique_comm_key).get(); } +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +int CommContextManager::GetRingId(const ncclComm_t& comm) const { + for (auto iter = id_to_comm_context_.begin(); + iter != id_to_comm_context_.end(); + ++iter) { + if (static_cast(iter->second.get()) + ->GetNcclComm() == comm) { + return std::stoi(iter->first); + } + } + return -1; +} +#endif + bool CommContextManager::Has(const std::string& unique_comm_key) const { return id_to_comm_context_.find(unique_comm_key) != id_to_comm_context_.end(); } diff --git a/paddle/phi/core/distributed/comm_context_manager.h b/paddle/phi/core/distributed/comm_context_manager.h index e2cb298a0984b8..dcbfaab55af903 100644 --- a/paddle/phi/core/distributed/comm_context_manager.h +++ b/paddle/phi/core/distributed/comm_context_manager.h @@ -22,6 +22,10 @@ #include "paddle/phi/core/distributed/comm_context.h" #include "paddle/phi/core/macros.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/phi/backends/gpu/forwards.h" +#endif + namespace phi { namespace distributed { @@ -44,6 +48,10 @@ class CommContextManager { CommContext* Get(const std::string& unique_comm_key) const; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + int GetRingId(const ncclComm_t& comm) const; +#endif + bool Has(const std::string& unique_comm_key) const; static void SetDeviceId(int dev_id);