diff --git a/csrc/cpu/comm/ccl.cpp b/csrc/cpu/comm/ccl.cpp index accf431f6929..9a04b6f873f2 100644 --- a/csrc/cpu/comm/ccl.cpp +++ b/csrc/cpu/comm/ccl.cpp @@ -277,10 +277,20 @@ int world_size = -1; std::set _comm_ids; std::set _colors; -ccl::vector_class _ccl_comms; +std::vector _ccl_comms; +ccl::shared_ptr_class sub_kvs; +std::map, int> group_to_comm_id; ccl::communicator& _get_comm_from_group() { return _ccl_comms[0]; } ccl::communicator& _get_comm_from_group(py::object group) { return _ccl_comms[0]; } +ccl::communicator& _get_comm_from_group(std::vector ranks) +{ + if (group_to_comm_id.find(ranks) != group_to_comm_id.end()) { + auto id = group_to_comm_id.find(ranks); + return _ccl_comms[id->second]; + } + return _ccl_comms[0]; +} #define CCLCHECK(cmd) \ do { \ @@ -394,12 +404,29 @@ int next_unique_val(std::set s) } } -py::object new_group(std::vector ranks) +std::vector get_sub_kvs_addr(bool first) +{ + if (first) { + sub_kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = sub_kvs->get_address(); + auto ccl_kvs_addr = std::vector(main_addr.begin(), main_addr.end()); + return ccl_kvs_addr; + } else { + ccl::kvs::address_type main_addr; + auto ccl_kvs_addr = std::vector(main_addr.begin(), main_addr.end()); + return ccl_kvs_addr; + } +} + +void initialize_sub_comm(int size, int rank, torch::Tensor& kvs_data, std::vector ranks) { - int comm_id = next_unique_val(_comm_ids); - int color = next_unique_val(_colors); - std::cout << "RANK: " << get_rank() << " COMM_ID: " << comm_id << " COLOR: " << color - << std::endl; + ccl::kvs::address_type main_addr; + if (rank != 0) { + memcpy(main_addr.data(), kvs_data.data_ptr(), main_addr.size()); + sub_kvs = ccl::create_kvs(main_addr); + } + _ccl_comms.push_back(ccl::create_communicator(size, rank, sub_kvs)); + group_to_comm_id[ranks] = _ccl_comms.size() - 1; } ccl::datatype get_ccl_datatype(c10::ScalarType type) @@ -452,7 +479,7 @@ ccl::reduction get_ccl_reduce_op(py::object op, at::Tensor& input) return ccl_op; } -void broadcast(torch::Tensor& data, int src, py::object group, bool async_op) +void broadcast(torch::Tensor& data, int src, std::vector group, bool async_op) { CCLCHECK(ccl::broadcast(data.data_ptr(), data.numel(), @@ -463,7 +490,7 @@ void broadcast(torch::Tensor& data, int src, py::object group, bool async_op) } // TODO: implement torch's async_op behavior, document it. -void all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op) +void all_reduce(torch::Tensor& data, py::object op, std::vector group, bool async_op) { CCLCHECK(ccl::allreduce(data.data_ptr(), data.data_ptr(), @@ -477,7 +504,7 @@ void all_reduce(torch::Tensor& data, py::object op, py::object group, bool async void all_reduce_caching(torch::Tensor& data, py::object op, std::string match_id, - py::object group, + std::vector group, bool async_op) { ccl::allreduce_attr attr = ccl::default_allreduce_attr; @@ -510,7 +537,7 @@ static void parallel_memcpy(void* to, void* from, size_t n_bytes) } } -void inference_all_reduce(torch::Tensor& data, py::object op, py::object group, bool async_op) +void inference_all_reduce(torch::Tensor& data, py::object op, std::vector group, bool async_op) { static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp"); static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value")); @@ -583,11 +610,18 @@ void inference_all_reduce(torch::Tensor& data, py::object op, py::object group, } } -void barrier(py::object group, bool async_op) +void barrier(std::vector group, bool async_op) { CCLCHECK(ccl::barrier(_get_comm_from_group(group)).wait()); } +std::vector get_available_coll() +{ + std::vector colls{ + "broadcast", "all_reduce", "inference_all_reduce", "all_reduce_caching", "barrier"}; + return colls; +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_kvs_addr", &get_kvs_addr, "create and get main kvs addr"); @@ -599,4 +633,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation"); m.def("all_reduce_caching", &all_reduce_caching, "ccl all_reduce with caching"); m.def("barrier", &barrier, "barrier"); + m.def("initialize_sub_comm", &initialize_sub_comm, "initialize_sub_comm"); + m.def("get_sub_kvs_addr", &get_sub_kvs_addr, "get_sub_kvs_addr"); + m.def("get_available_coll", &get_available_coll, "get_available_coll"); } diff --git a/deepspeed/comm/ccl.py b/deepspeed/comm/ccl.py index 65646f6ba646..859e70c7db3b 100644 --- a/deepspeed/comm/ccl.py +++ b/deepspeed/comm/ccl.py @@ -21,6 +21,16 @@ def build_ccl_op(): return ccl_cpp_module +class CCLHandler(): + + def __init__(self, ccl_comm_op=None): + self.ccl_comm_op = ccl_comm_op + + def wait(self): + # backend covered it + pass + + class CCLBackend(TorchBackend): def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, init_method=None): @@ -39,27 +49,129 @@ def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, i size = self.get_world_size() rank = self.get_rank() main_kvs = self.ccl_comm_op.get_kvs_addr(rank) - main_kvs = torch.tensor(main_kvs).to(torch.uint8) + main_kvs = torch.tensor(main_kvs).to(torch.uint8).to(get_accelerator().current_device_name()) super(CCLBackend, self).broadcast(main_kvs, 0) self.ccl_comm_op.initialize(size, rank, main_kvs) self.initialized = True + self.groups = [tuple(range(self.get_world_size()))] + self.available_coll = self.ccl_comm_op.get_available_coll() def is_initialized(self): return self.initialized - def broadcast(self, tensor, src, group=None, async_op=False): - self.ccl_comm_op.broadcast(tensor, src, group, async_op) + def run_collective(self, name, **kwargs): + if name in self.available_coll: + kwargs['group'] = self.get_all_ranks_from_group(kwargs['group']) + if 'dst' in kwargs: + kwargs['dst'] = kwargs['group'].index(kwargs['dst']) + if 'src' in kwargs: + kwargs['src'] = kwargs['group'].index(kwargs['src']) + func = "self.ccl_comm_op." + name + eval(func)(*(kwargs.values())) + return CCLHandler(self.ccl_comm_op) + else: + func = "super(CCLBackend, self)." + name + return eval(func)(*(kwargs.values())) def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False): use_caching = False if use_caching: match_id = f"{tensor.size()}-{op}" - self.ccl_comm_op.all_reduce_caching(tensor, op, match_id, group, async_op) + return self.run_collective(name="all_reduce_caching", + tensor=tensor, + op=op, + match_id=match_id, + group=group, + async_op=async_op) else: - self.ccl_comm_op.all_reduce(tensor, op, group, async_op) + return self.run_collective(name="all_reduce", tensor=tensor, op=op, group=group, async_op=async_op) def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False): - self.ccl_comm_op.inference_all_reduce(tensor, op, group, async_op) + return self.run_collective(name="inference_all_reduce", tensor=tensor, op=op, group=group, async_op=async_op) + + def broadcast(self, tensor, src, group=None, async_op=False): + return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op) + + def all_gather(self, tensor_list, tensor, group=None, async_op=False): + return self.run_collective(name="all_gather", + tensor_list=tensor_list, + tensor=tensor, + group=group, + async_op=async_op) + + def reduce_scatter_tensor(self, output_tensor, input_tensor, op, group=None, async_op=False): + return self.run_collective(name="reduce_scatter_tensor", + output_tensor=output_tensor, + input_tensor=input_tensor, + op=op, + group=group) + + def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False): + return self.run_collective(name="all_gather_into_tensor", + output_tensor=output_tensor, + input_tensor=input_tensor, + group=group) + + def all_to_all_single(self, output, input, output_split_sizes, input_split_sizes, group=None, async_op=False): + return self.run_collective(name="all_to_all_single", + output=output, + input=input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group) + + def send(self, tensor, dst, group=None, async_op=False): + return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, async_op=async_op) + + def recv(self, tensor, src, group=None, async_op=False): + return self.run_collective(name="recv", tensor=tensor, src=src, group=group, async_op=async_op) + + def gather(self, tensor, gather_list, dst, group=None, async_op=False): + return self.run_collective(name="gather", tensor=tensor, gather_list=gather_list, dst=dst, group=group) + + def scatter(self, tensor, gather_list, dst, group=None, async_op=False): + return self.run_collective(name="scatter", tensor=tensor, gather_list=gather_list, dst=dst, group=group) def barrier(self, group=None, async_op=False): - self.ccl_comm_op.barrier(group, async_op) + return self.run_collective(name="barrier", group=group, async_op=async_op) + + def monitored_barrier(self, group=None, timeout=None, wait_all_ranks=False): + return self.run_collective(name="monitored_barrier", group=group) + + def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False): + return self.run_collective(name="reduce_scatter", + output=output, + input_list=input_list, + op=op, + group=group, + async_op=async_op) + + def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): + return self.run_collective(name="reduce", tensor=tensor, dst=dst, op=op, group=group, async_op=async_op) + + def new_group(self, ranks): + return super(CCLBackend, self).new_group(ranks) + + def _new_group(self, ranks, group): + size = len(ranks) + rank = self.get_rank() + sub_main_kvs = self.ccl_comm_op.get_sub_kvs_addr(rank == ranks[0]) + sub_main_kvs = torch.tensor(sub_main_kvs).to(torch.uint8).to(get_accelerator().current_device_name()) + super(CCLBackend, self).broadcast(sub_main_kvs, ranks[0], group) + self.ccl_comm_op.initialize_sub_comm(size, ranks.index(rank), sub_main_kvs, ranks) + self.groups.append(tuple(ranks)) + + def get_all_ranks_from_group(self, group): + if group is None: + return list(range(self.get_world_size())) + rank = 0 + results = [] + try: + while True: + results.append(super(CCLBackend, self).get_global_rank(group, rank)) + rank += 1 + except RuntimeError: + pass + if tuple(results) not in self.groups: + self._new_group(results, group) + return results