From 2ce6bf8ce019bf445a5148ffc55dfd8178dd280e Mon Sep 17 00:00:00 2001 From: RyanInnerpeace Date: Sat, 9 Dec 2023 07:01:52 +0800 Subject: [PATCH] [NPU] Add HcclBackend for 1-bit adam, 1-bit lamb, 0/1 adam (#4733) To support NPU devices fulfilling feature requirements like 1-bit Adam, 1-bit Lamb, 0/1 Adam, I add HcclBackend and its corresponding import logics. See what we have already done in #4567 . --------- Co-authored-by: ryan Co-authored-by: Conglong Li --- deepspeed/runtime/comm/hccl.py | 124 ++++++++++++++++++++++++ deepspeed/runtime/fp16/onebit/adam.py | 11 +-- deepspeed/runtime/fp16/onebit/lamb.py | 12 +-- deepspeed/runtime/fp16/onebit/zoadam.py | 11 +-- 4 files changed, 140 insertions(+), 18 deletions(-) create mode 100644 deepspeed/runtime/comm/hccl.py diff --git a/deepspeed/runtime/comm/hccl.py b/deepspeed/runtime/comm/hccl.py new file mode 100644 index 000000000000..09fb11a731b8 --- /dev/null +++ b/deepspeed/runtime/comm/hccl.py @@ -0,0 +1,124 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import numpy as np +import torch +import torch_npu +import deepspeed.comm as dist + + +class HcclBackend(object): + + def __init__(self, mpu=None): + if mpu is None: + self.world_group = dist.new_group(ranks=range(dist.get_world_size())) + else: + self.mpu = mpu + self.world_group = self.mpu.get_data_parallel_group() + self.size = dist.get_world_size(group=self.world_group) + self.rank = dist.get_rank(group=self.world_group) + + def my_igather(self, rank, size, group, sendbuf, recvbuf, root): + req = [] + if rank == root: + for idx in range(size): + if idx != rank: + req.append(dist.irecv(recvbuf[idx], src=idx, group=group)) + else: + recvbuf[rank] = sendbuf + else: + req.append(dist.isend(sendbuf, group=group, dst=root)) + return req + + def my_gather(self, rank, size, group, sendbuf, recvbuf, root): + if rank == root: + for idx in range(size): + if idx != rank: + dist.recv(recvbuf[idx], src=idx, group=group) + else: + recvbuf[rank] = sendbuf + else: + dist.send(sendbuf, group=group, dst=root) + + def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank): + original_shape = buffer_m.size() + if len(original_shape) > 1: + buffer_m = torch.flatten(buffer_m) + + # align size of original_buffer and error + original_size = buffer_m.numel() + worker_error_size = worker_error.numel() + if original_size != worker_error_size: + empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device) + buffer_m = torch.cat([buffer_m, empty_tensor]) + + buffer_m.add_(worker_error) + worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) + + worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) + + sign_list_packed_tmp = torch_npu.npu_sign_bits_pack(buffer_m, self.size).type(torch.int8) + + recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])], + dtype=sign_list_packed_tmp[0].dtype, + device=sign_list_packed_tmp.device) + + sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)] + + recvbuf_scale = [ + torch.zeros(1, dtype=worker_scale.dtype, device=torch.device(local_rank)) for _ in range(self.size) + ] + + # communication phase 1 + # all to all for sign + dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group) + # all gather for scale + dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group) + + flattened_recvbuf_sign = recvbuf_sign.type(torch.uint8).flatten() + compensated_server_m = torch_npu.npu_sign_bits_unpack(flattened_recvbuf_sign, self.size, torch.float32) \ + .mul_(torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0) + + compensated_server_m.add_(server_error) + + server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) + + server_error.set_(compensated_server_m - + server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) + + server_sign_packed = torch_npu.npu_sign_bits_pack(compensated_server_m, 1).type(torch.int8) + + # recvbuf_sign_server + recvbuf_sign_server_tmp = torch.zeros([self.size, len(server_sign_packed[0])], + dtype=recvbuf_sign.dtype, + device=server_sign_packed.device) + + recvbuf_sign_server = [recvbuf_sign_server_tmp[idx] for idx in range(self.size)] + + # recvbuf_scale_server + recvbuf_scale_server_tmp = torch.zeros([self.size, 1], + dtype=worker_scale.dtype, + device=server_sign_packed.device) + + recvbuf_scale_server = [recvbuf_scale_server_tmp[idx] for idx in range(self.size)] + + # communication Phase 2 + dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group) + dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group) + + recvbuf_sign_server = torch.stack(recvbuf_sign_server) + + flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten() + + buffer_m.data.copy_( + torch_npu.npu_sign_bits_unpack(flattened_recvbuf_sign_server, self.size, + torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data) + + if original_size != worker_error_size: + buffer_m = buffer_m[0:original_size] + if len(original_shape) > 1: + buffer_m = buffer_m.reshape(original_shape) + + return buffer_m diff --git a/deepspeed/runtime/fp16/onebit/adam.py b/deepspeed/runtime/fp16/onebit/adam.py index 236eea8cadc5..ae3e5f573850 100644 --- a/deepspeed/runtime/fp16/onebit/adam.py +++ b/deepspeed/runtime/fp16/onebit/adam.py @@ -70,8 +70,6 @@ def __init__(self, super(OnebitAdam, self).__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 - assert (dist.is_initialized()) - self.comm_time = 0.0 self.step_time = 0.0 self.ave_step = 1 @@ -86,22 +84,23 @@ def __init__(self, self.comm_backend_name = comm_backend_name + assert dist.is_initialized(), "Please initialize the torch distributed backend." # Empty initializer. Set handle based on the comm backend as follows. self.comm_backend_handle = None - if self.comm_backend_name == 'nccl': assert ( required_torch_version(min_version=1.8) ), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" - assert dist.is_initialized() == True, "Please initialize the torch distributed backend." from deepspeed.runtime.comm.nccl import NcclBackend self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) - elif self.comm_backend_name == 'mpi': from deepspeed.runtime.comm.mpi import MpiBackend self.comm_backend_handle = MpiBackend(cuda_aware) - + elif self.comm_backend_name == 'hccl': + from deepspeed.runtime.comm.hccl import HcclBackend + self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = HcclBackend(self.deepspeed.mpu) self.size = self.comm_backend_handle.size self.divider = int(self.size * 8 / np.gcd(self.size, 8)) diff --git a/deepspeed/runtime/fp16/onebit/lamb.py b/deepspeed/runtime/fp16/onebit/lamb.py index 0662fabeeee1..9cd2e0f25648 100644 --- a/deepspeed/runtime/fp16/onebit/lamb.py +++ b/deepspeed/runtime/fp16/onebit/lamb.py @@ -93,8 +93,6 @@ def __init__(self, super(OnebitLamb, self).__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 - assert (dist.is_initialized()) - self.deepspeed = deepspeed self.lamb_freeze_key = False self.initialize = False @@ -108,21 +106,23 @@ def __init__(self, self.comm_backend_name = comm_backend_name + assert dist.is_initialized(), "Please initialize the torch distributed backend." # Empty initializer. Set handle based on the comm backend as follows. self.comm_backend_handle = None - if self.comm_backend_name == 'nccl': assert ( required_torch_version(min_version=1.8) ), "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" - assert dist.is_initialized() == True, "Please initialize the torch distributed backend." from deepspeed.runtime.comm.nccl import NcclBackend self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) - elif self.comm_backend_name == 'mpi': from deepspeed.runtime.comm.mpi import MpiBackend self.comm_backend_handle = MpiBackend(cuda_aware) + elif self.comm_backend_name == 'hccl': + from deepspeed.runtime.comm.hccl import HcclBackend + self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = HcclBackend(self.deepspeed.mpu) self.size = self.comm_backend_handle.size @@ -161,7 +161,7 @@ def step(self, closure=None, grads=None): else: grads_group = grads - #remove the previous stats + # remove the previous stats del self.lamb_coeffs[:] if self.lamb_freeze_key: diff --git a/deepspeed/runtime/fp16/onebit/zoadam.py b/deepspeed/runtime/fp16/onebit/zoadam.py index 922263ad6a76..9ef671e7e3b7 100644 --- a/deepspeed/runtime/fp16/onebit/zoadam.py +++ b/deepspeed/runtime/fp16/onebit/zoadam.py @@ -83,8 +83,6 @@ def __init__(self, super(ZeroOneAdam, self).__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 - assert (dist.is_initialized()) - self.deepspeed = deepspeed self.initialize = False self.cuda_aware = cuda_aware @@ -99,22 +97,23 @@ def __init__(self, self.comm_backend_name = comm_backend_name + assert dist.is_initialized(), "Please initialize the torch distributed backend." # Empty initializer. Set handle based on the comm backend as follows. self.comm_backend_handle = None - if self.comm_backend_name == 'nccl': assert ( required_torch_version(min_version=1.8) ), "Please use torch 1.8 or greater to enable NCCL backend in 0/1 Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" - assert dist.is_initialized() == True, "Please initialize the torch distributed backend." from deepspeed.runtime.comm.nccl import NcclBackend self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) - elif self.comm_backend_name == 'mpi': from deepspeed.runtime.comm.mpi import MpiBackend self.comm_backend_handle = MpiBackend(cuda_aware) - + elif self.comm_backend_name == 'hccl': + from deepspeed.runtime.comm.hccl import HcclBackend + self.using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = HcclBackend(self.deepspeed.mpu) self.size = self.comm_backend_handle.size self.divider = int(self.size * 8 / np.gcd(self.size, 8))