Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Support SHM based inference_all_reduce in TorchBackend #5391

Merged
merged 10 commits into from
Apr 17, 2024
6 changes: 4 additions & 2 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,14 @@ def get_op_builder(self, class_name):
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
from op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
from op_builder.cpu import CCLCommBuilder, SHMCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
except ImportError:
from deepspeed.ops.op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder
from deepspeed.ops.op_builder.cpu import CCLCommBuilder, SHMCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder

if class_name == "CCLCommBuilder":
return CCLCommBuilder
elif class_name == "SHMCommBuilder":
return SHMCommBuilder
elif class_name == "FusedAdamBuilder":
return FusedAdamBuilder
elif class_name == "CPUAdamBuilder":
Expand Down
9 changes: 2 additions & 7 deletions csrc/cpu/comm/ccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ void all_reduce_caching(torch::Tensor& data,
.wait());
}

void inference_all_reduce(torch::Tensor& data, py::object op, bool async_op)
void inference_all_reduce(torch::Tensor& data)
{
#ifdef DO_PROFILE
static double total_time = 0.0;
Expand All @@ -263,11 +263,6 @@ void inference_all_reduce(torch::Tensor& data, py::object op, bool async_op)
auto start = std::chrono::system_clock::now();
#endif

static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));

assert(py::int_(op.attr("value")) == ReduceOpSum);

auto numel = data.numel();

int data_size = 0;
Expand All @@ -285,7 +280,7 @@ void inference_all_reduce(torch::Tensor& data, py::object op, bool async_op)
data.data_ptr(),
data.numel(),
get_ccl_datatype(data.scalar_type()),
get_ccl_reduce_op(op, data),
ccl::reduction::sum,
_get_comm_from_group())
.wait());
} else {
Expand Down
115 changes: 115 additions & 0 deletions csrc/cpu/comm/shm_interface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include <torch/extension.h>

#include "shm.h"

// #define DO_PROFILE
#ifdef DO_PROFILE
#include <cfloat>
#include <chrono>
#endif

// Communication settings
static int world_rank = -1;
static int world_size = -1;

static bool is_initialized = 0;

static bool all_ranks_local_p = false;

void initialize(int size, int rank)
{
if (is_initialized) return;

// Check whether all ranks is on the same physical machine.
// If true, we will use an SHM based low latency allreduce

auto ls_string = std::getenv("LOCAL_SIZE");
int ls = 0;
if (ls_string != NULL) { ls = std::stoi(std::getenv("LOCAL_SIZE")); }

if (size >= 1 && size == ls) { all_ranks_local_p = true; }

world_size = size;
world_rank = rank;
is_initialized = 1;

auto addr_string = std::getenv("MASTER_ADDR");
if (addr_string == NULL) { addr_string = ""; }
auto port_string = std::getenv("MASTER_PORT");
if (port_string == NULL) { port_string = ""; }

if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); }
}

int get_rank(int group = 0) { return world_rank; }

int get_world_size(int group = 0) { return world_size; }

// Success - return 0
// Fail (cannot hornor the request and need to fall back) - return -1
int inference_all_reduce(torch::Tensor& data)
{
if (!all_ranks_local_p) return -1;
#ifdef DO_PROFILE
static double total_time = 0.0;
static double total_time_sq = 0.0;
static int count = -16; // warmup
static double max_time = 0.0;
static double min_time = DBL_MAX;
// make sure all rank reach this point before measuring time
// turn on this if you suspect each rank didn't reach here at the same time (stragger)
// if (all_ranks_local_p) { barrier_wait(0, world_size); }
auto start = std::chrono::system_clock::now();
#endif

auto numel = data.numel();

int data_size = 0;
bool data_type_fallback = false;

switch (data.scalar_type()) {
case c10::ScalarType::BFloat16: data_size = numel * 2; break;
case c10::ScalarType::Float: data_size = numel * 4; break;
default: data_type_fallback = true;
}

if (data_type_fallback) return -1;

all_reduce_outer_loop(data, numel, data_size);

#ifdef DO_PROFILE
auto end = std::chrono::system_clock::now();
count++;
if (count > 0) {
double elapsed = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
if (elapsed > max_time) { max_time = elapsed; }
if (elapsed < min_time) { min_time = elapsed; }
total_time += elapsed;
total_time_sq += elapsed * elapsed;
if (world_rank == 0 && count == 1000) {
auto avg = total_time / count;
auto sd =
sqrt(total_time_sq / count - total_time * total_time / (count * count)) / avg * 100;
printf(" C++ kernel\t\t %.2f\t %.2f\t%.2f\t %.2f\n",
min_time,
max_time,
total_time / count,
sd);
}
}
#endif
return 0;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("initialize", &initialize, "shm initialize");
m.def("get_rank", &get_rank, "get rank");
m.def("get_world_size", &get_world_size, "get world size");
m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation");
}
6 changes: 3 additions & 3 deletions deepspeed/comm/ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
else:
return self.run_collective(name=name, tensor=tensor, op=op, group=group, async_op=async_op)

