Skip to content

Commit

Permalink
add coresponding func in cclbackend
Browse files Browse the repository at this point in the history
  • Loading branch information
Liangliang-Ma committed Sep 27, 2023
1 parent fe27d76 commit b1d4a74
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 36 deletions.
38 changes: 32 additions & 6 deletions csrc/cpu/comm/ccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,15 @@ int world_size = -1;

std::set<int> _comm_ids;
std::set<int> _colors;
ccl::vector_class<ccl::communicator> _ccl_comms;
std::vector<ccl::communicator> _ccl_comms;
ccl::shared_ptr_class<ccl::kvs> sub_kvs;
std::map<std::vector<int>, int> group_to_comm_id;

ccl::communicator& _get_comm_from_group() { return _ccl_comms[0]; }
ccl::communicator& _get_comm_from_group(py::object group) { return _ccl_comms[0]; }
ccl::communicator& _get_comm_from_group(std::vector<int> ranks)
{
if (group_to_comm_id.find(ranks) != group_to_comm_id.end())
{
ccl::communicator& _get_comm_from_group(std::vector<int> ranks)
{
if (group_to_comm_id.find(ranks) != group_to_comm_id.end()) {
auto id = group_to_comm_id.find(ranks);
return _ccl_comms[id->second];
}
Expand Down Expand Up @@ -412,6 +412,31 @@ py::object new_group(std::vector<int> ranks)
<< std::endl;
}

std::vector<uint8_t> get_sub_kvs_addr(bool first)
{
if (first) {
sub_kvs = ccl::create_main_kvs();
ccl::kvs::address_type main_addr = sub_kvs->get_address();
auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
return ccl_kvs_addr;
} else {
ccl::kvs::address_type main_addr;
auto ccl_kvs_addr = std::vector<uint8_t>(main_addr.begin(), main_addr.end());
return ccl_kvs_addr;
}
}

void initialize_sub_comm(int size, int rank, torch::Tensor& kvs_data, std::vector<int> ranks)
{
ccl::kvs::address_type main_addr;
if (rank != 0) {
memcpy(main_addr.data(), kvs_data.data_ptr(), main_addr.size());
sub_kvs = ccl::create_kvs(main_addr);
}
_ccl_comms.push_back(ccl::create_communicator(size, rank, sub_kvs));
group_to_comm_id[ranks] = _ccl_comms.size() - 1;
}

ccl::datatype get_ccl_datatype(c10::ScalarType type)
{
ccl::datatype ccl_type;
Expand Down Expand Up @@ -586,7 +611,8 @@ void barrier(std::vector<int> group, bool async_op)

std::vector<std::string> get_available_coll()
{
std::vector<std::string> colls{"broadcast", "all_reduce", "inference_all_reduce", "all_reduce_caching", "barrier"};
std::vector<std::string> colls{
"broadcast", "all_reduce", "inference_all_reduce", "all_reduce_caching", "barrier"};
return colls;
}

Expand Down
93 changes: 63 additions & 30 deletions deepspeed/comm/ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ def build_ccl_op():
print(f'DeepSpeed {builder.absolute_name()} built successfully')
return ccl_cpp_module


class CCLHandler():

def __init__(self, ccl_comm_op=None):
self.ccl_comm_op = ccl_comm_op

def wait(self):
# backend covered it
pass


class CCLBackend(TorchBackend):

def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, init_method=None):
Expand All @@ -50,50 +53,76 @@ def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, i
super(CCLBackend, self).broadcast(main_kvs, 0)
self.ccl_comm_op.initialize(size, rank, main_kvs)
self.initialized = True
self.groups = [tuple(range(self.get_world_size()))]
self.available_coll = self.ccl_comm_op.get_available_coll()

def is_initialized(self):
return self.initialized

def run_collective(self, name, **kwargs):
if name in self.available_coll:
kwargs['group'] = self.get_all_ranks_from_group(kwargs['group'])
if 'dst' in kwargs:
kwargs['dst'] = kwargs['group'].index(kwargs['dst'])
if 'src' in kwargs:
kwargs['src'] = kwargs['group'].index(kwargs['src'])
func = "self.ccl_comm_op." + name
eval(func)(*(kwargs.values()))
return CCLHandler(self.ccl_comm_op)
else:
func = "super(CCLBackend, self)." + name
return eval(func)(*(kwargs.values()))

def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
use_caching = False
if use_caching:
match_id = f"{tensor.size()}-{op}"
return self.run_collective(name="all_reduce_caching", tensor=tensor, op=op, match_id=match_id, group=group, async_op=async_op)
return self.run_collective(name="all_reduce_caching",
tensor=tensor,
op=op,
match_id=match_id,
group=group,
async_op=async_op)
else:
return self.run_collective(name="all_reduce", tensor=tensor, op=op, group=group, async_op=async_op)

def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
return self.run_collective(name="inference_all_reduce", tensor=tensor, op=op, group=group, async_op=async_op)

def broadcast(self, tensor, src, group=None, async_op=False):
return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op)
return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op)

