diff --git a/paddle/fluid/operators/class_center_sample_op.cu b/paddle/fluid/operators/class_center_sample_op.cu index f63baadbde526..efac6332c6d29 100644 --- a/paddle/fluid/operators/class_center_sample_op.cu +++ b/paddle/fluid/operators/class_center_sample_op.cu @@ -30,6 +30,7 @@ namespace cub = hipcub; #include #include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/tensor_utils.h" @@ -37,6 +38,9 @@ namespace cub = hipcub; #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -364,21 +368,47 @@ void ClassCenterSampleKernel(const Context& dev_ctx, auto task = pg->AllReduce(in_tensor, out_tensor, opts); task->Wait(); } else { - const auto& comm = paddle::platform::NCCLCommContext::Instance().Get( - ring_id, dev_ctx.GetPlace()); + paddle::platform::NCCLComm* comm = nullptr; + phi::distributed::NCCLCommContext* comm_ctx = nullptr; // use global calculate stream - const auto calcu_stream = + auto stream = static_cast( phi::DeviceContextPool::Instance().Get(dev_ctx.GetPlace())) ->stream(); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( - num_classes_per_device_ptr, - num_classes_per_device_ptr, - num_classes_per_device.numel(), - phi::ToNCCLDataType(num_classes_per_device.dtype()), - ncclSum, - comm->comm(), - calcu_stream)); + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), + true, + errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm " + "True. But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(ring_id))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(ring_id))); + stream = comm_ctx->GetStream(); + } else { + comm = paddle::platform::NCCLCommContext::Instance().Get( + ring_id, dev_ctx.GetPlace()); + } + + if (comm_ctx) { + comm_ctx->AllReduce( + &num_classes_per_device, num_classes_per_device, ncclSum, stream); + paddle::platform::GpuStreamSync(stream); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( + num_classes_per_device_ptr, + num_classes_per_device_ptr, + num_classes_per_device.numel(), + phi::ToNCCLDataType(num_classes_per_device.dtype()), + ncclSum, + comm->comm(), + stream)); + } } } #endif diff --git a/test/legacy_test/test_class_center_sample_op.py b/test/legacy_test/test_class_center_sample_op.py index 546a63faa200e..da903b5a16689 100644 --- a/test/legacy_test/test_class_center_sample_op.py +++ b/test/legacy_test/test_class_center_sample_op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import numpy as np @@ -87,7 +88,11 @@ def init_dtype(self): def init_fix_seed(self): self.fix_seed = True + def with_new_comm(self): + os.environ["FLAGS_dynamic_static_unified_comm"] = "0" + def setUp(self): + self.with_new_comm() self.initParams() self.init_dtype() self.init_fix_seed() @@ -126,6 +131,11 @@ def init_fix_seed(self): self.fix_seed = True +class TestClassCenterSampleOpWithNewComm(TestClassCenterSampleOp): + def with_new_comm(self): + os.environ["FLAGS_dynamic_static_unified_comm"] = "1" + + class TestClassCenterSampleV2(unittest.TestCase): def setUp(self): self.initParams()