From c632ea09f8d107d10f76aa2b776e4df3c1ccf98a Mon Sep 17 00:00:00 2001 From: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Date: Fri, 19 Apr 2024 08:58:27 -0700 Subject: [PATCH 01/23] Selective dequantization (#5375) This PR adds a new functionality for the dequantizer function, called `selective_dequantize`, which enables partially dequantizing a 3-dimensional matrix in case we don't need to dequantize all the data from lower bit (like fp8/fp6) to bf16. I also added a unit test to check its functionality. --------- Co-authored-by: Reza Yazdani Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- csrc/fp_quantizer/includes/quantize.h | 11 ++ csrc/fp_quantizer/quantize.cpp | 31 +++++ csrc/fp_quantizer/quantize.cu | 135 +++++++++++++++++++ deepspeed/ops/fp_quantizer/quantize.py | 35 +++++ tests/unit/ops/fp_quantizer/test_fp_quant.py | 29 ++++ 5 files changed, 241 insertions(+) diff --git a/csrc/fp_quantizer/includes/quantize.h b/csrc/fp_quantizer/includes/quantize.h index 2204c1ba74fc..507252d6e722 100644 --- a/csrc/fp_quantizer/includes/quantize.h +++ b/csrc/fp_quantizer/includes/quantize.h @@ -113,3 +113,14 @@ void launch_dequantization(uint8_t* val, int q_mantisa_bits, int q_exponent_bits, cudaStream_t stream); + +template +void launch_selective_dequantization(uint8_t* val, + T* q_val, + int32_t* indexes, + int num_groups, + int group_size, + int num_indexes, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream); diff --git a/csrc/fp_quantizer/quantize.cpp b/csrc/fp_quantizer/quantize.cpp index 4a88ff767636..ec631c576e27 100644 --- a/csrc/fp_quantizer/quantize.cpp +++ b/csrc/fp_quantizer/quantize.cpp @@ -78,8 +78,39 @@ void dequantize(torch::Tensor& val, #endif } +#define DISPATCH_DEQUANTIZE_INDEX(T_TYPE, C_TYPE, mantisa) \ + if (val.options().dtype() == torch::T_TYPE) { \ + launch_selective_dequantization((uint8_t*)val_q.data_ptr(), \ + (C_TYPE*)val.data_ptr(), \ + (int32_t*)indexes.data_ptr(), \ + num_groups, \ + group_size, \ + num_indexes, \ + q_mantisa_bits, \ + q_exponent_bits, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } +void selective_dequantize(torch::Tensor& val, + torch::Tensor& val_q, + torch::Tensor& indexes, + int group_size, + int q_mantisa_bits, + int q_exponent_bits) +{ + int total_elems = at::numel(val); + int num_indexes = indexes.size(0); + int num_groups = total_elems / group_size; + + DISPATCH_DEQUANTIZE_INDEX(kHalf, __half, 10); +#ifdef BF16_AVAILABLE + DISPATCH_DEQUANTIZE_INDEX(kBFloat16, __nv_bfloat16, 7); +#endif +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("quantize", &quantize, "quantize function"); m.def("dequantize", &dequantize, "dequantize function"); + m.def("selective_dequantize", &selective_dequantize, "selective dequantize function"); } diff --git a/csrc/fp_quantizer/quantize.cu b/csrc/fp_quantizer/quantize.cu index 5f0b58f124f0..5ada6894747f 100644 --- a/csrc/fp_quantizer/quantize.cu +++ b/csrc/fp_quantizer/quantize.cu @@ -270,6 +270,7 @@ __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int mem_access::load_global( int8_data + quantization::quanitzed_access_granularity_6bits * 2, load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2); + } else { mem_access::load_global(int8_data, load_base_ptr); @@ -393,3 +394,137 @@ void launch_dequantization(uint8_t* val, INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7); #endif INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10); + +template +__global__ void apply_selective_dequantization(uint8_t* val, + T* q_val, + int32_t* indexes, + int group_size, + int total_num_elements) +{ + int index = indexes[blockIdx.x]; + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size; + int input_index = index * total_num_elements + tidx; + constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; + constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; + constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; + constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; + constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits); + const uint32_t g_index = (input_index / group_size); + const uint32_t group_size_bytes = (group_size * quantized_bits / 8); + const uint8_t* load_base_ptr = + val + g_index * (group_size_bytes + 4) + (input_index % group_size) * quantized_bits / 8; + + int mantisa_mask = ((1 << q_mantisa_bits) - 1); + mantisa_mask <<= (_mantisa_bits - q_mantisa_bits); + + T* store_base_ptr = q_val + tidx + blockIdx.x * total_num_elements; + float scale; + + uint8_t* scale_as_int8 = reinterpret_cast(&scale); + if (quantized_bits == 6) { + mem_access::load_global( + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); + mem_access::load_global( + scale_as_int8 + quantization::quanitzed_access_granularity_6bits, + val + g_index * (group_size_bytes + 4) + group_size_bytes + + quantization::quanitzed_access_granularity_6bits); + } else + mem_access::load_global( + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); + + if (tidx < total_num_elements) { + uint64_t q_buf_in; + uint64_t q_buf_in1; + uint8_t* int8_data = reinterpret_cast(&q_buf_in); + uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); + if (quantized_bits == 6) { + mem_access::load_global( + int8_data, load_base_ptr); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits, + load_base_ptr + quantization::quanitzed_access_granularity_6bits); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits * 2, + load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2); + } else { + mem_access::load_global(int8_data, + load_base_ptr); + if (quantized_bits > 4) { + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity, + load_base_ptr + quantization::quanitzed_access_granularity); + if (quantized_bits == 12) { + mem_access::load_global( + int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2); + } + } + } + T store_buf[vector_size]; + uint16_t* q_buf = reinterpret_cast(store_buf); +#pragma unroll + for (int j = 0; j < vector_size; j++) { + uint16_t new_data; + if (j < 5 || quantized_bits != 12) { + new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); + } else { + if (j == 5) { + new_data = (uint16_t)(q_buf_in1); + new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); + } else + new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); + } + + uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); + uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; + uint16_t dst_mantisa = (new_data & _mantisa_mask); + + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; + + q_buf[j] = + ((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) | + (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); + float up_cast = conversion::to(store_buf[j]); + store_buf[j] = conversion::to(up_cast * scale); + } + mem_access::store_global(store_base_ptr, store_buf); + } +} + +template +void launch_selective_dequantization(uint8_t* val, + T* q_val, + int32_t* indexes, + int num_groups, + int group_size, + int num_indexes, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream) +{ + int total_elements_per_index = (num_groups / num_indexes) * group_size; + int blocks = (total_elements_per_index - 1) / + (quantization::threads * (quantization::access_granularity / sizeof(T))) + + 1; + const dim3 grid(num_indexes, blocks); + const dim3 block(quantization::threads); + DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { + apply_selective_dequantization + <<>>(val, q_val, indexes, group_size, total_elements_per_index); + }); +} +#define INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(T, mantisa) \ + template void launch_selective_dequantization( \ + uint8_t*, T*, int32_t*, int, int, int, int, int, cudaStream_t); +// fp8(E4M3) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__nv_bfloat16, 7); +#endif +INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__half, 10); diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index 5dc3c190ae5d..0d4bf7bc6db1 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -77,3 +77,38 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out + + def selective_dequantize(self, + input_q, + indexes, + fp_out=None, + q_bits=8, + q_mantisa_bits=3, + scale=None) -> torch.Tensor: + assert (not hasattr(self, 'orig_shape') or len(self.orig_shape) == 3), \ + "Selective-Dequantization works on 3d tensor only! Please reshape the tensor before calling dequantize function." + assert (self.orig_dtype is not None), \ + "[De-quantization Error]: you need to call quantize before dequantizing!" + fp_out = torch.empty( + (indexes.shape[0], + *self.orig_shape[1:]), dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out + if q_bits == 8: + pass + elif q_bits == 12: + q_mantisa_bits = 4 + elif q_bits == 6: + q_mantisa_bits = 2 + elif q_bits == 4: + q_mantisa_bits = 1 + else: + assert (0), \ + f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" + + if scale is not None: + assert input_q.numel() == fp_out.numel(), \ + f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' + input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() + + fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits, + q_bits - q_mantisa_bits - 1) + return fp_out diff --git a/tests/unit/ops/fp_quantizer/test_fp_quant.py b/tests/unit/ops/fp_quantizer/test_fp_quant.py index 101f4cd69811..bed8bd7e3bcc 100644 --- a/tests/unit/ops/fp_quantizer/test_fp_quant.py +++ b/tests/unit/ops/fp_quantizer/test_fp_quant.py @@ -61,6 +61,35 @@ def test_fp_quant_meta(dtype): assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +def test_fp_quant_selective(dtype): + group_size = 128 + q_bits = 8 + exp_bits = 4 + man_bits = 3 + + fpq = FP_Quantize(group_size=group_size) + indexes = torch.zeros(2, dtype=torch.int32, device='cuda') + indexes[0] = 1 + indexes[1] = 3 + for i in range(10): + x = torch.rand(4, 1024, dtype=dtype, device='cuda') + + x = x.reshape(4, 1, x.shape[-1]) + ds_x = x.clone() + x_quantized = fpq.quantize(ds_x, q_bits=q_bits) + x_dequantized = fpq.selective_dequantize(x_quantized, indexes, q_bits=q_bits) + + qtorch_out = qtorch_quantize(x.index_select(0, indexes), + exp_bits=exp_bits, + man_bits=man_bits, + group_size=group_size) + qtorch_error = (qtorch_out - x.index_select(0, indexes)).abs().sum() / x.numel() + ds_error = (x_dequantized - x.index_select(0, indexes)).abs().sum() / x.numel() + + assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" + + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) @pytest.mark.parametrize("q_bits", [8, 6, 12], ids=["qbits8", "qbits6", "qbits12"]) def test_fp_quant(dtype, q_bits): From 99951caa3d2155a3bb84109a0828543793e088cc Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Fri, 19 Apr 2024 14:19:47 -0700 Subject: [PATCH 02/23] Fix sorting of shard optimizer states files for universal checkpoint (#5395) This PR resolves the issue reported in #5283. To resolve the issue, we sort files of sharded optimizer states based on DP indices. --------- Co-authored-by: Olatunji Ruwase --- deepspeed/checkpoint/ds_to_universal.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index d5eca81c804f..63fa866718de 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -132,6 +132,10 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): cnt = 0 +def dp_index_to_str(dp_index): + return f"{dp_index:0>2d}" + + def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel): global cnt # temp hack @@ -140,9 +144,8 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, os.makedirs(param_base_path, exist_ok=True) cnt += 1 - counter = f"{dp_index:0>2d}" - path = os.path.join(param_base_path, f"{state_name}.{counter}") + path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}") #print(f"{param_name}: {offset}: {numel} => {path}") @@ -156,10 +159,21 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): slices = [] for tp_index in range(tp_degree): prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") - paths = sorted(list(glob.glob(f"{prefix_path}.*"))) + paths = glob.glob(f"{prefix_path}.*") + if len(paths) == 0: continue + pattern = re.compile(f"{prefix_path}\\.([0-9]+)") + dp_indices = set() + for p in paths: + m = pattern.match(p) + if m: + dp_indices.add(int(m.group(1))) + else: + raise ValueError(f"Cannot parse dp_rank from {p}") + + paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))] shards = [torch.load(p) for p in paths] if state == "step": From 3f875d95193fbd3a0c7f0c0dcc8d39469061bb66 Mon Sep 17 00:00:00 2001 From: shiyuan680 <72335504+shiyuan680@users.noreply.github.com> Date: Sun, 21 Apr 2024 07:35:50 +0800 Subject: [PATCH 03/23] add device config env for the accelerator (#5396) Thank you for [pr](https://github.com/microsoft/DeepSpeed/pull/5369) and @delock contribution of ideas. As mentioned in this [pr](https://github.com/microsoft/DeepSpeed/pull/5369), each device has its own environmental variables. We create visible_devices_envs() and set_visible_devices_envs() methods on the accelerator class to enable each accelerator to implement env settings within the interface , which is more generic to other accelerators. this commit has tested on npu, each one has 8 ascend npus --------- Co-authored-by: yangcheng Co-authored-by: eigen2017 Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- README.md | 11 ++++++----- accelerator/abstract_accelerator.py | 8 ++++++++ accelerator/cpu_accelerator.py | 8 ++++++++ accelerator/cuda_accelerator.py | 7 +++++++ accelerator/hpu_accelerator.py | 7 +++++++ accelerator/mps_accelerator.py | 9 +++++++++ accelerator/npu_accelerator.py | 7 +++++++ accelerator/xpu_accelerator.py | 7 +++++++ deepspeed/launcher/launch.py | 11 +++++++---- 9 files changed, 66 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index a1335caa4949..201b9016f8ab 100755 --- a/README.md +++ b/README.md @@ -159,11 +159,12 @@ dynamically link them at runtime. ## Contributed HW support * DeepSpeed now support various HW accelerators. -| Contributor | Hardware | Accelerator Name | Contributor validated | Upstream validated | -| ----------- | -------- | ---------------- | --------------------- | ------------------ | -| Intel | Intel(R) Gaudi(R) 2 AI accelerator | hpu | Yes | Yes | -| Intel | Intel(R) Xeon(R) Processors | cpu | Yes | Yes | -| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | Yes | +| Contributor | Hardware | Accelerator Name | Contributor validated | Upstream validated | +|-------------|-------------------------------------|------------------| --------------------- |--------------------| +| Huawei | Huawei Ascend NPU | npu | Yes | No | +| Intel | Intel(R) Gaudi(R) 2 AI accelerator | hpu | Yes | Yes | +| Intel | Intel(R) Xeon(R) Processors | cpu | Yes | Yes | +| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | Yes | ## PyPI We regularly push releases to [PyPI](https://pypi.org/project/deepspeed/) and encourage users to install from there in most cases. diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index 3c5d799e293e..768d5ea34e5e 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -287,3 +287,11 @@ def build_extension(self): @abc.abstractmethod def export_envs(self): ... + + @abc.abstractmethod + def visible_devices_envs(self): + ... + + @abc.abstractmethod + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + ... diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index a0171723cfb8..237e7f51dcb4 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -322,3 +322,11 @@ def build_extension(self): def export_envs(self): return [] + + # TODO: cpu's visible envs is confirmed, keep as CUDA_VISIBLE_DEVICES + def visible_devices_envs(self): + return ['CUDA_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 3d5e9c168c16..2fc0cfd94125 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -360,3 +360,10 @@ def build_extension(self): def export_envs(self): return ['NCCL'] + + def visible_devices_envs(self): + return ['CUDA_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index 30b115e8b1ab..326efc8fa01b 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -294,3 +294,10 @@ def build_extension(self): def export_envs(self): return [] + + def visible_devices_envs(self): + return ['HABANA_VISIBLE_MODULES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py index 972b33caece1..ff70b860d7c7 100644 --- a/accelerator/mps_accelerator.py +++ b/accelerator/mps_accelerator.py @@ -258,3 +258,12 @@ def build_extension(self): def export_envs(self): return [] + + # TODO: mpu's visible envs is confirmed, keep as CUDA_VISIBLE_DEVICES + def visible_devices_envs(self): + # TODO: could not find visible devices env for mps + return ['CUDA_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 472157e32c02..5d891ecb707d 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -278,3 +278,10 @@ def build_extension(self): def export_envs(self): return ['ASCEND', 'HCCL', 'LD_LIBRARY', 'PATH'] + + def visible_devices_envs(self): + return ['ASCEND_RT_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index 3f65263946ab..c59f60077d2f 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -289,3 +289,10 @@ def build_extension(self): def export_envs(self): return [] + + def visible_devices_envs(self): + return ['ZE_AFFINITY_MASK'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py index ffb9cbc18fbd..079182a3b45b 100755 --- a/deepspeed/launcher/launch.py +++ b/deepspeed/launcher/launch.py @@ -22,6 +22,7 @@ from collections import defaultdict from typing import Dict from argparse import ArgumentParser, REMAINDER +from deepspeed.accelerator import get_accelerator from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT from ..nebula.constants import DLTS_POD_ENV_PATH from ..utils import logger, get_numactl_cmd @@ -146,8 +147,8 @@ def main(): node_list = list(world_info.keys()) args.nnodes = len(node_list) local_node = node_list[args.node_rank] - local_gpu_ids = world_info[local_node] - num_local_procs = len(local_gpu_ids) + local_accelerator_ids = world_info[local_node] + num_local_procs = len(local_accelerator_ids) logger.info(f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}") global_rank_mapping = defaultdict(list) @@ -161,8 +162,10 @@ def main(): curr_global_rank += 1 logger.info(f"global_rank_mapping={global_rank_mapping}") logger.info(f"dist_world_size={dist_world_size}") - current_env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, local_gpu_ids)) - logger.info(f"Setting CUDA_VISIBLE_DEVICES={current_env['CUDA_VISIBLE_DEVICES']}") + + get_accelerator().set_visible_devices_envs(current_env, local_accelerator_ids) + for env in get_accelerator().visible_devices_envs(): + logger.info(f"Setting {env}={current_env[env]}") # set PyTorch distributed related environmental variables current_env["MASTER_ADDR"] = args.master_addr From 9b6ef9e1f0d8acaefd989440b27da9069aa69207 Mon Sep 17 00:00:00 2001 From: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Tue, 23 Apr 2024 03:47:00 +0800 Subject: [PATCH 04/23] 64bit indexing fused adam (#5187) ## The Issue Applying `FusedAdam` on large tensors will cause an error `CUDA error: an illegal memory access was encountered`. https://github.com/microsoft/DeepSpeed/issues/3429 https://github.com/NVIDIA/apex/issues/1654 ## PR Content Following the solution in the apex repository (https://github.com/NVIDIA/apex/pull/1765), changing indexing type to `int64` if necessary. --------- Co-authored-by: Michael Wyatt Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- csrc/adam/multi_tensor_adam.cu | 77 +++++++++++++++++++++++--------- csrc/adam/multi_tensor_apply.cuh | 10 ++--- 2 files changed, 61 insertions(+), 26 deletions(-) diff --git a/csrc/adam/multi_tensor_adam.cu b/csrc/adam/multi_tensor_adam.cu index d6b9b2f70710..a1fc7d15aec9 100644 --- a/csrc/adam/multi_tensor_adam.cu +++ b/csrc/adam/multi_tensor_adam.cu @@ -30,7 +30,7 @@ typedef enum : int { using MATH_T = float; -template +template struct AdamFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, @@ -48,13 +48,13 @@ struct AdamFunctor { // if(*noop_gmem == 1) // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; // potentially use to pass in list of scalar // int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; T* g = (T*)tl.addresses[0][tensor_loc]; g += chunk_idx * chunk_size; @@ -71,7 +71,8 @@ struct AdamFunctor { n -= chunk_idx * chunk_size; // see note in multi_tensor_scale_kernel.cu - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + for (index_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { MATH_T r_g[ILP]; MATH_T r_p[ILP]; MATH_T r_m[ILP]; @@ -146,23 +147,57 @@ void multi_tensor_adam_cuda(int chunk_size, bias_correction2 = 1 - std::pow(beta2, step); } + size_t max_size = 0; + bool requires_64bit_indexing = false; + for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { + for (auto it2 = it->begin(); it2 != it->end(); it2++) { + if (it2->numel() > max_size) { + max_size = it2->numel(); + if (max_size >= INT_MAX) { + requires_64bit_indexing = true; + break; + } + } + } + if (requires_64bit_indexing) { break; } + } + // Assume single type across p,g,m1,m2 now - DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), - 0, - "adam", - multi_tensor_apply<4>(BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor(), - beta1, - beta2, - bias_correction1, - bias_correction2, - epsilon, - lr, - (adamMode_t)mode, - weight_decay);) + if (requires_64bit_indexing) { + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>((int64_t)BLOCK_SIZE, + (int64_t)chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t)mode, + weight_decay);) + } else { + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>(BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t)mode, + weight_decay);) + } AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/csrc/adam/multi_tensor_apply.cuh b/csrc/adam/multi_tensor_apply.cuh index 12f41cb49c6b..342376c141be 100644 --- a/csrc/adam/multi_tensor_apply.cuh +++ b/csrc/adam/multi_tensor_apply.cuh @@ -35,7 +35,7 @@ struct TensorListMetadata { }; template -__global__ void multi_tensor_apply_kernel(int chunk_size, +__global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int* noop_flag, T tl, U callable, @@ -46,8 +46,8 @@ __global__ void multi_tensor_apply_kernel(int chunk_size, } template -void multi_tensor_apply(int block_size, - int chunk_size, +void multi_tensor_apply(int64_t block_size, + int64_t chunk_size, const at::Tensor& noop_flag, const std::vector>& tensor_lists, T callable, @@ -91,9 +91,9 @@ void multi_tensor_apply(int block_size, tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); loc_tensor_info++; - int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; - for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) { // std::cout << chunks_this_tensor << std::endl; tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; tl.block_to_chunk[loc_block_info] = chunk; From c292b03a403cf0f6fd747f6106ad40060c70d3f8 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Mon, 22 Apr 2024 12:50:15 -0700 Subject: [PATCH 05/23] Improve parallel process of universal checkpoint conversion (#5343) The conversion script from a regular checkpoint to the universal one runs the followings in parallel. 1. extracts zero sharded optimizer states 2. merge the shards However, it passes `map()` a set of only a few tasks (the number specified as workers). Thus it needs to wait for the slowest tasks to finish for every set. This PR submits all the tasks to the pool and wait until the futures get ready. We can keep all workers running. --------- Co-authored-by: Olatunji Ruwase --- deepspeed/checkpoint/ds_to_universal.py | 34 ++++++++----------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index 63fa866718de..b1a8276589b6 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -10,7 +10,7 @@ import argparse import glob import itertools -import multiprocessing +from concurrent.futures import ProcessPoolExecutor import os import re import shutil @@ -292,27 +292,18 @@ def get_matched_sub_params_pattern(name_): return unmatched_patterns -def _get_chunks(l, n): - for i in range(0, len(l), n): - yield l[i:i + n] - - def _do_parallel_work(do_work, work_chunks, num_workers): + results = [] if num_workers > 1: - pool = multiprocessing.Pool(num_workers) - results = [] - for batch in tqdm.tqdm(work_chunks): - res = pool.map(do_work, batch) - results.extend(res) - pool.close() - pool.join() + with ProcessPoolExecutor(max_workers=num_workers) as executor: + future_list = [executor.submit(do_work, work) for work in work_chunks] + for f in tqdm.tqdm(future_list): + results.append(f.result()) else: # No parallel pass for unit testing # We can't create child processes in tests - results = [] - for batch in tqdm.tqdm(work_chunks): - res = [do_work(x) for x in batch] - results.extend(res) + for work in tqdm.tqdm(work_chunks): + results.append(do_work(work)) return results @@ -321,20 +312,15 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), range(ds_checkpoint.dp_degree))) #pprint(f'{_3d_range_list=}') - work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers)) - #pprint(f'{work_chunks=}') - # extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0]) do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint) - _do_parallel_work(do_work, work_chunks, args.num_extract_workers) + _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers) def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir): - work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers)) - #pprint(work_chunks) zero_output_folder = os.path.join(args.output_folder, "zero") do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree) - unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers) + unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers) # verify that all patterns were used # if a pattern was not used by any of the workers, then it was not used at all -> assert/alert From c66bc4269e9484b6e57d6f5521df02c70d399246 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 23 Apr 2024 07:27:09 +0800 Subject: [PATCH 06/23] set the default to use set_to_none for clearing gradients in BF16 optimizer. (#5434) as discussed in #5175, set the default to use set_to_none for clearing gradients in BF16 optimizer. Additionally, for the case of zero clearing, use foreach_zero. Verified correctness with mega-ds llama 7B training. FYI @loadams --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/bf16_optimizer.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index f970e582b354..1f3365b20f4e 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -341,7 +341,7 @@ def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads): # clear gradients if clear_lp_grads: - lp.grad._zero() + lp.grad.zero_() @torch.no_grad() def _update_hp_grads_func(self, clear_lp_grads=False): @@ -441,11 +441,20 @@ def clear_hp_grads(self): self.fp32_groups_has_gradients[i] = [False] * len(group) def clear_lp_grads(self): + + # using zero_() fixed memory address for graph replay + set_to_none = False if self.graph_harvesting else True + zero_grads_list = [] for group in self.bf16_groups: for param in group: - if param.grad is not None: - # Using zero_() fixed memory address for graph replay - param.grad.zero_() + if set_to_none: + param.grad = None + elif param.grad is not None: + if param.grad.grad_fn is not None: + param.grad.detach_() + zero_grads_list.append(param.grad) + if not set_to_none and len(zero_grads_list) > 0: + torch._foreach_zero_(zero_grads_list) def state_dict(self): state_dict = {} From 5e6c9b931184bf8f0a245a2fef183078139b6d7e Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 23 Apr 2024 12:24:37 -0700 Subject: [PATCH 07/23] OptimizedLinear implementation (#5355) Optimized version of `nn.Linear` that adds features such as: * LoRA w. base weight sharding * FP [6,8,12] quantization Depends on #5336 being merged first Co-authored-by: @rajhans Co-authored-by: @aurickq --------- Co-authored-by: Rajhans Samdani Co-authored-by: Jeff Rasley --- deepspeed/linear/__init__.py | 7 ++ deepspeed/linear/config.py | 39 +++++++ deepspeed/linear/optimized_linear.py | 150 +++++++++++++++++++++++++ deepspeed/linear/quantization.py | 137 ++++++++++++++++++++++ deepspeed/ops/fp_quantizer/__init__.py | 2 +- deepspeed/ops/fp_quantizer/quantize.py | 33 +++++- tests/unit/linear/test_linear.py | 128 +++++++++++++++++++++ tests/unit/linear/test_quant_param.py | 58 ++++++++++ 8 files changed, 550 insertions(+), 4 deletions(-) create mode 100644 deepspeed/linear/__init__.py create mode 100644 deepspeed/linear/config.py create mode 100644 deepspeed/linear/optimized_linear.py create mode 100644 deepspeed/linear/quantization.py create mode 100644 tests/unit/linear/test_linear.py create mode 100644 tests/unit/linear/test_quant_param.py diff --git a/deepspeed/linear/__init__.py b/deepspeed/linear/__init__.py new file mode 100644 index 000000000000..a27f1c3eaee7 --- /dev/null +++ b/deepspeed/linear/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .optimized_linear import OptimizedLinear +from .config import LoRAConfig, QuantizationConfig diff --git a/deepspeed/linear/config.py b/deepspeed/linear/config.py new file mode 100644 index 000000000000..ae9050a3c92b --- /dev/null +++ b/deepspeed/linear/config.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from dataclasses import dataclass + + +@dataclass +class LoRAConfig: + """ + Configuration settings for LoRAOptimizedLinear. + + Attributes: + lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64. + lora_alpha (float): LoRA scaling factor, default is 16. + base_weight_sharding (int): The degree to which the base weights are sharded, + should typically be set to the data-parallel world size to maximize the memory + reduction benefits. Defaults to 1, which means this feature is disabled. + """ + lora_r: int = 64 + lora_alpha: float = 16. + base_weight_sharding: int = 1 + + +@dataclass +class QuantizationConfig: + """ + Configuration settings for quantization for LoRAOptimizedLinear, QuantizedLinear, + and QuantizedParameter + + Attributes: + q_bits (int): The number of bits used for quantization. Default is 8. + mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3. + group_size (int): The size of the group used for quantization. Default is 512. + """ + q_bits: int = 8 + mantissa_bits: int = 3 + group_size: int = 512 diff --git a/deepspeed/linear/optimized_linear.py b/deepspeed/linear/optimized_linear.py new file mode 100644 index 000000000000..138bd493ffc7 --- /dev/null +++ b/deepspeed/linear/optimized_linear.py @@ -0,0 +1,150 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import is_dataclass +from deepspeed.accelerator import get_accelerator +import deepspeed.comm as dist + +from .config import LoRAConfig, QuantizationConfig +from .quantization import QuantizedParameter, QuantizedLinear + + +class OptimizedLinear(nn.Module): + """ + Optimized version of nn.Linear that adds features such as: + * LoRA w. base weight sharding + * FP [6,8,12] quantization + + Arguments: + input_dim: Required: size of each input sample + output_dim: Required: size of each output sample + bias: Optional: If set to False, the layer will not learn an additive bias. Default: False + lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree + quantization_config: Optional: QuantizationConfig defining quantization features + dtype: Optional: parameter dtype, only supports bfloat16 currently + + Returns: + Returns a new nn.Module depending on the input config. Either native + torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear. + """ + + def __new__(self, + input_dim: int, + output_dim: int, + bias: bool = False, + lora_config: LoRAConfig = None, + quantization_config: QuantizationConfig = None, + dtype=torch.bfloat16): + + if quantization_config is not None and not is_dataclass(quantization_config): + raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}") + if lora_config is not None and not is_dataclass(lora_config): + raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}") + if lora_config is None and quantization_config is None: + # Everything disabled, fall back to normal nn.Linear + self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype) + + elif lora_config: + # lora enabled, quantization may or may not be + self = LoRAOptimizedLinear(input_dim=input_dim, + output_dim=output_dim, + bias=bias, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=dtype) + + elif quantization_config: + # only quantization enabled, no lora + self = QuantizedLinear(input_dim=input_dim, + output_dim=output_dim, + bias=bias, + quantization_config=quantization_config, + dtype=dtype) + return self + + +class LoRAOptimizedLinear(nn.Module): + + def __init__(self, + input_dim: int, + output_dim: int, + bias: bool = False, + lora_config: LoRAConfig = None, + quantization_config: QuantizationConfig = None, + device=None, + dtype=torch.bfloat16): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.bias = bias + self.lora_config = lora_config + self.quantization_config = quantization_config + device = get_accelerator().current_device() if device is None else device + assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" + + self.zero_shards = self.lora_config.base_weight_sharding + self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards) + w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype)) + torch.nn.init.xavier_uniform_(w) + + if self.quantization_config is not None: + assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization" + self.base_weight = QuantizedParameter(w, quantization_config=quantization_config) + else: + self.base_weight = w + + self.base_weight.requires_grad = False + + # Use RS lora for now. + self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r) + # Keeping lora weights in bf16 precision for ease of training. + self.lora_weight_1 = nn.Linear(self.input_dim, + self.lora_config.lora_r, + bias=self.bias, + device=device, + dtype=dtype) + self.lora_weight_2 = nn.Linear(self.lora_config.lora_r, + self.output_dim, + bias=self.bias, + device=device, + dtype=dtype) + self.lora_weight_1.weight.requires_grad = True + self.lora_weight_2.weight.requires_grad = True + + def full_weight(self): + # This assumes weights are evenly sharded across gpus. which might not be correct. + # in that case, we should flatten before all_gather. + local_weight = self.base_weight.dequantized() if isinstance(self.base_weight, + QuantizedParameter) else self.base_weight + tensor_list = [ + torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype) + for _ in range(self.zero_shards) + ] + dist.all_gather(tensor_list, local_weight) + weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1)) + return weight + + def linear_without_F_linear(self, input, weight): + output = torch.mm(input.reshape(-1, input.shape[-1]), weight) + output = output.view(*input.shape[:-1], weight.shape[1]) + return output + + def forward(self, input_tensor): + # Gather the sharded base weight + if self.zero_shards > 1: + with torch.no_grad(): + base_weight = self.full_weight() + elif self.quantization_config: + base_weight = self.base_weight.dequantized() + else: + base_weight = self.base_weight + + base_weight_output = F.linear(input_tensor, base_weight) + lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor)) + return base_weight_output + self.lora_scaling_factor * lora_output diff --git a/deepspeed/linear/quantization.py b/deepspeed/linear/quantization.py new file mode 100644 index 000000000000..f5343af45fb8 --- /dev/null +++ b/deepspeed/linear/quantization.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional + +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.fp_quantizer import Quantizer, FP_Quantize +from .config import QuantizationConfig + + +class QuantizedParameter(nn.Parameter): + """ + Quantized parameter class that implements weight quantization. Weights + are stored in quantized form on GPUs, and can be dequantized on-the-fly when + needed by the model. The weights are actually quantized during any `.to(device)`. + + Arguments: + data (Tensor): parameter tensor. + requires_grad (bool, optional): if the parameter requires gradient. Defaults + to False and is not supported to be True. Argument provided only for interface + compatibility with torch.nn.Parameter. + quantization_config (QuantizationConfig, optional): + quantizer (Quantizer, optional): Defaults to FP_Quantize but can be any quantizer + that implements deepspeed.ops.fp_quantizer.Quantizer. This argument is also + required since the quantizer is stashed in the Parameter itself, some models + may clone the Parameter by passing an attribute __dict__. For an example, see + tests/unit/linear/test_quant_param.py::TestQuantParam::test_hf_clone + """ + + def __new__( + cls, + data: Optional[torch.Tensor] = None, + requires_grad: bool = False, # quantized weights must be frozen + quantization_config: QuantizationConfig = None, + quantizer: Quantizer = None, + ): + if requires_grad: + raise ValueError(f"requires_grad=True is not supported with QuantizedParameter") + if data is None: + data = torch.empty(0) + self = torch.Tensor._make_subclass(cls, data, requires_grad) + self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config + if quantizer is not None: + self.quantizer = quantizer + else: + # if FPQuantizerBuilder is not compatible in this env this init will fail + self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size) + self._ensure_quantized(self) + return self + + def _ensure_quantized(self, tensor: torch.Tensor): + # If the tensor is on the accelerator and is not quantized, then quantize it in-place. + if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.int8: + with get_accelerator().stream(get_accelerator().current_stream(tensor.device)): + tensor.data = self.quantizer.quantize(tensor.data, + q_bits=self.quantization_config.q_bits, + q_mantisa_bits=self.quantization_config.mantissa_bits) + assert tensor.dtype == torch.int8 + + def dequantized(self) -> torch.Tensor: + """ + Return a tensor containing the dequantized weights of this parameter. + """ + if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.int8: + with get_accelerator().stream(get_accelerator().current_stream(self.data.device)): + return self.quantizer.dequantize(self.data, + q_bits=self.quantization_config.q_bits, + q_mantisa_bits=self.quantization_config.mantissa_bits) + return self.data + + def __getstate__(self): + state = self.__dict__ + state["data"] = self.data + state["quantization_config"] = self.quantization_config + state["requires_grad"] = self.requires_grad + return state + + def __setstate__(self, state): + self.quantizer = state["quantizer"] + self.quantization_config = state["quantization_config"] + self.data = state["data"] + self.requires_grad = state["requires_grad"] + + def __deepcopy__(self, memo): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + new_instance.quantizer = copy.deepcopy(state["quantizer"]) + new_instance.quantization_config = copy.deepcopy(state["quantization_config"]) + new_instance.data = copy.deepcopy(state["data"]) + return new_instance + + def __copy__(self): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + + def cuda(self, device=None, non_blocking=False): + return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) + + def to(self, *args, **kwargs): + """ + Move the parameter to the given device. Then, if the device is a cuda device, + quantize it. + """ + tensor = super().to(*args, **kwargs) + self._ensure_quantized(tensor) + return tensor + + +class QuantizedLinear(nn.Linear): + """ + Linear layer that implements weight quantization. Parameters + are stored via `QuantizedParameter` and are dequantized on-the-fly during any + forward pass. + """ + + def __init__(self, + input_dim: int, + output_dim: int, + bias: bool = False, + quantization_config: QuantizationConfig = None, + dtype=torch.bfloat16): + super().__init__(input_dim, output_dim, bias=bias, dtype=dtype) + assert dtype == torch.bfloat16, "currently only supports bfloat16 dtype" + self.weight = QuantizedParameter(self.weight.data, quantization_config=quantization_config) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.dequantized(), self.bias) diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py index 5575f3567185..995bbae4aeaf 100644 --- a/deepspeed/ops/fp_quantizer/__init__.py +++ b/deepspeed/ops/fp_quantizer/__init__.py @@ -3,4 +3,4 @@ # DeepSpeed Team -from .quantize import FP_Quantize +from .quantize import FP_Quantize, Quantizer diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index 0d4bf7bc6db1..f8435bda16c1 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -4,20 +4,47 @@ # DeepSpeed Team import torch +import abc +from abc import ABC from deepspeed.ops.op_builder import FPQuantizerBuilder fp_quant_module = None -class FP_Quantize: +class Quantizer(ABC): + """ + Abstract Quantizer class that implmenents quantize/dequantize methods. + + Arguments: + group_size (int, optional): number of values or elements that are grouped + together for the quantization process. + """ + + def __init__(self, group_size=512) -> None: + self.group_size = group_size + + @abc.abstractmethod + def quantize(self, + input, + q_bits=8, + q_mantisa_bits=3, + stochastic_mode=False, + return_meta_tensor=False) -> torch.Tensor: + ... + + @abc.abstractmethod + def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor: + ... + + +class FP_Quantize(Quantizer): def __init__(self, group_size=512) -> None: global fp_quant_module + super().__init__(group_size=group_size) if fp_quant_module is None: fp_quant_module = FPQuantizerBuilder().load() - - self.group_size = group_size self.orig_dtype = None def quantize(self, diff --git a/tests/unit/linear/test_linear.py b/tests/unit/linear/test_linear.py new file mode 100644 index 000000000000..ccd26b4cd726 --- /dev/null +++ b/tests/unit/linear/test_linear.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +import deepspeed.comm as dist + +from deepspeed.accelerator import get_accelerator +from deepspeed.linear import OptimizedLinear, LoRAConfig, QuantizationConfig +from unit.common import DistributedTest + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + + +class TestBasicLinear(DistributedTest): + world_size = 2 + + def test(self): + lora_config = None + quantization_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 1 # Number of samples in a batch + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + + dummy_input = torch.rand(batch_size, input_features, dtype=torch.bfloat16) + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) + + +@pytest.mark.parametrize("base_weight_sharding", [1, 2]) +class TestLoRALinear(DistributedTest): + world_size = 2 + + def test(self, base_weight_sharding): + rank = dist.get_rank() + lora_config = None + quantization_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 5 # Number of samples in a batch + + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding) + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + device = get_accelerator().current_device_name() + linear_layer = linear_layer.to(device) + if rank == 0: + for n, p in linear_layer.named_parameters(): + print(f"{n}, {p.shape}") + + dummy_input = torch.rand(batch_size, input_features, device=device, dtype=torch.bfloat16) + + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) + + +@pytest.mark.parametrize("q_bits", [8, 6]) +class TestQuantLinear(DistributedTest): + world_size = 2 + + def test(self, q_bits): + rank = dist.get_rank() + lora_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 5 # Number of samples in a batch + + lora_config = None + quantization_config = QuantizationConfig(q_bits=q_bits) + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + device = get_accelerator().current_device_name() + linear_layer = linear_layer.to(device) + dummy_input = torch.rand([batch_size, input_features], device=device, dtype=torch.bfloat16) + + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) + + +@pytest.mark.parametrize("base_weight_sharding", [1, 2], ids=['bws1', 'bws2']) +@pytest.mark.parametrize("q_bits", [8, 6], ids=['qbit8', 'qbit6']) +class TestOptimizedLinear(DistributedTest): + world_size = 2 + + def test(self, base_weight_sharding, q_bits): + rank = dist.get_rank() + lora_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 5 # Number of samples in a batch + + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding) + quantization_config = QuantizationConfig(q_bits=q_bits) + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + device = get_accelerator().current_device_name() + linear_layer = linear_layer.to(device) + dummy_input = torch.rand([batch_size, input_features], device=device, dtype=torch.bfloat16) + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) diff --git a/tests/unit/linear/test_quant_param.py b/tests/unit/linear/test_quant_param.py new file mode 100644 index 000000000000..9479b3cba8a0 --- /dev/null +++ b/tests/unit/linear/test_quant_param.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed + +from deepspeed.accelerator import get_accelerator +from deepspeed.linear.quantization import QuantizedParameter +from deepspeed.linear.config import QuantizationConfig + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +from unit.common import DistributedTest + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + + +class TestQuantParam(DistributedTest): + world_size = 1 + + @pytest.mark.parametrize('dtype', [torch.half, torch.float]) + def test_unsupported_dtypes(self, dtype): + device = get_accelerator().current_device_name() + data = torch.rand(5, 5, device='cpu', dtype=dtype) + qp = QuantizedParameter(data) + with pytest.raises(AssertionError): + qp.to(device) + + def test_requires_grad(self): + data = torch.rand(5, 5, dtype=torch.bfloat16) + with pytest.raises(ValueError): + QuantizedParameter(data, requires_grad=True) + + def test_move_to_accelerator(self): + device = get_accelerator().current_device() + data = torch.rand(5, 5, device='cpu', dtype=torch.bfloat16) + qp = QuantizedParameter(data) + assert qp.device == torch.device('cpu') + qp = qp.to(get_accelerator().current_device_name()) + assert qp.device == torch.device(device) + assert qp.dtype == torch.int8 + + def test_hf_clone(self): + device = get_accelerator().current_device_name() + data = torch.rand(5, 5, device=device, dtype=torch.bfloat16) + + quantization_config = QuantizationConfig(q_bits=6) + qp = QuantizedParameter(data, quantization_config=quantization_config) + + # should be able to clone parameter via dict, HF expects this to work + qp_copy = QuantizedParameter(qp.data, **qp.__dict__) + + assert all(qp.data == qp_copy.data) + assert qp.quantization_config == qp_copy.quantization_config From ad2027952f9730cbd1a8385e4a441e470248645e Mon Sep 17 00:00:00 2001 From: Jhonso7393 <167781426+Jhonso7393@users.noreply.github.com> Date: Tue, 23 Apr 2024 23:45:47 +0300 Subject: [PATCH 08/23] Update README.md (#5453) Fixing a minor typo at the README file Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 201b9016f8ab..f9d81eddfdae 100755 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat). * [2024/03] [DeepSpeed-FP6:The power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)] -* [2024/01] [DeepSpeed-FastGen: Introducting Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19) +* [2024/01] [DeepSpeed-FastGen: Introducing Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19) * [2023/11] [Llama 2 Inference on 4th Gen Intel® Xeon® Scalable Processor with DeepSpeed](https://github.com/microsoft/DeepSpeed/tree/master/blogs/intel-inference) [[Intel version]](https://www.intel.com/content/www/us/en/developer/articles/technical/xllama-2-on-xeon-scalable-processor-with-deepspeed.html) * [2023/11] [DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-offloadpp) * [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/japanese/README.md)] From 5f631abc2f930ecece38fae05dc9bd3923c555dd Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Tue, 23 Apr 2024 16:24:12 -0700 Subject: [PATCH 09/23] Update PyTest torch version to match PyTorch latest official (2.3.0) (#5454) --- .github/workflows/cpu-torch-latest.yml | 4 ++-- .github/workflows/nv-torch-latest-v100.yml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cpu-torch-latest.yml b/.github/workflows/cpu-torch-latest.yml index 9c1ad02f75a6..5727ff2e1cde 100644 --- a/.github/workflows/cpu-torch-latest.yml +++ b/.github/workflows/cpu-torch-latest.yml @@ -50,5 +50,5 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.2" - TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.2" + TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.3" + TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.3" diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index 3ca8ac43dfa4..2e0490c18ba7 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -29,7 +29,7 @@ jobs: - name: Install pytorch run: | - pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu118 + pip install -U --cache-dir $TORCH_CACHE torch==2.2.2 torchvision --index-url https://download.pytorch.org/whl/cu118 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" From fbdf0eaf15f45cd2a8b5846ad9428609a3860b41 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 23 Apr 2024 16:27:27 -0700 Subject: [PATCH 10/23] Update version.txt after 0.14.2 release (#5458) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.14.2 Author - @loadams Co-authored-by: loadams --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index e867cc2a66a8..ac4a79626c87 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.14.2 +0.14.3 From fa8458b1a80d6ba55091b17f092de19bbf95eb3d Mon Sep 17 00:00:00 2001 From: vikram singh shekhawat Date: Wed, 24 Apr 2024 20:55:18 +0530 Subject: [PATCH 11/23] Add getter and setter methods for compile_backend across accelerators. (#5299) Add getter and setter methods for `compile_backend` across accelerators, which provide a mechanism to retrieve the compile backend. These APIs handle user-defined backend selection and raise a `ValueError` with informative error messages for unsupported backends. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- accelerator/abstract_accelerator.py | 9 +++++++++ accelerator/cpu_accelerator.py | 12 ++++++++++++ accelerator/cuda_accelerator.py | 12 ++++++++++++ accelerator/hpu_accelerator.py | 12 ++++++++++++ accelerator/mps_accelerator.py | 12 ++++++++++++ accelerator/npu_accelerator.py | 12 ++++++++++++ accelerator/xpu_accelerator.py | 12 ++++++++++++ tests/unit/runtime/compile/test_compile_wrapper.py | 4 +--- tests/unit/runtime/compile/test_compile_zero.py | 4 +--- tests/unit/runtime/compile/test_load_config.py | 4 +--- 10 files changed, 84 insertions(+), 9 deletions(-) diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index 768d5ea34e5e..690f19292299 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -12,6 +12,7 @@ class DeepSpeedAccelerator(ABC): def __init__(self): self._name = None self._communication_backend_name = None + self._compile_backend = None @abc.abstractmethod def is_synchronized_device(self): @@ -295,3 +296,11 @@ def visible_devices_envs(self): @abc.abstractmethod def set_visible_devices_envs(self, current_env, local_accelerator_ids): ... + + @abc.abstractmethod + def get_compile_backend(self): + ... + + @abc.abstractmethod + def set_compile_backend(self, backend): + ... diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 237e7f51dcb4..bd11d034f312 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -20,6 +20,7 @@ class CPU_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'cpu' + self._compile_backend = "inductor" if oneccl_imported_p: self._communication_backend_name = 'ccl' else: @@ -330,3 +331,14 @@ def visible_devices_envs(self): def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 2fc0cfd94125..60d66b6cdbab 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -25,6 +25,7 @@ class CUDA_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'cuda' self._communication_backend_name = 'nccl' + self._compile_backend = "inductor" if pynvml is None: self._init_pynvml() @@ -367,3 +368,14 @@ def visible_devices_envs(self): def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index 326efc8fa01b..114f367e879d 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -16,6 +16,7 @@ class HPU_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'hpu' self._communication_backend_name = 'hccl' + self._compile_backend = "hpu_backend" try: import habana_frameworks.torch.hpu as hpu hpu.setDeterministic(True) @@ -301,3 +302,14 @@ def visible_devices_envs(self): def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py index ff70b860d7c7..5fc9b1c8cfb6 100644 --- a/accelerator/mps_accelerator.py +++ b/accelerator/mps_accelerator.py @@ -20,6 +20,7 @@ class MPS_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = "mps" self._communication_backend_name = None + self._compile_backend = "inductor" def is_synchronized_device(self): return False @@ -267,3 +268,14 @@ def visible_devices_envs(self): def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 5d891ecb707d..b0e0ff948e52 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -20,6 +20,7 @@ def __init__(self): super().__init__() self._name = 'npu' self._communication_backend_name = 'hccl' + self._compile_backend = "inductor" # dict that holds class name <--> class type mapping i.e. # 'AsyncIOBuilder': # this dict will be filled at init stage @@ -285,3 +286,14 @@ def visible_devices_envs(self): def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }") diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index c59f60077d2f..9c4a9c903f96 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -14,6 +14,7 @@ class XPU_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'xpu' self._communication_backend_name = 'ccl' + self._compile_backend = "inductor" self.aligned_tensors = [] def is_synchronized_device(self): @@ -296,3 +297,14 @@ def visible_devices_envs(self): def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/tests/unit/runtime/compile/test_compile_wrapper.py b/tests/unit/runtime/compile/test_compile_wrapper.py index d1830534f6ea..62af25ac3ba4 100644 --- a/tests/unit/runtime/compile/test_compile_wrapper.py +++ b/tests/unit/runtime/compile/test_compile_wrapper.py @@ -31,11 +31,9 @@ def base_config(): }, "compile": { "enabled": True, - "backend": "inductor" + "backend": get_accelerator().get_compile_backend() } } - if get_accelerator().device_name() == 'hpu': - config_dict['compile']['backend'] = 'hpu_backend' return config_dict diff --git a/tests/unit/runtime/compile/test_compile_zero.py b/tests/unit/runtime/compile/test_compile_zero.py index 7568c27e3ed2..9890ea708eec 100644 --- a/tests/unit/runtime/compile/test_compile_zero.py +++ b/tests/unit/runtime/compile/test_compile_zero.py @@ -51,12 +51,10 @@ def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device): }, "compile": { "enabled": True, - "backend": "inductor" + "backend": get_accelerator().get_compile_backend() } } - if get_accelerator().device_name() == 'hpu': - config_dict['compile']['backend'] = 'hpu_backend' if offload_device == OffloadDeviceEnum.cpu: config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device} elif offload_device == OffloadDeviceEnum.nvme: diff --git a/tests/unit/runtime/compile/test_load_config.py b/tests/unit/runtime/compile/test_load_config.py index 601adae58884..cee8d3b23f6b 100644 --- a/tests/unit/runtime/compile/test_load_config.py +++ b/tests/unit/runtime/compile/test_load_config.py @@ -47,12 +47,10 @@ def base_config(): }, "compile": { "enabled": True, - "backend": "inductor" + "backend": get_accelerator().get_compile_backend() } } - if get_accelerator().device_name() == 'hpu': - config_dict['compile']['backend'] = 'hpu_backend' return config_dict From fcc731f09d8e09b04a816b3ea0f83ab1d15169b3 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Thu, 25 Apr 2024 11:01:35 -0700 Subject: [PATCH 12/23] Fix torch.compile error for PyTorch v2.3 (#5463) PyTorch v2.3 throws an error when it tries to compile `iter_params` used for ZeRO3. This PR excludes the function from the compilation targets. After this PR is merged, we can [unpin the torch version for unit tests](https://github.com/microsoft/DeepSpeed/pull/5459). --- deepspeed/runtime/zero/partitioned_param_coordinator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 8fc962c4f2a7..bdec8a55fcbc 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -34,6 +34,7 @@ def get_all_parameters(sub_module, recurse=False): return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) +@compiler.disable def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: return map(lambda pair: pair[1], get_all_parameters(module, recurse)) From bc48371c5e1fb8fd70fc79285e66201dbb65679b Mon Sep 17 00:00:00 2001 From: Lev Kurilenko <113481193+lekurile@users.noreply.github.com> Date: Thu, 25 Apr 2024 11:37:15 -0700 Subject: [PATCH 13/23] Revert "stage3: efficient compute of scaled_global_grad_norm (#5256)" (#5461) This reverts commit 54c06872647ca60699f752e60ac1643bd05aa63c due to #5256 causing bugs when the ZeRO3 + ZeRO Offload features are enabled. This bug was discovered due to failures in the DS Chat CI workflow. Failing tests across CI failures: | Failing Test Name | | --- | | test_ds_chat[zero3--offload-] | | test_ds_chat[zero3--offload-lora] | | test_ds_chat[zero3-he-offload-] | | test_ds_chat[zero3-he-offload-lora] | Error message: ``` RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cpu! ``` It seems that `torch.stack()` or `torch.norm()` is having issues when the offload feature is enabled and tensors are split between CPU/GPU, however this is just an initial guess and would require more investigation. @nelyahu Since you are the original author of the PR, if you have some bandwidth, any help here is greatly appreciated! After reverting this commit, all tests pass in the DS Chat CI workflow: https://github.com/microsoft/DeepSpeed/actions/runs/8824064414/job/24225802763 @tjruwase for context. --- deepspeed/runtime/zero/stage3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 68cab13c4a93..c6ff216edfcb 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -15,7 +15,7 @@ from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce -from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item +from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -2027,7 +2027,7 @@ def step(self, closure=None): return norm_groups = self._get_norm_groups() - scaled_global_grad_norm = torch.norm(torch.stack(norm_groups)) + scaled_global_grad_norm = get_global_norm(norm_list=norm_groups) # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.loss_scale From e37296b23c1df99dfba508a678d9f310b22aeeaa Mon Sep 17 00:00:00 2001 From: Lev Kurilenko <113481193+lekurile@users.noreply.github.com> Date: Thu, 25 Apr 2024 13:36:46 -0700 Subject: [PATCH 14/23] Update ds-chat CI workflow paths to include zero stage 1-3 files (#5462) This PR updates the ds-chat CI workflow to run when ZeRO stage 1-3 files are updated. --- .github/workflows/nv-ds-chat.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml index 94571eb101bb..cf8756fbd528 100644 --- a/.github/workflows/nv-ds-chat.yml +++ b/.github/workflows/nv-ds-chat.yml @@ -10,6 +10,10 @@ on: required: false default: 'master' type: string + pull_request: + paths: + - "deepspeed/runtime/zero/stage_1_and_2.py" + - "deepspeed/runtime/zero/stage3.py" concurrency: group: ${{ github.workflow }}-${{ github.ref }} From 4c15ad9f8d51a1950842c69bbbc9d93c73afbcfc Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Thu, 25 Apr 2024 14:44:39 -0700 Subject: [PATCH 15/23] Update with ops not supported on Windows (#5468) --- build_win.bat | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/build_win.bat b/build_win.bat index ec8c8a362a78..af5c5103fa4b 100644 --- a/build_win.bat +++ b/build_win.bat @@ -1,6 +1,10 @@ @echo off set DS_BUILD_AIO=0 +set DS_BUILD_CUTLASS_OPS=0 +set DS_BUILD_EVOFORMER_ATTN=0 +set DS_BUILD_FP_QUANTIZER=0 +set DS_BUILD_RAGGED_DEVICE_OPS=0 set DS_BUILD_SPARSE_ATTN=0 echo Administrative permissions required. Detecting permissions... From 059bb2085cf404caa5874004a252a56fc74c952e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ant=C3=B4nio=20Vieira?= Date: Mon, 29 Apr 2024 14:37:54 -0300 Subject: [PATCH 16/23] fix: swapping order of parameters in create_dir_symlink method. (#5465) Order of parameters in create_dir_symlink method looks wrong. Because this we get the error "PermissionError: [WinError 5] Denied access: '.\\deepspeed\\ops\\csrc'" when install deepspeed >= 0.4.0 on Windows enviroment. Please check this out @eltonzheng and @jeffra. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index f1367b850e02..839941b989c9 100755 --- a/setup.py +++ b/setup.py @@ -219,9 +219,9 @@ def create_dir_symlink(src, dest): if sys.platform == "win32": # This creates a symbolic links on Windows. # It needs Administrator privilege to create symlinks on Windows. - create_dir_symlink('..\\..\\csrc', '.\\deepspeed\\ops\\csrc') - create_dir_symlink('..\\..\\op_builder', '.\\deepspeed\\ops\\op_builder') - create_dir_symlink('..\\accelerator', '.\\deepspeed\\accelerator') + create_dir_symlink('.\\deepspeed\\ops\\csrc', '..\\..\\csrc') + create_dir_symlink('.\\deepspeed\\ops\\op_builder', '..\\..\\op_builder') + create_dir_symlink('.\\deepspeed\\accelerator', '..\\accelerator') egg_info.manifest_maker.template = 'MANIFEST_win.in' # Parse the DeepSpeed version string from version.txt. From f32ad3e1c562be80bd8be6b7b2246dc3041b5bfd Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Mon, 29 Apr 2024 16:39:12 -0700 Subject: [PATCH 17/23] Un-pin torch version in nv-torch-latest back to latest and skip test_compile_zero tests on v100 (#5459) Torch updating to 2.3.0 broke some test_compile_zero tests, we pinned it, @tohtana pushed fixes in #5463, this should un-pin and move us back to the latest. Failing test that indicates the generated code cannot run bf16 on V100 [here](https://github.com/microsoft/DeepSpeed/actions/runs/8838672379/job/24270349996?pr=5459#step:8:5157). --- .github/workflows/nv-torch-latest-v100.yml | 6 +++--- tests/unit/runtime/compile/test_compile_zero.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index 2e0490c18ba7..3109f6060944 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -29,7 +29,7 @@ jobs: - name: Install pytorch run: | - pip install -U --cache-dir $TORCH_CACHE torch==2.2.2 torchvision --index-url https://download.pytorch.org/whl/cu118 + pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu118 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -55,5 +55,5 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.2" --cuda_ver="11.8" - pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.2" --cuda_ver="11.8" + pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.3" --cuda_ver="11.8" + pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.3" --cuda_ver="11.8" diff --git a/tests/unit/runtime/compile/test_compile_zero.py b/tests/unit/runtime/compile/test_compile_zero.py index 9890ea708eec..a0736b0f5425 100644 --- a/tests/unit/runtime/compile/test_compile_zero.py +++ b/tests/unit/runtime/compile/test_compile_zero.py @@ -12,7 +12,7 @@ from unit.runtime.compile.util import compare_loss from unit.common import DistributedTest -from unit.util import bf16_required_version_check +from unit.util import bf16_required_version_check, skip_on_arch pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1), reason="Compile tests requires Pytorch version 2.1 or above") @@ -26,9 +26,11 @@ class TestZeRO(DistributedTest): @pytest.mark.parametrize('zero_stage', [1, 2, 3]) @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device): + if dtype == torch.bfloat16: + skip_on_arch(min_arch=8) if dtype == torch.bfloat16 and not bf16_required_version_check(): pytest.skip( - " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + "DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly" ) if get_accelerator().device_name() == "cpu": pytest.skip("CPU does not support this test yet") From 90793aab546beedfd805eddd8fd2d807b3862d53 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Fri, 3 May 2024 23:22:29 +0300 Subject: [PATCH 18/23] re-introduce: stage3: efficient compute of scaled_global_grad_norm (#5493) reverting previous revert of this feature: https://github.com/nelyahu/DeepSpeed/commit/bc48371c5e1fb8fd70fc79285e66201dbb65679b in addition, bug fix for offload mode. --- deepspeed/runtime/zero/stage3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index c6ff216edfcb..13ca29c9fceb 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1409,7 +1409,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): norm_is_nan = total_norm.isnan() inf_or_nan = norm_is_nan.logical_or(norm_is_inf) - err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float) total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm return total_norm From 0fc19b6a320cf8aa0a5f6c2b1fa310bae9a70d94 Mon Sep 17 00:00:00 2001 From: harygo2 <168085522+harygo2@users.noreply.github.com> Date: Tue, 7 May 2024 08:05:54 +0800 Subject: [PATCH 19/23] Fix crash when creating Torch tensor on NPU with device=get_accelerator().current_device() (#5464) Creating a Torch tensor with the parameter `device=get_accelerator().current_device()` can result in a crash when using an NPU. This issue arises because the `current_device` API across all accelerators is expected to return a device id as an integer, according to the [interface docs.](https://github.com/microsoft/DeepSpeed/blob/fa8458b1a80d6ba55091b17f092de19bbf95eb3d/docs/_tutorials/accelerator-abstraction-interface.md?plain=1#L52C1-L56C103) However, specifying `device` as an interger when creating tensors by default directs Torch to use the CUDA backend, which leads to crash on NPUs (and potentially other accelerators as well). To resolve this, we should use `get_accelerator().current_device_name()` instead, which returns the correct device identifier strings such as `"npu:0", "cuda:0", or "xpu:0"`. This API provides the appropriate context needed for creating tensors on specific hardware accelerators. I also notice that `device=get_accelerator().current_device()` is used across several files under deepspeed/inference, and may also lead to crash on other accelerators. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/linear/optimized_linear.py | 2 +- deepspeed/runtime/fp16/fused_optimizer.py | 2 +- deepspeed/runtime/utils.py | 8 +++++--- tests/unit/moe/test_moe.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/deepspeed/linear/optimized_linear.py b/deepspeed/linear/optimized_linear.py index 138bd493ffc7..e982785a8122 100644 --- a/deepspeed/linear/optimized_linear.py +++ b/deepspeed/linear/optimized_linear.py @@ -85,7 +85,7 @@ def __init__(self, self.bias = bias self.lora_config = lora_config self.quantization_config = quantization_config - device = get_accelerator().current_device() if device is None else device + device = get_accelerator().current_device_name() if device is None else device assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" self.zero_shards = self.lora_config.base_weight_sharding diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index bf1693307ea7..49093bb73c8f 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -241,7 +241,7 @@ def _get_norm_mask_idx(self, group): group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) grad_flat_st_idx = grad_flat_en_idx - return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device()) + return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device_name()) def step(self, closure=None): """ diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 7744b2ee8b98..2c01c3475a70 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -171,7 +171,7 @@ def get_norm_with_moe_layers_fast(all_groups_norm, group): # This implementation standardizes the grad_norm across ranks. A more precise implementation can be found in 'get_norm_with_moe_layers'. # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=group)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device(), dtype=torch.float) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device_name(), dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=group) all_groups_norm = scaled_norm_tensor.item() #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") @@ -424,9 +424,11 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No # # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool) # # for mask_idx in grad_norm_mask[idx]: # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True - cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(), + cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device_name(), dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) - mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype) + mask_tensor = torch.zeros(p.shape[0] + 1, + device=get_accelerator().current_device_name(), + dtype=p.dtype) mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1] diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index d39f9fe3d651..fdff9430a4e6 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -177,7 +177,7 @@ class TestTopk(DistributedTest): world_size = 2 def test(self): - device = get_accelerator().current_device() + device = get_accelerator().current_device_name() if dist.get_rank() == 0: logits = torch.rand(2, 2, device=device) elif dist.get_rank() == 1: From 0b224edcf7d83713b95ad6b989694a8bdf01809e Mon Sep 17 00:00:00 2001 From: BacharL Date: Wed, 8 May 2024 12:53:25 +0300 Subject: [PATCH 20/23] Fix compile wrapper (#5455) compile wrapper will inherit from user module class and copy it's __dict__ This should resolve most issues in #5383 except potential extra user forward hooks. @tohtana @loadams Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> --- deepspeed/runtime/compiler.py | 163 ++++++++++++++++--------------- deepspeed/runtime/engine.py | 7 -- deepspeed/runtime/pipe/engine.py | 4 +- 3 files changed, 83 insertions(+), 91 deletions(-) diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index b5e4e33425d0..66fe29fbbea2 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -83,84 +83,85 @@ def validate_enabled(cls, field_value, values): return field_value -class CompiledModuleWrapper(torch.nn.Module): - - def __init__(self, module, compile_config: Union[CompileConfig, None] = None): - super().__init__() - - assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." - - modules = self.__dict__.get('_modules') - modules['wrapped'] = module - self.__dict__['wrapped'] = module - self._is_compiled = False - self._backend = get_backend_fn(compile_config.backend) - self._compile_kwargs = compile_config.kwargs - self._compiler_fn = None - - def __getattr__(self, name): - return getattr(self.__dict__['wrapped'], name) - - def set_backend(self, backend: Union[str, Callable]): - """Set the backend for torch.compile. - - Args: - backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. - You can directly pass a function that works as a backend. - See also `backend` field in `CompileConfig` for more details. - """ - self._backend = get_backend_fn(backend) - - def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: - """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. - You can also pass a backend name with "backend" key to change the backend. - - Args: - kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. - """ - - if "backend" in kwargs: - raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") - self._compile_kwargs.update(kwargs) - - def set_compiler_fn(self, compiler_fn: Callable) -> None: - """Set a function to be used for compiling the module. - This function should take a torch.nn.Module as input and return a compiled module. - Note that other compile options are ignored when a compiler_fn is set. - - Example: - ```python - def my_compiler_fn(module: torch.nn.Module): - ... - return torch.compile(module, ...) - - engine.set_compiler_fn(my_compiler_fn) - ``` - """ - self._compiler_fn = compiler_fn - - def forward(self, *args, **kwargs) -> Any: - if not self.is_compiled: - if self._compiler_fn is None: - self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs) - else: - self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) - self._is_compiled = True - - return self.__dict__['wrapped'](*args, **kwargs) - - @property - def is_compiled(self) -> bool: - return self._is_compiled - - @property - def backend(self) -> Union[str, Callable]: - return self._backend - - @property - def torch_compile_kwargs(self) -> Dict[str, Any]: - return self._compile_kwargs - - @property - def compiler_fn(self) -> Union[Callable, None]: - return self._compiler_fn +def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None): + + class wrapper(mod.__class__): + + def __init__(self, module, compile_config: Union[CompileConfig, None] = None): + self.__dict__ = module.__dict__.copy() + + assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." + + self.__dict__['wrapped'] = module + self._is_compiled = False + self._backend = get_backend_fn(compile_config.backend) + self._compile_kwargs = compile_config.kwargs + self._compiler_fn = None + + def set_backend(self, backend: Union[str, Callable]): + """Set the backend for torch.compile. + + Args: + backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. + You can directly pass a function that works as a backend. + See also `backend` field in `CompileConfig` for more details. + """ + self._backend = get_backend_fn(backend) + + def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: + """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. + You can also pass a backend name with "backend" key to change the backend. + + Args: + kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. + """ + + if "backend" in kwargs: + raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") + self._compile_kwargs.update(kwargs) + + def set_compiler_fn(self, compiler_fn: Callable) -> None: + """Set a function to be used for compiling the module. + This function should take a torch.nn.Module as input and return a compiled module. + Note that other compile options are ignored when a compiler_fn is set. + + Example: + ```python + def my_compiler_fn(module: torch.nn.Module): + ... + return torch.compile(module, ...) + + engine.set_compiler_fn(my_compiler_fn) + ``` + """ + self._compiler_fn = compiler_fn + + def forward(self, *args, **kwargs) -> Any: + if not self.is_compiled: + if self._compiler_fn is None: + self.__dict__['wrapped'] = torch.compile(self.wrapped, + backend=self._backend, + **self._compile_kwargs) + else: + self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) + self._is_compiled = True + + return self.__dict__['wrapped'](*args, **kwargs) + + @property + def is_compiled(self) -> bool: + return self._is_compiled + + @property + def backend(self) -> Union[str, Callable]: + return self._backend + + @property + def torch_compile_kwargs(self) -> Dict[str, Any]: + return self._compile_kwargs + + @property + def compiler_fn(self) -> Union[Callable, None]: + return self._compiler_fn + + return wrapper(mod, compile_config) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9a2b943b0992..34263444c1b7 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -469,13 +469,6 @@ def __getattr__(self, name): return getattr(self, name) elif name in dir(_module): return getattr(_module, name) - elif isinstance(_module, CompiledModuleWrapper): - try: - return getattr(_module, name) - except AttributeError: - raise AttributeError( - f"None of {type(self).__name__}, CompiledModuleWrapper, or the wrapped model has the attribute '{name}'" - ) else: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 1dda7f1aad32..be8fe1a368c6 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -67,9 +67,7 @@ class PipelineEngine(DeepSpeedEngine): def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): super().__init__(*super_args, **super_kwargs) - assert isinstance(self.module, PipelineModule) \ - or (hasattr(self.module, 'wrapped') and isinstance(self.module.wrapped, PipelineModule)), \ - "model must base PipelineModule" + assert isinstance(self.module, PipelineModule), "model must base PipelineModule" assert self.zero_optimization_stage( ) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism" From 3dd7ccff8103be60c31d963dd2278d43abb68fd1 Mon Sep 17 00:00:00 2001 From: Yejing-Lai Date: Thu, 9 May 2024 06:04:02 +0800 Subject: [PATCH 21/23] enable phi3_mini autotp (#5501) This PR aims to enable phi3 mini autotp. Phi3 mini uses chunk MLP. We adjust this linear layer weight order to support this model. Please kindly review~ Thanks! --------- Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com> --- deepspeed/module_inject/auto_tp.py | 9 +++-- deepspeed/module_inject/fusedqkv_utils.py | 40 ++++++++++++++++++++++- deepspeed/module_inject/replace_module.py | 6 +++- deepspeed/module_inject/tp_shard.py | 10 ++++++ 4 files changed, 61 insertions(+), 4 deletions(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 88f7086518e8..4944e1954e37 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -13,7 +13,7 @@ from deepspeed import comm as dist from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce from deepspeed.accelerator import get_accelerator -from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw +from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_chunk_mlp from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list @@ -133,7 +133,8 @@ def is_load_module(module): load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm] load_layer_names = [ "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear", - "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm" + "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding", + "Phi3RMSNorm" ] return module.__class__ in load_layers or module._get_name() in load_layer_names @@ -328,6 +329,10 @@ def _replace(self, child, name, conv_linear_layer): # For mixtral-7x8b, need to skip MoE gate linear replace. if name == "block_sparse_moe.gate": return child + # for phi3. + if 'gate_up_proj' in name: + weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size()) + return LinearLayer(weight=weight, bias=bias) if name in self.all_reduce_linears: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] # else [weight_shape[0], weight_shape[1] // mp_size] diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index cf087c16da8a..33d36fbfae54 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -4,7 +4,7 @@ # DeepSpeed Team import torch from deepspeed.utils.logging import warning_once -from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd, get_num_attention_heads def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0): @@ -42,6 +42,7 @@ def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index): "FalconDecoderLayer": 'bloomtype', "GPTBigCodeBlock": 'bigcodetype', "DecoderLayer": 'glmtype', + "Phi3DecoderLayer": "phi3type" } def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): @@ -93,6 +94,20 @@ def _bigcode_type_transpose(input, mp_size): split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0) return torch.cat((split_q[gpu_index], kv), dim=0) + def _phi3_type_transpose(input, mp_size): + num_kv_heads = get_num_kv_heads() + num_heads = get_num_attention_heads() + hidden_size = input.shape[1] + head_dim = hidden_size // num_heads + q_pos = input.shape[0] - 2 * num_kv_heads * head_dim + q = input[:q_pos] + k = input[q_pos:q_pos + num_kv_heads * head_dim] + v = input[q_pos + num_kv_heads * head_dim:] + split_q = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0) + split_k = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0) + split_v = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0) + return torch.cat((split_q[gpu_index], split_k[gpu_index], split_v[gpu_index]), dim=0) + def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): # suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following @@ -110,6 +125,8 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): return _qwen_type_transpose(src, mp_size, module) elif fused_qkv_type == 'bigcodetype': return _bigcode_type_transpose(src, mp_size) + elif fused_qkv_type == 'phi3type': + return _phi3_type_transpose(src, mp_size) raise ValueError("unknown fused_qkv_type") @@ -123,3 +140,24 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type," f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors") return _bloom_type_transpose(src, mp_size) + + +# For phi3 with chunk mlp, adjust the weight order. +def shard_chunk_mlp( + weight, + bias, + rank, + world_size, +): + weight_gate, weight_states = weight.chunk(2, dim=0) + total_size = weight_gate.shape[0] + split_weight_gate = weight_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + split_weight_states = weight_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + shard_weight = torch.cat((split_weight_gate[rank], split_weight_states[rank]), dim=0) + if bias is not None: + bias_gate, bias_states = bias.chunk(2, dim=0) + split_bias_gate = bias_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + split_bias_states = bias_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + return shard_weight, torch.cat((split_bias_gate[rank], split_bias_states[rank]), dim=0) + + return shard_weight, None diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index e1703562d180..3029a79698dc 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -16,7 +16,7 @@ from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading from deepspeed import comm as dist -from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd +from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads from .load_checkpoint import load_model_with_checkpoint import time @@ -290,6 +290,10 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): # 4.2 set n_embd set_n_embd(n_embd) + # 4.3 set attention_heads + if hasattr(model_config, 'num_attention_heads'): + set_num_attention_heads(getattr(model_config, 'num_attention_heads')) + # 5. Set linear policies _autotp.update_linear_policies() diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 79c19b5f1272..6758c7a657f6 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -12,6 +12,11 @@ def set_num_kv_heads(num): num_kv_heads = num +def set_num_attention_heads(num): + global num_attention_heads + num_attention_heads = num + + def set_n_embd(num): global n_embd n_embd = num @@ -22,6 +27,11 @@ def get_num_kv_heads(): return num_kv_heads +def get_num_attention_heads(): + global num_attention_heads + return num_attention_heads + + def get_shard_size(total_size, mp_size, name=None, rank=None): global num_kv_heads last_linear = ["lm_head", "embed_out"] From df4ef0ab698a6ad071b4aad282690eac690b16fd Mon Sep 17 00:00:00 2001 From: BacharL Date: Fri, 10 May 2024 23:53:55 +0300 Subject: [PATCH 22/23] Fused adam for HPU (#5500) Co-authored-by: Olatunji Ruwase --- op_builder/hpu/fused_adam.py | 94 ++++++++++++++++++++++++++++++++---- 1 file changed, 85 insertions(+), 9 deletions(-) diff --git a/op_builder/hpu/fused_adam.py b/op_builder/hpu/fused_adam.py index d77228317ddb..5acb121668e3 100644 --- a/op_builder/hpu/fused_adam.py +++ b/op_builder/hpu/fused_adam.py @@ -4,10 +4,88 @@ # DeepSpeed Team -from .builder import CPUOpBuilder +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder +try: + import torch + import math +except ImportError as e: + pass -class FusedAdamBuilder(CPUOpBuilder): + +class HPUFusedAdam: + htcore = None + is_lazy_mode = None + + @staticmethod + def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode, + bias_correction, weight_decay, *args): + + if HPUFusedAdam.htcore is None: + from habana_frameworks.torch import core as htcore + from habana_frameworks.torch.utils.internal import is_lazy + HPUFusedAdam.htcore = htcore + HPUFusedAdam.is_lazy_mode = is_lazy() + + htcore = HPUFusedAdam.htcore + + htcore.step_closure._mark_step_if_lazy() + step_size = lr + if bias_correction: + bias_correction1 = 1.0 - pow(beta1, step) + bias_correction2 = 1.0 - pow(beta2, step) + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + neg_step = -step_size + neg_step_t = (torch.tensor([neg_step], dtype=torch.float, + requires_grad=False).to(tensor_lists[1][0].dtype).to(tensor_lists[1][0].device, + non_blocking=True)) + + weight_decay = weight_decay if adam_w_mode else 0 + + # since lr is fed into the kernel as tensor, perform the scalar multiplication of wd here + # NOTE: TODO if lr is updated every step, then we need to convert it as tensor and + # perform weight decay unconditonally. + modified_wd = 1.0 - weight_decay * lr + + if HPUFusedAdam.is_lazy_mode: + torch.ops.hpu.optimizer_adamw( + tensor_lists[0], + tensor_lists[1], + tensor_lists[2], + tensor_lists[3], + neg_step_t, + beta1, + beta2, + epsilon, + modified_wd, + ) + else: + modified_wd_t = (torch.tensor([modified_wd], dtype=torch.float, requires_grad=False).to( + tensor_lists[1][0].dtype).to(tensor_lists[1][0].device, non_blocking=True)) + torch.ops.hpu.optimizer_adamw( + tensor_lists[0], + tensor_lists[1], + tensor_lists[2], + tensor_lists[3], + neg_step_t, + beta1, + beta2, + epsilon, + modified_wd_t, + modified_wd != 1.0, + ) + + htcore.step_closure._mark_step_if_lazy() + + +class FusedAdamBuilder(OpBuilder): BUILD_VAR = "DS_BUILD_FUSED_ADAM" NAME = "fused_adam" @@ -18,12 +96,10 @@ def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' def sources(self): - return ['csrc/cpu/adam/fused_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] - - def cxx_args(self): - args = super().cxx_args() - args += ['-DENABLE_BFLOAT16'] - return args + return [] def include_paths(self): - return ['csrc/includes'] + return [] + + def load(self, verbose=True): + return HPUFusedAdam From 4696afd27bd54abe0ee313048f176133402e2497 Mon Sep 17 00:00:00 2001 From: YiSheng5 Date: Tue, 14 May 2024 02:26:11 +0800 Subject: [PATCH 23/23] [manifest] update mainfest to add hpp file in csrc. (#5522) Update the mainfest to cover hpp file in csrc. --- MANIFEST.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index ab79573ef96c..85e00695d648 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,7 +3,7 @@ include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so recursive-include requirements *.txt recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json -recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc +recursive-include csrc *.cpp *.h *.hpp *.cu *.tr *.cuh *.cc recursive-include op_builder *.py recursive-include benchmarks *.py recursive-include accelerator *.py