def all_gather(self, tensor_list, tensor, group=None, async_op=False):
return self.run_collective(name="all_gather", tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
return self.run_collective(name="all_gather",
tensor_list=tensor_list,
tensor=tensor,
group=group,
async_op=async_op)

def reduce_scatter_tensor(self, output_tensor, input_tensor, op, group=None, async_op=False):
return self.run_collective(name="reduce_scatter_tensor", output_tensor=output_tensor, input_tensor=input_tensor, op=op, group=group)
return self.run_collective(name="reduce_scatter_tensor",
output_tensor=output_tensor,
input_tensor=input_tensor,
op=op,
group=group)

def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
return self.run_collective(name="all_gather_into_tensor", output_tensor=output_tensor, input_tensor=input_tensor, group=group)
return self.run_collective(name="all_gather_into_tensor",
output_tensor=output_tensor,
input_tensor=input_tensor,
group=group)

def all_to_all_single(self, output, input, output_split_sizes, input_split_sizes, group=None, async_op=False):
return self.run_collective(name="all_to_all_single", output=output, input=input, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=group)
return self.run_collective(name="all_to_all_single",
output=output,
input=input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group)

def send(self, tensor, dst, group=None, async_op=False):
return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, async_op=async_op)

def recv(self, tensor, src, group=None, async_op=False):
return self.run_collective(name="recv", tensor=tensor, src=src, group=group, async_op=async_op)

Expand All @@ -103,43 +132,47 @@ def gather(self, tensor, gather_list, dst, group=None, async_op=False):
def scatter(self, tensor, gather_list, dst, group=None, async_op=False):
return self.run_collective(name="scatter", tensor=tensor, gather_list=gather_list, dst=dst, group=group)

def barrier(self, group=None, async_op=False):
def barrier(self, group=None, async_op=False):
return self.run_collective(name="barrier", group=group, async_op=async_op)

def monitored_barrier(self, group=None, timeout=None, wait_all_ranks=False):
return self.run_collective(name="monitored_barrier", group=group)

def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
return self.run_collective(name="reduce_scatter", output=output, input_list=input_list, op=op, group=group, async_op=async_op)

def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
return self.run_collective(name="reduce_scatter",
output=output,
input_list=input_list,
op=op,
group=group,
async_op=async_op)

def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
return self.run_collective(name="reduce", tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)

def new_group(self, ranks):
return super(CCLBackend, self).new_group(ranks)

def _new_group(self, ranks, group):
size = len(ranks)
rank = self.get_rank()
if tuple(ranks) in self.groups or rank not in ranks:
return
sub_main_kvs = self.ccl_comm_op.get_sub_kvs_addr(rank == ranks[0])
sub_main_kvs = torch.tensor(sub_main_kvs).to(torch.uint8).to("xpu:"+str(rank))
torch_new_group = super(CCLBackend, self).new_group(ranks)
super(CCLBackend, self).broadcast(sub_main_kvs, ranks[0], torch_new_group, False)
sub_main_kvs = torch.tensor(sub_main_kvs).to(torch.uint8).to("xpu:" + str(rank))
torch_new_group = group
super(CCLBackend, self).broadcast(sub_main_kvs, ranks[0], group)
self.ccl_comm_op.initialize_sub_comm(size, ranks.index(rank), sub_main_kvs, ranks)
self.groups.append(tuple(ranks))
return torch_new_group


def get_all_ranks_from_group(self, group):
if group is None:
return list(range(self.get_world_size()))
rank=0
results=[]
rank = 0
results = []
try:
while True:
results.append(torch.distributed.distributed_c10d._get_global_rank(group, rank))
rank+=1
results.append(super(CCLBackend, self).get_global_rank(group, rank))
rank += 1
except RuntimeError:
pass

if tuple(results) not in self.groups:
self.new_group(results)
self._new_group(results, group)
return results

0 comments on commit b1d4a74

Please sign in to comment.