diff --git a/paddle/fluid/operators/fused/fused_attention_utils.h b/paddle/fluid/operators/fused/fused_attention_utils.h index 26cab895f0dfc..2d599a16e25fe 100644 --- a/paddle/fluid/operators/fused/fused_attention_utils.h +++ b/paddle/fluid/operators/fused/fused_attention_utils.h @@ -18,8 +18,13 @@ #include "paddle/fluid/distributed/collective/process_group_nccl.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif +#include "paddle/fluid/distributed/collective/utils.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/errors.h" namespace phi { @@ -47,11 +52,46 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT auto place = dev_ctx.GetPlace(); void *recvbuff = dev_ctx.template Alloc(&tensor, tensor.numel() * sizeof(T)); - auto comm = - paddle::platform::NCCLCommContext::Instance().Get(ring_id, place); - auto stream = dev_ctx.stream(); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( - sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream)); + gpuStream_t stream = nullptr; + platform::NCCLComm *comm = nullptr; + phi::distributed::NCCLCommContext *comm_ctx = nullptr; + + const auto &comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + + if (FLAGS_dynamic_static_unified_comm) { + // Use New Communication Library + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), + true, + platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(ring_id))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(ring_id))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + + stream = comm_ctx->GetStream(); + VLOG(3) << "new comm_context_manager has ring_id" << ring_id; + } else { + comm = paddle::platform::NCCLCommContext::Instance().Get(ring_id, place); + + stream = dev_ctx.stream(); + VLOG(3) << "old NCCLCommContext has ring_id " << ring_id; + } + if (comm_ctx) { + comm_ctx->AllReduce(&tensor, tensor, ncclSum, stream); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( + sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream)); + } } #else PADDLE_THROW(phi::errors::Unimplemented( diff --git a/test/legacy_test/test_fused_attention_op.py b/test/legacy_test/test_fused_attention_op.py index af734c96d19d8..0e012659f95f6 100644 --- a/test/legacy_test/test_fused_attention_op.py +++ b/test/legacy_test/test_fused_attention_op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import numpy as np @@ -31,6 +32,7 @@ class TestFusedAttentionOp(OpTest): def setUp(self): + self.with_new_comm() self.config() self.generate_input_data() @@ -79,6 +81,9 @@ def setUp(self): paddle.set_default_dtype(self.x_type) self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") + def with_new_comm(self): + os.environ["FLAGS_dynamic_static_unified_comm"] = "0" + def config(self): self.x_type = np.float32 self.attn_mask_type = np.float64 @@ -350,6 +355,11 @@ def test_fused_attention_op(self): ) +class TestFusedAttentionOpWithNewComm(TestFusedAttentionOp): + def with_new_comm(self): + os.environ["FLAGS_dynamic_static_unified_comm"] = "1" + + class TestFusedAttentionOpBiasIsNone(TestFusedAttentionOp): def config(self): super().config()