Skip to content

Commit

Permalink
[NewComm] No.1 compatiable upgrade for class_center_sample op (#57153)
Browse files Browse the repository at this point in the history
* [NewComm] upgrade class_center_sample_op

* fix bug
  • Loading branch information
GreatV authored Sep 13, 2023
1 parent 9f321c8 commit b38c24c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 11 deletions.
52 changes: 41 additions & 11 deletions paddle/fluid/operators/class_center_sample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@ namespace cub = hipcub;
#include <random>

#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_utils.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
Expand Down Expand Up @@ -364,21 +368,47 @@ void ClassCenterSampleKernel(const Context& dev_ctx,
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
} else {
const auto& comm = paddle::platform::NCCLCommContext::Instance().Get(
ring_id, dev_ctx.GetPlace());
paddle::platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
// use global calculate stream
const auto calcu_stream =
auto stream =
static_cast<GPUContext*>(
phi::DeviceContextPool::Instance().Get(dev_ctx.GetPlace()))
->stream();
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce(
num_classes_per_device_ptr,
num_classes_per_device_ptr,
num_classes_per_device.numel(),
phi::ToNCCLDataType(num_classes_per_device.dtype()),
ncclSum,
comm->comm(),
calcu_stream));
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm "
"True. But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
stream = comm_ctx->GetStream();
} else {
comm = paddle::platform::NCCLCommContext::Instance().Get(
ring_id, dev_ctx.GetPlace());
}

if (comm_ctx) {
comm_ctx->AllReduce(
&num_classes_per_device, num_classes_per_device, ncclSum, stream);
paddle::platform::GpuStreamSync(stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce(
num_classes_per_device_ptr,
num_classes_per_device_ptr,
num_classes_per_device.numel(),
phi::ToNCCLDataType(num_classes_per_device.dtype()),
ncclSum,
comm->comm(),
stream));
}
}
}
#endif
Expand Down
10 changes: 10 additions & 0 deletions test/legacy_test/test_class_center_sample_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

import numpy as np
Expand Down Expand Up @@ -87,7 +88,11 @@ def init_dtype(self):
def init_fix_seed(self):
self.fix_seed = True

def with_new_comm(self):
os.environ["FLAGS_dynamic_static_unified_comm"] = "0"

def setUp(self):
self.with_new_comm()
self.initParams()
self.init_dtype()
self.init_fix_seed()
Expand Down Expand Up @@ -126,6 +131,11 @@ def init_fix_seed(self):
self.fix_seed = True


class TestClassCenterSampleOpWithNewComm(TestClassCenterSampleOp):
def with_new_comm(self):
os.environ["FLAGS_dynamic_static_unified_comm"] = "1"


class TestClassCenterSampleV2(unittest.TestCase):
def setUp(self):
self.initParams()
Expand Down

0 comments on commit b38c24c

Please sign in to comment.