def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
def inference_all_reduce(self, tensor, group=None):
name = "inference_all_reduce"
if name in self.available_coll:
return self.ccl_comm_op.inference_all_reduce(tensor, op, async_op)
return self.ccl_comm_op.inference_all_reduce(tensor)
else:
return self.run_collective(name=name, tensor=tensor, op=op, group=None, async_op=async_op)
return self.run_collective(name=name, tensor=tensor, op=ReduceOp.SUM, group=None, async_op=False)

def broadcast(self, tensor, src, group=None, async_op=False):
return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def inference_all_reduce(tensor,
log_name='all_reduce',
debug=get_caller_func()):
global cdb
return cdb.inference_all_reduce(tensor, op, group, async_op)
return cdb.inference_all_reduce(tensor, group)


@timed_op
Expand Down
22 changes: 19 additions & 3 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import deepspeed
from deepspeed import utils

from .utils import *
Expand All @@ -19,6 +20,15 @@
DS_COMM_REDUCE_OFF = False


def build_shm_op():
builder = get_accelerator().create_op_builder("SHMCommBuilder")
if builder is None or not deepspeed.ops.__compatible_ops__[builder.NAME]:
return None
shm_cpp_module = builder.load()
print(f'DeepSpeed {builder.absolute_name()} built successfully')
return shm_cpp_module


def has_coalescing_manager():
has_c10d = hasattr(torch.distributed, 'distributed_c10d')
return has_c10d and hasattr(torch.distributed.distributed_c10d, '_coalescing_manager')
Expand Down Expand Up @@ -88,6 +98,7 @@ class TorchBackend(Backend):

def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
super(TorchBackend, self).__init__()
self.shm_comm_op = build_shm_op()
self.has_all_reduce_coalesced = has_all_reduce_coalesced()
self.has_coalescing_manager = has_coalescing_manager()
self.all_gather_function = self.get_all_gather_function()
Expand All @@ -99,6 +110,8 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
# it is not so we can run on a single GPU without doing any init_process_group
self.single_gpu_mode = True
self.init_process_group(backend, timeout, init_method, rank, world_size)
if self.shm_comm_op != None:
self.shm_comm_op.initialize(self.get_world_size(), self.get_rank())

@classmethod
@compiler.disable
Expand Down Expand Up @@ -139,9 +152,12 @@ def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, asyn
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

@compiler.disable
def inference_all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
def inference_all_reduce(self, tensor, group=None):
if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor) == -1:
return torch.distributed.all_reduce(tensor=tensor,
op=torch.distributed.ReduceOp.SUM,
group=group,
async_op=False)

@compiler.disable
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
Expand Down
2 changes: 1 addition & 1 deletion op_builder/cpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''

from .comm import CCLCommBuilder
from .comm import CCLCommBuilder, SHMCommBuilder
from .fused_adam import FusedAdamBuilder
from .cpu_adam import CPUAdamBuilder
from .no_impl import NotImplementedBuilder
27 changes: 27 additions & 0 deletions op_builder/cpu/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,30 @@ def extra_ldflags(self):
return []
else:
return ['-lccl', f'-L{ccl_root_path}/lib']


class SHMCommBuilder(CPUOpBuilder):
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
BUILD_VAR = "DS_BUILD_SHM_COMM"
NAME = "deepspeed_shm_comm"

def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)

def absolute_name(self):
return f'deepspeed.ops.comm.{self.NAME}_op'

def sources(self):
return ['csrc/cpu/comm/shm_interface.cpp', 'csrc/cpu/comm/shm.cpp']

def include_paths(self):
includes = ['csrc/cpu/includes']
return includes

def cxx_args(self):
return ['-O2', '-fopenmp']

def is_compatible(self, verbose=True):
# TODO: add soft compatibility check for private binary release.
# a soft check, as in we know it can be trivially changed.
return super().is_compatible(verbose)
Loading