diff --git a/.github/workflows/cpu-torch-latest.yml b/.github/workflows/cpu-torch-latest.yml index 9c1ad02f75a6..5727ff2e1cde 100644 --- a/.github/workflows/cpu-torch-latest.yml +++ b/.github/workflows/cpu-torch-latest.yml @@ -50,5 +50,5 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.2" - TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.2" + TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.3" + TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.3" diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml index 94571eb101bb..cf8756fbd528 100644 --- a/.github/workflows/nv-ds-chat.yml +++ b/.github/workflows/nv-ds-chat.yml @@ -10,6 +10,10 @@ on: required: false default: 'master' type: string + pull_request: + paths: + - "deepspeed/runtime/zero/stage_1_and_2.py" + - "deepspeed/runtime/zero/stage3.py" concurrency: group: ${{ github.workflow }}-${{ github.ref }} diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index 3ca8ac43dfa4..3109f6060944 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -55,5 +55,5 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.2" --cuda_ver="11.8" - pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.2" --cuda_ver="11.8" + pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.3" --cuda_ver="11.8" + pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.3" --cuda_ver="11.8" diff --git a/MANIFEST.in b/MANIFEST.in index ab79573ef96c..85e00695d648 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,7 +3,7 @@ include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so recursive-include requirements *.txt recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json -recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc +recursive-include csrc *.cpp *.h *.hpp *.cu *.tr *.cuh *.cc recursive-include op_builder *.py recursive-include benchmarks *.py recursive-include accelerator *.py diff --git a/README.md b/README.md index a1335caa4949..f9d81eddfdae 100755 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat). * [2024/03] [DeepSpeed-FP6:The power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)] -* [2024/01] [DeepSpeed-FastGen: Introducting Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19) +* [2024/01] [DeepSpeed-FastGen: Introducing Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19) * [2023/11] [Llama 2 Inference on 4th Gen Intel® Xeon® Scalable Processor with DeepSpeed](https://github.com/microsoft/DeepSpeed/tree/master/blogs/intel-inference) [[Intel version]](https://www.intel.com/content/www/us/en/developer/articles/technical/xllama-2-on-xeon-scalable-processor-with-deepspeed.html) * [2023/11] [DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-offloadpp) * [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/japanese/README.md)] @@ -159,11 +159,12 @@ dynamically link them at runtime. ## Contributed HW support * DeepSpeed now support various HW accelerators. -| Contributor | Hardware | Accelerator Name | Contributor validated | Upstream validated | -| ----------- | -------- | ---------------- | --------------------- | ------------------ | -| Intel | Intel(R) Gaudi(R) 2 AI accelerator | hpu | Yes | Yes | -| Intel | Intel(R) Xeon(R) Processors | cpu | Yes | Yes | -| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | Yes | +| Contributor | Hardware | Accelerator Name | Contributor validated | Upstream validated | +|-------------|-------------------------------------|------------------| --------------------- |--------------------| +| Huawei | Huawei Ascend NPU | npu | Yes | No | +| Intel | Intel(R) Gaudi(R) 2 AI accelerator | hpu | Yes | Yes | +| Intel | Intel(R) Xeon(R) Processors | cpu | Yes | Yes | +| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | Yes | ## PyPI We regularly push releases to [PyPI](https://pypi.org/project/deepspeed/) and encourage users to install from there in most cases. diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index 3c5d799e293e..690f19292299 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -12,6 +12,7 @@ class DeepSpeedAccelerator(ABC): def __init__(self): self._name = None self._communication_backend_name = None + self._compile_backend = None @abc.abstractmethod def is_synchronized_device(self): @@ -287,3 +288,19 @@ def build_extension(self): @abc.abstractmethod def export_envs(self): ... + + @abc.abstractmethod + def visible_devices_envs(self): + ... + + @abc.abstractmethod + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + ... + + @abc.abstractmethod + def get_compile_backend(self): + ... + + @abc.abstractmethod + def set_compile_backend(self, backend): + ... diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index a0171723cfb8..bd11d034f312 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -20,6 +20,7 @@ class CPU_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'cpu' + self._compile_backend = "inductor" if oneccl_imported_p: self._communication_backend_name = 'ccl' else: @@ -322,3 +323,22 @@ def build_extension(self): def export_envs(self): return [] + + # TODO: cpu's visible envs is confirmed, keep as CUDA_VISIBLE_DEVICES + def visible_devices_envs(self): + return ['CUDA_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 3d5e9c168c16..60d66b6cdbab 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -25,6 +25,7 @@ class CUDA_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'cuda' self._communication_backend_name = 'nccl' + self._compile_backend = "inductor" if pynvml is None: self._init_pynvml() @@ -360,3 +361,21 @@ def build_extension(self): def export_envs(self): return ['NCCL'] + + def visible_devices_envs(self): + return ['CUDA_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index 30b115e8b1ab..114f367e879d 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -16,6 +16,7 @@ class HPU_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'hpu' self._communication_backend_name = 'hccl' + self._compile_backend = "hpu_backend" try: import habana_frameworks.torch.hpu as hpu hpu.setDeterministic(True) @@ -294,3 +295,21 @@ def build_extension(self): def export_envs(self): return [] + + def visible_devices_envs(self): + return ['HABANA_VISIBLE_MODULES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py index 972b33caece1..5fc9b1c8cfb6 100644 --- a/accelerator/mps_accelerator.py +++ b/accelerator/mps_accelerator.py @@ -20,6 +20,7 @@ class MPS_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = "mps" self._communication_backend_name = None + self._compile_backend = "inductor" def is_synchronized_device(self): return False @@ -258,3 +259,23 @@ def build_extension(self): def export_envs(self): return [] + + # TODO: mpu's visible envs is confirmed, keep as CUDA_VISIBLE_DEVICES + def visible_devices_envs(self): + # TODO: could not find visible devices env for mps + return ['CUDA_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 472157e32c02..b0e0ff948e52 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -20,6 +20,7 @@ def __init__(self): super().__init__() self._name = 'npu' self._communication_backend_name = 'hccl' + self._compile_backend = "inductor" # dict that holds class name <--> class type mapping i.e. # 'AsyncIOBuilder': # this dict will be filled at init stage @@ -278,3 +279,21 @@ def build_extension(self): def export_envs(self): return ['ASCEND', 'HCCL', 'LD_LIBRARY', 'PATH'] + + def visible_devices_envs(self): + return ['ASCEND_RT_VISIBLE_DEVICES'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends }") diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index 3f65263946ab..9c4a9c903f96 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -14,6 +14,7 @@ class XPU_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'xpu' self._communication_backend_name = 'ccl' + self._compile_backend = "inductor" self.aligned_tensors = [] def is_synchronized_device(self): @@ -289,3 +290,21 @@ def build_extension(self): def export_envs(self): return [] + + def visible_devices_envs(self): + return ['ZE_AFFINITY_MASK'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + supported_backends = torch._dynamo.list_backends(exclude_tags=()) + if backend in supported_backends: + self._compile_backend = backend + else: + raise ValueError( + f"{backend} not supported by {self.device_name()}. Supported Backends are {supported_backends}") diff --git a/build_win.bat b/build_win.bat index ec8c8a362a78..af5c5103fa4b 100644 --- a/build_win.bat +++ b/build_win.bat @@ -1,6 +1,10 @@ @echo off set DS_BUILD_AIO=0 +set DS_BUILD_CUTLASS_OPS=0 +set DS_BUILD_EVOFORMER_ATTN=0 +set DS_BUILD_FP_QUANTIZER=0 +set DS_BUILD_RAGGED_DEVICE_OPS=0 set DS_BUILD_SPARSE_ATTN=0 echo Administrative permissions required. Detecting permissions... diff --git a/csrc/adam/multi_tensor_adam.cu b/csrc/adam/multi_tensor_adam.cu index d6b9b2f70710..a1fc7d15aec9 100644 --- a/csrc/adam/multi_tensor_adam.cu +++ b/csrc/adam/multi_tensor_adam.cu @@ -30,7 +30,7 @@ typedef enum : int { using MATH_T = float; -template +template struct AdamFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, @@ -48,13 +48,13 @@ struct AdamFunctor { // if(*noop_gmem == 1) // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; // potentially use to pass in list of scalar // int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; T* g = (T*)tl.addresses[0][tensor_loc]; g += chunk_idx * chunk_size; @@ -71,7 +71,8 @@ struct AdamFunctor { n -= chunk_idx * chunk_size; // see note in multi_tensor_scale_kernel.cu - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + for (index_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { MATH_T r_g[ILP]; MATH_T r_p[ILP]; MATH_T r_m[ILP]; @@ -146,23 +147,57 @@ void multi_tensor_adam_cuda(int chunk_size, bias_correction2 = 1 - std::pow(beta2, step); } + size_t max_size = 0; + bool requires_64bit_indexing = false; + for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { + for (auto it2 = it->begin(); it2 != it->end(); it2++) { + if (it2->numel() > max_size) { + max_size = it2->numel(); + if (max_size >= INT_MAX) { + requires_64bit_indexing = true; + break; + } + } + } + if (requires_64bit_indexing) { break; } + } + // Assume single type across p,g,m1,m2 now - DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), - 0, - "adam", - multi_tensor_apply<4>(BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor(), - beta1, - beta2, - bias_correction1, - bias_correction2, - epsilon, - lr, - (adamMode_t)mode, - weight_decay);) + if (requires_64bit_indexing) { + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>((int64_t)BLOCK_SIZE, + (int64_t)chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t)mode, + weight_decay);) + } else { + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4>(BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + AdamFunctor(), + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + (adamMode_t)mode, + weight_decay);) + } AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/csrc/adam/multi_tensor_apply.cuh b/csrc/adam/multi_tensor_apply.cuh index 12f41cb49c6b..342376c141be 100644 --- a/csrc/adam/multi_tensor_apply.cuh +++ b/csrc/adam/multi_tensor_apply.cuh @@ -35,7 +35,7 @@ struct TensorListMetadata { }; template -__global__ void multi_tensor_apply_kernel(int chunk_size, +__global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int* noop_flag, T tl, U callable, @@ -46,8 +46,8 @@ __global__ void multi_tensor_apply_kernel(int chunk_size, } template -void multi_tensor_apply(int block_size, - int chunk_size, +void multi_tensor_apply(int64_t block_size, + int64_t chunk_size, const at::Tensor& noop_flag, const std::vector>& tensor_lists, T callable, @@ -91,9 +91,9 @@ void multi_tensor_apply(int block_size, tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); loc_tensor_info++; - int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; - for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) { // std::cout << chunks_this_tensor << std::endl; tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; tl.block_to_chunk[loc_block_info] = chunk; diff --git a/csrc/fp_quantizer/includes/quantize.h b/csrc/fp_quantizer/includes/quantize.h index 2204c1ba74fc..507252d6e722 100644 --- a/csrc/fp_quantizer/includes/quantize.h +++ b/csrc/fp_quantizer/includes/quantize.h @@ -113,3 +113,14 @@ void launch_dequantization(uint8_t* val, int q_mantisa_bits, int q_exponent_bits, cudaStream_t stream); + +template +void launch_selective_dequantization(uint8_t* val, + T* q_val, + int32_t* indexes, + int num_groups, + int group_size, + int num_indexes, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream); diff --git a/csrc/fp_quantizer/quantize.cpp b/csrc/fp_quantizer/quantize.cpp index 4a88ff767636..ec631c576e27 100644 --- a/csrc/fp_quantizer/quantize.cpp +++ b/csrc/fp_quantizer/quantize.cpp @@ -78,8 +78,39 @@ void dequantize(torch::Tensor& val, #endif } +#define DISPATCH_DEQUANTIZE_INDEX(T_TYPE, C_TYPE, mantisa) \ + if (val.options().dtype() == torch::T_TYPE) { \ + launch_selective_dequantization((uint8_t*)val_q.data_ptr(), \ + (C_TYPE*)val.data_ptr(), \ + (int32_t*)indexes.data_ptr(), \ + num_groups, \ + group_size, \ + num_indexes, \ + q_mantisa_bits, \ + q_exponent_bits, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } +void selective_dequantize(torch::Tensor& val, + torch::Tensor& val_q, + torch::Tensor& indexes, + int group_size, + int q_mantisa_bits, + int q_exponent_bits) +{ + int total_elems = at::numel(val); + int num_indexes = indexes.size(0); + int num_groups = total_elems / group_size; + + DISPATCH_DEQUANTIZE_INDEX(kHalf, __half, 10); +#ifdef BF16_AVAILABLE + DISPATCH_DEQUANTIZE_INDEX(kBFloat16, __nv_bfloat16, 7); +#endif +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("quantize", &quantize, "quantize function"); m.def("dequantize", &dequantize, "dequantize function"); + m.def("selective_dequantize", &selective_dequantize, "selective dequantize function"); } diff --git a/csrc/fp_quantizer/quantize.cu b/csrc/fp_quantizer/quantize.cu index 5f0b58f124f0..5ada6894747f 100644 --- a/csrc/fp_quantizer/quantize.cu +++ b/csrc/fp_quantizer/quantize.cu @@ -270,6 +270,7 @@ __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int mem_access::load_global( int8_data + quantization::quanitzed_access_granularity_6bits * 2, load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2); + } else { mem_access::load_global(int8_data, load_base_ptr); @@ -393,3 +394,137 @@ void launch_dequantization(uint8_t* val, INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7); #endif INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10); + +template +__global__ void apply_selective_dequantization(uint8_t* val, + T* q_val, + int32_t* indexes, + int group_size, + int total_num_elements) +{ + int index = indexes[blockIdx.x]; + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size; + int input_index = index * total_num_elements + tidx; + constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; + constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; + constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; + constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; + constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits); + const uint32_t g_index = (input_index / group_size); + const uint32_t group_size_bytes = (group_size * quantized_bits / 8); + const uint8_t* load_base_ptr = + val + g_index * (group_size_bytes + 4) + (input_index % group_size) * quantized_bits / 8; + + int mantisa_mask = ((1 << q_mantisa_bits) - 1); + mantisa_mask <<= (_mantisa_bits - q_mantisa_bits); + + T* store_base_ptr = q_val + tidx + blockIdx.x * total_num_elements; + float scale; + + uint8_t* scale_as_int8 = reinterpret_cast(&scale); + if (quantized_bits == 6) { + mem_access::load_global( + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); + mem_access::load_global( + scale_as_int8 + quantization::quanitzed_access_granularity_6bits, + val + g_index * (group_size_bytes + 4) + group_size_bytes + + quantization::quanitzed_access_granularity_6bits); + } else + mem_access::load_global( + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); + + if (tidx < total_num_elements) { + uint64_t q_buf_in; + uint64_t q_buf_in1; + uint8_t* int8_data = reinterpret_cast(&q_buf_in); + uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); + if (quantized_bits == 6) { + mem_access::load_global( + int8_data, load_base_ptr); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits, + load_base_ptr + quantization::quanitzed_access_granularity_6bits); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits * 2, + load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2); + } else { + mem_access::load_global(int8_data, + load_base_ptr); + if (quantized_bits > 4) { + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity, + load_base_ptr + quantization::quanitzed_access_granularity); + if (quantized_bits == 12) { + mem_access::load_global( + int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2); + } + } + } + T store_buf[vector_size]; + uint16_t* q_buf = reinterpret_cast(store_buf); +#pragma unroll + for (int j = 0; j < vector_size; j++) { + uint16_t new_data; + if (j < 5 || quantized_bits != 12) { + new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); + } else { + if (j == 5) { + new_data = (uint16_t)(q_buf_in1); + new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); + } else + new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); + } + + uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); + uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; + uint16_t dst_mantisa = (new_data & _mantisa_mask); + + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; + + q_buf[j] = + ((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) | + (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); + float up_cast = conversion::to(store_buf[j]); + store_buf[j] = conversion::to(up_cast * scale); + } + mem_access::store_global(store_base_ptr, store_buf); + } +} + +template +void launch_selective_dequantization(uint8_t* val, + T* q_val, + int32_t* indexes, + int num_groups, + int group_size, + int num_indexes, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream) +{ + int total_elements_per_index = (num_groups / num_indexes) * group_size; + int blocks = (total_elements_per_index - 1) / + (quantization::threads * (quantization::access_granularity / sizeof(T))) + + 1; + const dim3 grid(num_indexes, blocks); + const dim3 block(quantization::threads); + DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { + apply_selective_dequantization + <<>>(val, q_val, indexes, group_size, total_elements_per_index); + }); +} +#define INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(T, mantisa) \ + template void launch_selective_dequantization( \ + uint8_t*, T*, int32_t*, int, int, int, int, int, cudaStream_t); +// fp8(E4M3) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__nv_bfloat16, 7); +#endif +INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__half, 10); diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index d5eca81c804f..b1a8276589b6 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -10,7 +10,7 @@ import argparse import glob import itertools -import multiprocessing +from concurrent.futures import ProcessPoolExecutor import os import re import shutil @@ -132,6 +132,10 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): cnt = 0 +def dp_index_to_str(dp_index): + return f"{dp_index:0>2d}" + + def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel): global cnt # temp hack @@ -140,9 +144,8 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, os.makedirs(param_base_path, exist_ok=True) cnt += 1 - counter = f"{dp_index:0>2d}" - path = os.path.join(param_base_path, f"{state_name}.{counter}") + path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}") #print(f"{param_name}: {offset}: {numel} => {path}") @@ -156,10 +159,21 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): slices = [] for tp_index in range(tp_degree): prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") - paths = sorted(list(glob.glob(f"{prefix_path}.*"))) + paths = glob.glob(f"{prefix_path}.*") + if len(paths) == 0: continue + pattern = re.compile(f"{prefix_path}\\.([0-9]+)") + dp_indices = set() + for p in paths: + m = pattern.match(p) + if m: + dp_indices.add(int(m.group(1))) + else: + raise ValueError(f"Cannot parse dp_rank from {p}") + + paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))] shards = [torch.load(p) for p in paths] if state == "step": @@ -278,27 +292,18 @@ def get_matched_sub_params_pattern(name_): return unmatched_patterns -def _get_chunks(l, n): - for i in range(0, len(l), n): - yield l[i:i + n] - - def _do_parallel_work(do_work, work_chunks, num_workers): + results = [] if num_workers > 1: - pool = multiprocessing.Pool(num_workers) - results = [] - for batch in tqdm.tqdm(work_chunks): - res = pool.map(do_work, batch) - results.extend(res) - pool.close() - pool.join() + with ProcessPoolExecutor(max_workers=num_workers) as executor: + future_list = [executor.submit(do_work, work) for work in work_chunks] + for f in tqdm.tqdm(future_list): + results.append(f.result()) else: # No parallel pass for unit testing # We can't create child processes in tests - results = [] - for batch in tqdm.tqdm(work_chunks): - res = [do_work(x) for x in batch] - results.extend(res) + for work in tqdm.tqdm(work_chunks): + results.append(do_work(work)) return results @@ -307,20 +312,15 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), range(ds_checkpoint.dp_degree))) #pprint(f'{_3d_range_list=}') - work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers)) - #pprint(f'{work_chunks=}') - # extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0]) do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint) - _do_parallel_work(do_work, work_chunks, args.num_extract_workers) + _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers) def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir): - work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers)) - #pprint(work_chunks) zero_output_folder = os.path.join(args.output_folder, "zero") do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree) - unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers) + unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers) # verify that all patterns were used # if a pattern was not used by any of the workers, then it was not used at all -> assert/alert diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py index ffb9cbc18fbd..079182a3b45b 100755 --- a/deepspeed/launcher/launch.py +++ b/deepspeed/launcher/launch.py @@ -22,6 +22,7 @@ from collections import defaultdict from typing import Dict from argparse import ArgumentParser, REMAINDER +from deepspeed.accelerator import get_accelerator from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT from ..nebula.constants import DLTS_POD_ENV_PATH from ..utils import logger, get_numactl_cmd @@ -146,8 +147,8 @@ def main(): node_list = list(world_info.keys()) args.nnodes = len(node_list) local_node = node_list[args.node_rank] - local_gpu_ids = world_info[local_node] - num_local_procs = len(local_gpu_ids) + local_accelerator_ids = world_info[local_node] + num_local_procs = len(local_accelerator_ids) logger.info(f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}") global_rank_mapping = defaultdict(list) @@ -161,8 +162,10 @@ def main(): curr_global_rank += 1 logger.info(f"global_rank_mapping={global_rank_mapping}") logger.info(f"dist_world_size={dist_world_size}") - current_env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, local_gpu_ids)) - logger.info(f"Setting CUDA_VISIBLE_DEVICES={current_env['CUDA_VISIBLE_DEVICES']}") + + get_accelerator().set_visible_devices_envs(current_env, local_accelerator_ids) + for env in get_accelerator().visible_devices_envs(): + logger.info(f"Setting {env}={current_env[env]}") # set PyTorch distributed related environmental variables current_env["MASTER_ADDR"] = args.master_addr diff --git a/deepspeed/linear/__init__.py b/deepspeed/linear/__init__.py new file mode 100644 index 000000000000..a27f1c3eaee7 --- /dev/null +++ b/deepspeed/linear/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .optimized_linear import OptimizedLinear +from .config import LoRAConfig, QuantizationConfig diff --git a/deepspeed/linear/config.py b/deepspeed/linear/config.py new file mode 100644 index 000000000000..ae9050a3c92b --- /dev/null +++ b/deepspeed/linear/config.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from dataclasses import dataclass + + +@dataclass +class LoRAConfig: + """ + Configuration settings for LoRAOptimizedLinear. + + Attributes: + lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64. + lora_alpha (float): LoRA scaling factor, default is 16. + base_weight_sharding (int): The degree to which the base weights are sharded, + should typically be set to the data-parallel world size to maximize the memory + reduction benefits. Defaults to 1, which means this feature is disabled. + """ + lora_r: int = 64 + lora_alpha: float = 16. + base_weight_sharding: int = 1 + + +@dataclass +class QuantizationConfig: + """ + Configuration settings for quantization for LoRAOptimizedLinear, QuantizedLinear, + and QuantizedParameter + + Attributes: + q_bits (int): The number of bits used for quantization. Default is 8. + mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3. + group_size (int): The size of the group used for quantization. Default is 512. + """ + q_bits: int = 8 + mantissa_bits: int = 3 + group_size: int = 512 diff --git a/deepspeed/linear/optimized_linear.py b/deepspeed/linear/optimized_linear.py new file mode 100644 index 000000000000..e982785a8122 --- /dev/null +++ b/deepspeed/linear/optimized_linear.py @@ -0,0 +1,150 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import is_dataclass +from deepspeed.accelerator import get_accelerator +import deepspeed.comm as dist + +from .config import LoRAConfig, QuantizationConfig +from .quantization import QuantizedParameter, QuantizedLinear + + +class OptimizedLinear(nn.Module): + """ + Optimized version of nn.Linear that adds features such as: + * LoRA w. base weight sharding + * FP [6,8,12] quantization + + Arguments: + input_dim: Required: size of each input sample + output_dim: Required: size of each output sample + bias: Optional: If set to False, the layer will not learn an additive bias. Default: False + lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree + quantization_config: Optional: QuantizationConfig defining quantization features + dtype: Optional: parameter dtype, only supports bfloat16 currently + + Returns: + Returns a new nn.Module depending on the input config. Either native + torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear. + """ + + def __new__(self, + input_dim: int, + output_dim: int, + bias: bool = False, + lora_config: LoRAConfig = None, + quantization_config: QuantizationConfig = None, + dtype=torch.bfloat16): + + if quantization_config is not None and not is_dataclass(quantization_config): + raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}") + if lora_config is not None and not is_dataclass(lora_config): + raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}") + if lora_config is None and quantization_config is None: + # Everything disabled, fall back to normal nn.Linear + self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype) + + elif lora_config: + # lora enabled, quantization may or may not be + self = LoRAOptimizedLinear(input_dim=input_dim, + output_dim=output_dim, + bias=bias, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=dtype) + + elif quantization_config: + # only quantization enabled, no lora + self = QuantizedLinear(input_dim=input_dim, + output_dim=output_dim, + bias=bias, + quantization_config=quantization_config, + dtype=dtype) + return self + + +class LoRAOptimizedLinear(nn.Module): + + def __init__(self, + input_dim: int, + output_dim: int, + bias: bool = False, + lora_config: LoRAConfig = None, + quantization_config: QuantizationConfig = None, + device=None, + dtype=torch.bfloat16): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.bias = bias + self.lora_config = lora_config + self.quantization_config = quantization_config + device = get_accelerator().current_device_name() if device is None else device + assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" + + self.zero_shards = self.lora_config.base_weight_sharding + self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards) + w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype)) + torch.nn.init.xavier_uniform_(w) + + if self.quantization_config is not None: + assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization" + self.base_weight = QuantizedParameter(w, quantization_config=quantization_config) + else: + self.base_weight = w + + self.base_weight.requires_grad = False + + # Use RS lora for now. + self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r) + # Keeping lora weights in bf16 precision for ease of training. + self.lora_weight_1 = nn.Linear(self.input_dim, + self.lora_config.lora_r, + bias=self.bias, + device=device, + dtype=dtype) + self.lora_weight_2 = nn.Linear(self.lora_config.lora_r, + self.output_dim, + bias=self.bias, + device=device, + dtype=dtype) + self.lora_weight_1.weight.requires_grad = True + self.lora_weight_2.weight.requires_grad = True + + def full_weight(self): + # This assumes weights are evenly sharded across gpus. which might not be correct. + # in that case, we should flatten before all_gather. + local_weight = self.base_weight.dequantized() if isinstance(self.base_weight, + QuantizedParameter) else self.base_weight + tensor_list = [ + torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype) + for _ in range(self.zero_shards) + ] + dist.all_gather(tensor_list, local_weight) + weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1)) + return weight + + def linear_without_F_linear(self, input, weight): + output = torch.mm(input.reshape(-1, input.shape[-1]), weight) + output = output.view(*input.shape[:-1], weight.shape[1]) + return output + + def forward(self, input_tensor): + # Gather the sharded base weight + if self.zero_shards > 1: + with torch.no_grad(): + base_weight = self.full_weight() + elif self.quantization_config: + base_weight = self.base_weight.dequantized() + else: + base_weight = self.base_weight + + base_weight_output = F.linear(input_tensor, base_weight) + lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor)) + return base_weight_output + self.lora_scaling_factor * lora_output diff --git a/deepspeed/linear/quantization.py b/deepspeed/linear/quantization.py new file mode 100644 index 000000000000..f5343af45fb8 --- /dev/null +++ b/deepspeed/linear/quantization.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional + +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.fp_quantizer import Quantizer, FP_Quantize +from .config import QuantizationConfig + + +class QuantizedParameter(nn.Parameter): + """ + Quantized parameter class that implements weight quantization. Weights + are stored in quantized form on GPUs, and can be dequantized on-the-fly when + needed by the model. The weights are actually quantized during any `.to(device)`. + + Arguments: + data (Tensor): parameter tensor. + requires_grad (bool, optional): if the parameter requires gradient. Defaults + to False and is not supported to be True. Argument provided only for interface + compatibility with torch.nn.Parameter. + quantization_config (QuantizationConfig, optional): + quantizer (Quantizer, optional): Defaults to FP_Quantize but can be any quantizer + that implements deepspeed.ops.fp_quantizer.Quantizer. This argument is also + required since the quantizer is stashed in the Parameter itself, some models + may clone the Parameter by passing an attribute __dict__. For an example, see + tests/unit/linear/test_quant_param.py::TestQuantParam::test_hf_clone + """ + + def __new__( + cls, + data: Optional[torch.Tensor] = None, + requires_grad: bool = False, # quantized weights must be frozen + quantization_config: QuantizationConfig = None, + quantizer: Quantizer = None, + ): + if requires_grad: + raise ValueError(f"requires_grad=True is not supported with QuantizedParameter") + if data is None: + data = torch.empty(0) + self = torch.Tensor._make_subclass(cls, data, requires_grad) + self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config + if quantizer is not None: + self.quantizer = quantizer + else: + # if FPQuantizerBuilder is not compatible in this env this init will fail + self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size) + self._ensure_quantized(self) + return self + + def _ensure_quantized(self, tensor: torch.Tensor): + # If the tensor is on the accelerator and is not quantized, then quantize it in-place. + if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.int8: + with get_accelerator().stream(get_accelerator().current_stream(tensor.device)): + tensor.data = self.quantizer.quantize(tensor.data, + q_bits=self.quantization_config.q_bits, + q_mantisa_bits=self.quantization_config.mantissa_bits) + assert tensor.dtype == torch.int8 + + def dequantized(self) -> torch.Tensor: + """ + Return a tensor containing the dequantized weights of this parameter. + """ + if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.int8: + with get_accelerator().stream(get_accelerator().current_stream(self.data.device)): + return self.quantizer.dequantize(self.data, + q_bits=self.quantization_config.q_bits, + q_mantisa_bits=self.quantization_config.mantissa_bits) + return self.data + + def __getstate__(self): + state = self.__dict__ + state["data"] = self.data + state["quantization_config"] = self.quantization_config + state["requires_grad"] = self.requires_grad + return state + + def __setstate__(self, state): + self.quantizer = state["quantizer"] + self.quantization_config = state["quantization_config"] + self.data = state["data"] + self.requires_grad = state["requires_grad"] + + def __deepcopy__(self, memo): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + new_instance.quantizer = copy.deepcopy(state["quantizer"]) + new_instance.quantization_config = copy.deepcopy(state["quantization_config"]) + new_instance.data = copy.deepcopy(state["data"]) + return new_instance + + def __copy__(self): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + + def cuda(self, device=None, non_blocking=False): + return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) + + def to(self, *args, **kwargs): + """ + Move the parameter to the given device. Then, if the device is a cuda device, + quantize it. + """ + tensor = super().to(*args, **kwargs) + self._ensure_quantized(tensor) + return tensor + + +class QuantizedLinear(nn.Linear): + """ + Linear layer that implements weight quantization. Parameters + are stored via `QuantizedParameter` and are dequantized on-the-fly during any + forward pass. + """ + + def __init__(self, + input_dim: int, + output_dim: int, + bias: bool = False, + quantization_config: QuantizationConfig = None, + dtype=torch.bfloat16): + super().__init__(input_dim, output_dim, bias=bias, dtype=dtype) + assert dtype == torch.bfloat16, "currently only supports bfloat16 dtype" + self.weight = QuantizedParameter(self.weight.data, quantization_config=quantization_config) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.dequantized(), self.bias) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 88f7086518e8..4944e1954e37 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -13,7 +13,7 @@ from deepspeed import comm as dist from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce from deepspeed.accelerator import get_accelerator -from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw +from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_chunk_mlp from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list @@ -133,7 +133,8 @@ def is_load_module(module): load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm] load_layer_names = [ "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear", - "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm" + "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding", + "Phi3RMSNorm" ] return module.__class__ in load_layers or module._get_name() in load_layer_names @@ -328,6 +329,10 @@ def _replace(self, child, name, conv_linear_layer): # For mixtral-7x8b, need to skip MoE gate linear replace. if name == "block_sparse_moe.gate": return child + # for phi3. + if 'gate_up_proj' in name: + weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size()) + return LinearLayer(weight=weight, bias=bias) if name in self.all_reduce_linears: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] # else [weight_shape[0], weight_shape[1] // mp_size] diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index cf087c16da8a..33d36fbfae54 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -4,7 +4,7 @@ # DeepSpeed Team import torch from deepspeed.utils.logging import warning_once -from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd, get_num_attention_heads def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0): @@ -42,6 +42,7 @@ def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index): "FalconDecoderLayer": 'bloomtype', "GPTBigCodeBlock": 'bigcodetype', "DecoderLayer": 'glmtype', + "Phi3DecoderLayer": "phi3type" } def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): @@ -93,6 +94,20 @@ def _bigcode_type_transpose(input, mp_size): split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0) return torch.cat((split_q[gpu_index], kv), dim=0) + def _phi3_type_transpose(input, mp_size): + num_kv_heads = get_num_kv_heads() + num_heads = get_num_attention_heads() + hidden_size = input.shape[1] + head_dim = hidden_size // num_heads + q_pos = input.shape[0] - 2 * num_kv_heads * head_dim + q = input[:q_pos] + k = input[q_pos:q_pos + num_kv_heads * head_dim] + v = input[q_pos + num_kv_heads * head_dim:] + split_q = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0) + split_k = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0) + split_v = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0) + return torch.cat((split_q[gpu_index], split_k[gpu_index], split_v[gpu_index]), dim=0) + def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): # suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following @@ -110,6 +125,8 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): return _qwen_type_transpose(src, mp_size, module) elif fused_qkv_type == 'bigcodetype': return _bigcode_type_transpose(src, mp_size) + elif fused_qkv_type == 'phi3type': + return _phi3_type_transpose(src, mp_size) raise ValueError("unknown fused_qkv_type") @@ -123,3 +140,24 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type," f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors") return _bloom_type_transpose(src, mp_size) + + +# For phi3 with chunk mlp, adjust the weight order. +def shard_chunk_mlp( + weight, + bias, + rank, + world_size, +): + weight_gate, weight_states = weight.chunk(2, dim=0) + total_size = weight_gate.shape[0] + split_weight_gate = weight_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + split_weight_states = weight_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + shard_weight = torch.cat((split_weight_gate[rank], split_weight_states[rank]), dim=0) + if bias is not None: + bias_gate, bias_states = bias.chunk(2, dim=0) + split_bias_gate = bias_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + split_bias_states = bias_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + return shard_weight, torch.cat((split_bias_gate[rank], split_bias_states[rank]), dim=0) + + return shard_weight, None diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index e1703562d180..3029a79698dc 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -16,7 +16,7 @@ from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading from deepspeed import comm as dist -from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd +from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads from .load_checkpoint import load_model_with_checkpoint import time @@ -290,6 +290,10 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): # 4.2 set n_embd set_n_embd(n_embd) + # 4.3 set attention_heads + if hasattr(model_config, 'num_attention_heads'): + set_num_attention_heads(getattr(model_config, 'num_attention_heads')) + # 5. Set linear policies _autotp.update_linear_policies() diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 79c19b5f1272..6758c7a657f6 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -12,6 +12,11 @@ def set_num_kv_heads(num): num_kv_heads = num +def set_num_attention_heads(num): + global num_attention_heads + num_attention_heads = num + + def set_n_embd(num): global n_embd n_embd = num @@ -22,6 +27,11 @@ def get_num_kv_heads(): return num_kv_heads +def get_num_attention_heads(): + global num_attention_heads + return num_attention_heads + + def get_shard_size(total_size, mp_size, name=None, rank=None): global num_kv_heads last_linear = ["lm_head", "embed_out"] diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py index 5575f3567185..995bbae4aeaf 100644 --- a/deepspeed/ops/fp_quantizer/__init__.py +++ b/deepspeed/ops/fp_quantizer/__init__.py @@ -3,4 +3,4 @@ # DeepSpeed Team -from .quantize import FP_Quantize +from .quantize import FP_Quantize, Quantizer diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index 5dc3c190ae5d..f8435bda16c1 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -4,20 +4,47 @@ # DeepSpeed Team import torch +import abc +from abc import ABC from deepspeed.ops.op_builder import FPQuantizerBuilder fp_quant_module = None -class FP_Quantize: +class Quantizer(ABC): + """ + Abstract Quantizer class that implmenents quantize/dequantize methods. + + Arguments: + group_size (int, optional): number of values or elements that are grouped + together for the quantization process. + """ + + def __init__(self, group_size=512) -> None: + self.group_size = group_size + + @abc.abstractmethod + def quantize(self, + input, + q_bits=8, + q_mantisa_bits=3, + stochastic_mode=False, + return_meta_tensor=False) -> torch.Tensor: + ... + + @abc.abstractmethod + def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor: + ... + + +class FP_Quantize(Quantizer): def __init__(self, group_size=512) -> None: global fp_quant_module + super().__init__(group_size=group_size) if fp_quant_module is None: fp_quant_module = FPQuantizerBuilder().load() - - self.group_size = group_size self.orig_dtype = None def quantize(self, @@ -77,3 +104,38 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out + + def selective_dequantize(self, + input_q, + indexes, + fp_out=None, + q_bits=8, + q_mantisa_bits=3, + scale=None) -> torch.Tensor: + assert (not hasattr(self, 'orig_shape') or len(self.orig_shape) == 3), \ + "Selective-Dequantization works on 3d tensor only! Please reshape the tensor before calling dequantize function." + assert (self.orig_dtype is not None), \ + "[De-quantization Error]: you need to call quantize before dequantizing!" + fp_out = torch.empty( + (indexes.shape[0], + *self.orig_shape[1:]), dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out + if q_bits == 8: + pass + elif q_bits == 12: + q_mantisa_bits = 4 + elif q_bits == 6: + q_mantisa_bits = 2 + elif q_bits == 4: + q_mantisa_bits = 1 + else: + assert (0), \ + f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" + + if scale is not None: + assert input_q.numel() == fp_out.numel(), \ + f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' + input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() + + fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits, + q_bits - q_mantisa_bits - 1) + return fp_out diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index f970e582b354..1f3365b20f4e 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -341,7 +341,7 @@ def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads): # clear gradients if clear_lp_grads: - lp.grad._zero() + lp.grad.zero_() @torch.no_grad() def _update_hp_grads_func(self, clear_lp_grads=False): @@ -441,11 +441,20 @@ def clear_hp_grads(self): self.fp32_groups_has_gradients[i] = [False] * len(group) def clear_lp_grads(self): + + # using zero_() fixed memory address for graph replay + set_to_none = False if self.graph_harvesting else True + zero_grads_list = [] for group in self.bf16_groups: for param in group: - if param.grad is not None: - # Using zero_() fixed memory address for graph replay - param.grad.zero_() + if set_to_none: + param.grad = None + elif param.grad is not None: + if param.grad.grad_fn is not None: + param.grad.detach_() + zero_grads_list.append(param.grad) + if not set_to_none and len(zero_grads_list) > 0: + torch._foreach_zero_(zero_grads_list) def state_dict(self): state_dict = {} diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index b5e4e33425d0..66fe29fbbea2 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -83,84 +83,85 @@ def validate_enabled(cls, field_value, values): return field_value -class CompiledModuleWrapper(torch.nn.Module): - - def __init__(self, module, compile_config: Union[CompileConfig, None] = None): - super().__init__() - - assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." - - modules = self.__dict__.get('_modules') - modules['wrapped'] = module - self.__dict__['wrapped'] = module - self._is_compiled = False - self._backend = get_backend_fn(compile_config.backend) - self._compile_kwargs = compile_config.kwargs - self._compiler_fn = None - - def __getattr__(self, name): - return getattr(self.__dict__['wrapped'], name) - - def set_backend(self, backend: Union[str, Callable]): - """Set the backend for torch.compile. - - Args: - backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. - You can directly pass a function that works as a backend. - See also `backend` field in `CompileConfig` for more details. - """ - self._backend = get_backend_fn(backend) - - def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: - """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. - You can also pass a backend name with "backend" key to change the backend. - - Args: - kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. - """ - - if "backend" in kwargs: - raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") - self._compile_kwargs.update(kwargs) - - def set_compiler_fn(self, compiler_fn: Callable) -> None: - """Set a function to be used for compiling the module. - This function should take a torch.nn.Module as input and return a compiled module. - Note that other compile options are ignored when a compiler_fn is set. - - Example: - ```python - def my_compiler_fn(module: torch.nn.Module): - ... - return torch.compile(module, ...) - - engine.set_compiler_fn(my_compiler_fn) - ``` - """ - self._compiler_fn = compiler_fn - - def forward(self, *args, **kwargs) -> Any: - if not self.is_compiled: - if self._compiler_fn is None: - self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs) - else: - self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) - self._is_compiled = True - - return self.__dict__['wrapped'](*args, **kwargs) - - @property - def is_compiled(self) -> bool: - return self._is_compiled - - @property - def backend(self) -> Union[str, Callable]: - return self._backend - - @property - def torch_compile_kwargs(self) -> Dict[str, Any]: - return self._compile_kwargs - - @property - def compiler_fn(self) -> Union[Callable, None]: - return self._compiler_fn +def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None): + + class wrapper(mod.__class__): + + def __init__(self, module, compile_config: Union[CompileConfig, None] = None): + self.__dict__ = module.__dict__.copy() + + assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." + + self.__dict__['wrapped'] = module + self._is_compiled = False + self._backend = get_backend_fn(compile_config.backend) + self._compile_kwargs = compile_config.kwargs + self._compiler_fn = None + + def set_backend(self, backend: Union[str, Callable]): + """Set the backend for torch.compile. + + Args: + backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. + You can directly pass a function that works as a backend. + See also `backend` field in `CompileConfig` for more details. + """ + self._backend = get_backend_fn(backend) + + def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: + """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. + You can also pass a backend name with "backend" key to change the backend. + + Args: + kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. + """ + + if "backend" in kwargs: + raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") + self._compile_kwargs.update(kwargs) + + def set_compiler_fn(self, compiler_fn: Callable) -> None: + """Set a function to be used for compiling the module. + This function should take a torch.nn.Module as input and return a compiled module. + Note that other compile options are ignored when a compiler_fn is set. + + Example: + ```python + def my_compiler_fn(module: torch.nn.Module): + ... + return torch.compile(module, ...) + + engine.set_compiler_fn(my_compiler_fn) + ``` + """ + self._compiler_fn = compiler_fn + + def forward(self, *args, **kwargs) -> Any: + if not self.is_compiled: + if self._compiler_fn is None: + self.__dict__['wrapped'] = torch.compile(self.wrapped, + backend=self._backend, + **self._compile_kwargs) + else: + self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) + self._is_compiled = True + + return self.__dict__['wrapped'](*args, **kwargs) + + @property + def is_compiled(self) -> bool: + return self._is_compiled + + @property + def backend(self) -> Union[str, Callable]: + return self._backend + + @property + def torch_compile_kwargs(self) -> Dict[str, Any]: + return self._compile_kwargs + + @property + def compiler_fn(self) -> Union[Callable, None]: + return self._compiler_fn + + return wrapper(mod, compile_config) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9a2b943b0992..34263444c1b7 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -469,13 +469,6 @@ def __getattr__(self, name): return getattr(self, name) elif name in dir(_module): return getattr(_module, name) - elif isinstance(_module, CompiledModuleWrapper): - try: - return getattr(_module, name) - except AttributeError: - raise AttributeError( - f"None of {type(self).__name__}, CompiledModuleWrapper, or the wrapped model has the attribute '{name}'" - ) else: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index bf1693307ea7..49093bb73c8f 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -241,7 +241,7 @@ def _get_norm_mask_idx(self, group): group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) grad_flat_st_idx = grad_flat_en_idx - return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device()) + return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device_name()) def step(self, closure=None): """ diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 1dda7f1aad32..be8fe1a368c6 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -67,9 +67,7 @@ class PipelineEngine(DeepSpeedEngine): def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): super().__init__(*super_args, **super_kwargs) - assert isinstance(self.module, PipelineModule) \ - or (hasattr(self.module, 'wrapped') and isinstance(self.module.wrapped, PipelineModule)), \ - "model must base PipelineModule" + assert isinstance(self.module, PipelineModule), "model must base PipelineModule" assert self.zero_optimization_stage( ) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism" diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 7744b2ee8b98..2c01c3475a70 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -171,7 +171,7 @@ def get_norm_with_moe_layers_fast(all_groups_norm, group): # This implementation standardizes the grad_norm across ranks. A more precise implementation can be found in 'get_norm_with_moe_layers'. # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=group)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device(), dtype=torch.float) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device_name(), dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=group) all_groups_norm = scaled_norm_tensor.item() #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") @@ -424,9 +424,11 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No # # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool) # # for mask_idx in grad_norm_mask[idx]: # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True - cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(), + cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device_name(), dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) - mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype) + mask_tensor = torch.zeros(p.shape[0] + 1, + device=get_accelerator().current_device_name(), + dtype=p.dtype) mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1] diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 8fc962c4f2a7..bdec8a55fcbc 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -34,6 +34,7 @@ def get_all_parameters(sub_module, recurse=False): return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) +@compiler.disable def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: return map(lambda pair: pair[1], get_all_parameters(module, recurse)) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 68cab13c4a93..13ca29c9fceb 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -15,7 +15,7 @@ from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce -from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item +from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -1409,7 +1409,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): norm_is_nan = total_norm.isnan() inf_or_nan = norm_is_nan.logical_or(norm_is_inf) - err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float) total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm return total_norm @@ -2027,7 +2027,7 @@ def step(self, closure=None): return norm_groups = self._get_norm_groups() - scaled_global_grad_norm = torch.norm(torch.stack(norm_groups)) + scaled_global_grad_norm = get_global_norm(norm_list=norm_groups) # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.loss_scale diff --git a/op_builder/hpu/fused_adam.py b/op_builder/hpu/fused_adam.py index d77228317ddb..5acb121668e3 100644 --- a/op_builder/hpu/fused_adam.py +++ b/op_builder/hpu/fused_adam.py @@ -4,10 +4,88 @@ # DeepSpeed Team -from .builder import CPUOpBuilder +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder +try: + import torch + import math +except ImportError as e: + pass -class FusedAdamBuilder(CPUOpBuilder): + +class HPUFusedAdam: + htcore = None + is_lazy_mode = None + + @staticmethod + def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode, + bias_correction, weight_decay, *args): + + if HPUFusedAdam.htcore is None: + from habana_frameworks.torch import core as htcore + from habana_frameworks.torch.utils.internal import is_lazy + HPUFusedAdam.htcore = htcore + HPUFusedAdam.is_lazy_mode = is_lazy() + + htcore = HPUFusedAdam.htcore + + htcore.step_closure._mark_step_if_lazy() + step_size = lr + if bias_correction: + bias_correction1 = 1.0 - pow(beta1, step) + bias_correction2 = 1.0 - pow(beta2, step) + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + neg_step = -step_size + neg_step_t = (torch.tensor([neg_step], dtype=torch.float, + requires_grad=False).to(tensor_lists[1][0].dtype).to(tensor_lists[1][0].device, + non_blocking=True)) + + weight_decay = weight_decay if adam_w_mode else 0 + + # since lr is fed into the kernel as tensor, perform the scalar multiplication of wd here + # NOTE: TODO if lr is updated every step, then we need to convert it as tensor and + # perform weight decay unconditonally. + modified_wd = 1.0 - weight_decay * lr + + if HPUFusedAdam.is_lazy_mode: + torch.ops.hpu.optimizer_adamw( + tensor_lists[0], + tensor_lists[1], + tensor_lists[2], + tensor_lists[3], + neg_step_t, + beta1, + beta2, + epsilon, + modified_wd, + ) + else: + modified_wd_t = (torch.tensor([modified_wd], dtype=torch.float, requires_grad=False).to( + tensor_lists[1][0].dtype).to(tensor_lists[1][0].device, non_blocking=True)) + torch.ops.hpu.optimizer_adamw( + tensor_lists[0], + tensor_lists[1], + tensor_lists[2], + tensor_lists[3], + neg_step_t, + beta1, + beta2, + epsilon, + modified_wd_t, + modified_wd != 1.0, + ) + + htcore.step_closure._mark_step_if_lazy() + + +class FusedAdamBuilder(OpBuilder): BUILD_VAR = "DS_BUILD_FUSED_ADAM" NAME = "fused_adam" @@ -18,12 +96,10 @@ def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' def sources(self): - return ['csrc/cpu/adam/fused_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] - - def cxx_args(self): - args = super().cxx_args() - args += ['-DENABLE_BFLOAT16'] - return args + return [] def include_paths(self): - return ['csrc/includes'] + return [] + + def load(self, verbose=True): + return HPUFusedAdam diff --git a/setup.py b/setup.py index f1367b850e02..839941b989c9 100755 --- a/setup.py +++ b/setup.py @@ -219,9 +219,9 @@ def create_dir_symlink(src, dest): if sys.platform == "win32": # This creates a symbolic links on Windows. # It needs Administrator privilege to create symlinks on Windows. - create_dir_symlink('..\\..\\csrc', '.\\deepspeed\\ops\\csrc') - create_dir_symlink('..\\..\\op_builder', '.\\deepspeed\\ops\\op_builder') - create_dir_symlink('..\\accelerator', '.\\deepspeed\\accelerator') + create_dir_symlink('.\\deepspeed\\ops\\csrc', '..\\..\\csrc') + create_dir_symlink('.\\deepspeed\\ops\\op_builder', '..\\..\\op_builder') + create_dir_symlink('.\\deepspeed\\accelerator', '..\\accelerator') egg_info.manifest_maker.template = 'MANIFEST_win.in' # Parse the DeepSpeed version string from version.txt. diff --git a/tests/unit/linear/test_linear.py b/tests/unit/linear/test_linear.py new file mode 100644 index 000000000000..ccd26b4cd726 --- /dev/null +++ b/tests/unit/linear/test_linear.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +import deepspeed.comm as dist + +from deepspeed.accelerator import get_accelerator +from deepspeed.linear import OptimizedLinear, LoRAConfig, QuantizationConfig +from unit.common import DistributedTest + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + + +class TestBasicLinear(DistributedTest): + world_size = 2 + + def test(self): + lora_config = None + quantization_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 1 # Number of samples in a batch + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + + dummy_input = torch.rand(batch_size, input_features, dtype=torch.bfloat16) + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) + + +@pytest.mark.parametrize("base_weight_sharding", [1, 2]) +class TestLoRALinear(DistributedTest): + world_size = 2 + + def test(self, base_weight_sharding): + rank = dist.get_rank() + lora_config = None + quantization_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 5 # Number of samples in a batch + + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding) + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + device = get_accelerator().current_device_name() + linear_layer = linear_layer.to(device) + if rank == 0: + for n, p in linear_layer.named_parameters(): + print(f"{n}, {p.shape}") + + dummy_input = torch.rand(batch_size, input_features, device=device, dtype=torch.bfloat16) + + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) + + +@pytest.mark.parametrize("q_bits", [8, 6]) +class TestQuantLinear(DistributedTest): + world_size = 2 + + def test(self, q_bits): + rank = dist.get_rank() + lora_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 5 # Number of samples in a batch + + lora_config = None + quantization_config = QuantizationConfig(q_bits=q_bits) + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + device = get_accelerator().current_device_name() + linear_layer = linear_layer.to(device) + dummy_input = torch.rand([batch_size, input_features], device=device, dtype=torch.bfloat16) + + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) + + +@pytest.mark.parametrize("base_weight_sharding", [1, 2], ids=['bws1', 'bws2']) +@pytest.mark.parametrize("q_bits", [8, 6], ids=['qbit8', 'qbit6']) +class TestOptimizedLinear(DistributedTest): + world_size = 2 + + def test(self, base_weight_sharding, q_bits): + rank = dist.get_rank() + lora_config = None + + input_features = 64 # Number of input features + output_features = 64 # Number of output features + batch_size = 5 # Number of samples in a batch + + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding) + quantization_config = QuantizationConfig(q_bits=q_bits) + + linear_layer = OptimizedLinear(input_dim=input_features, + output_dim=output_features, + lora_config=lora_config, + quantization_config=quantization_config, + dtype=torch.bfloat16) + device = get_accelerator().current_device_name() + linear_layer = linear_layer.to(device) + dummy_input = torch.rand([batch_size, input_features], device=device, dtype=torch.bfloat16) + output = linear_layer(dummy_input) + assert output.shape == (batch_size, output_features) diff --git a/tests/unit/linear/test_quant_param.py b/tests/unit/linear/test_quant_param.py new file mode 100644 index 000000000000..9479b3cba8a0 --- /dev/null +++ b/tests/unit/linear/test_quant_param.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed + +from deepspeed.accelerator import get_accelerator +from deepspeed.linear.quantization import QuantizedParameter +from deepspeed.linear.config import QuantizationConfig + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +from unit.common import DistributedTest + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + + +class TestQuantParam(DistributedTest): + world_size = 1 + + @pytest.mark.parametrize('dtype', [torch.half, torch.float]) + def test_unsupported_dtypes(self, dtype): + device = get_accelerator().current_device_name() + data = torch.rand(5, 5, device='cpu', dtype=dtype) + qp = QuantizedParameter(data) + with pytest.raises(AssertionError): + qp.to(device) + + def test_requires_grad(self): + data = torch.rand(5, 5, dtype=torch.bfloat16) + with pytest.raises(ValueError): + QuantizedParameter(data, requires_grad=True) + + def test_move_to_accelerator(self): + device = get_accelerator().current_device() + data = torch.rand(5, 5, device='cpu', dtype=torch.bfloat16) + qp = QuantizedParameter(data) + assert qp.device == torch.device('cpu') + qp = qp.to(get_accelerator().current_device_name()) + assert qp.device == torch.device(device) + assert qp.dtype == torch.int8 + + def test_hf_clone(self): + device = get_accelerator().current_device_name() + data = torch.rand(5, 5, device=device, dtype=torch.bfloat16) + + quantization_config = QuantizationConfig(q_bits=6) + qp = QuantizedParameter(data, quantization_config=quantization_config) + + # should be able to clone parameter via dict, HF expects this to work + qp_copy = QuantizedParameter(qp.data, **qp.__dict__) + + assert all(qp.data == qp_copy.data) + assert qp.quantization_config == qp_copy.quantization_config diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index d39f9fe3d651..fdff9430a4e6 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -177,7 +177,7 @@ class TestTopk(DistributedTest): world_size = 2 def test(self): - device = get_accelerator().current_device() + device = get_accelerator().current_device_name() if dist.get_rank() == 0: logits = torch.rand(2, 2, device=device) elif dist.get_rank() == 1: diff --git a/tests/unit/ops/fp_quantizer/test_fp_quant.py b/tests/unit/ops/fp_quantizer/test_fp_quant.py index 101f4cd69811..bed8bd7e3bcc 100644 --- a/tests/unit/ops/fp_quantizer/test_fp_quant.py +++ b/tests/unit/ops/fp_quantizer/test_fp_quant.py @@ -61,6 +61,35 @@ def test_fp_quant_meta(dtype): assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +def test_fp_quant_selective(dtype): + group_size = 128 + q_bits = 8 + exp_bits = 4 + man_bits = 3 + + fpq = FP_Quantize(group_size=group_size) + indexes = torch.zeros(2, dtype=torch.int32, device='cuda') + indexes[0] = 1 + indexes[1] = 3 + for i in range(10): + x = torch.rand(4, 1024, dtype=dtype, device='cuda') + + x = x.reshape(4, 1, x.shape[-1]) + ds_x = x.clone() + x_quantized = fpq.quantize(ds_x, q_bits=q_bits) + x_dequantized = fpq.selective_dequantize(x_quantized, indexes, q_bits=q_bits) + + qtorch_out = qtorch_quantize(x.index_select(0, indexes), + exp_bits=exp_bits, + man_bits=man_bits, + group_size=group_size) + qtorch_error = (qtorch_out - x.index_select(0, indexes)).abs().sum() / x.numel() + ds_error = (x_dequantized - x.index_select(0, indexes)).abs().sum() / x.numel() + + assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" + + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) @pytest.mark.parametrize("q_bits", [8, 6, 12], ids=["qbits8", "qbits6", "qbits12"]) def test_fp_quant(dtype, q_bits): diff --git a/tests/unit/runtime/compile/test_compile_wrapper.py b/tests/unit/runtime/compile/test_compile_wrapper.py index d1830534f6ea..62af25ac3ba4 100644 --- a/tests/unit/runtime/compile/test_compile_wrapper.py +++ b/tests/unit/runtime/compile/test_compile_wrapper.py @@ -31,11 +31,9 @@ def base_config(): }, "compile": { "enabled": True, - "backend": "inductor" + "backend": get_accelerator().get_compile_backend() } } - if get_accelerator().device_name() == 'hpu': - config_dict['compile']['backend'] = 'hpu_backend' return config_dict diff --git a/tests/unit/runtime/compile/test_compile_zero.py b/tests/unit/runtime/compile/test_compile_zero.py index 7568c27e3ed2..a0736b0f5425 100644 --- a/tests/unit/runtime/compile/test_compile_zero.py +++ b/tests/unit/runtime/compile/test_compile_zero.py @@ -12,7 +12,7 @@ from unit.runtime.compile.util import compare_loss from unit.common import DistributedTest -from unit.util import bf16_required_version_check +from unit.util import bf16_required_version_check, skip_on_arch pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1), reason="Compile tests requires Pytorch version 2.1 or above") @@ -26,9 +26,11 @@ class TestZeRO(DistributedTest): @pytest.mark.parametrize('zero_stage', [1, 2, 3]) @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device): + if dtype == torch.bfloat16: + skip_on_arch(min_arch=8) if dtype == torch.bfloat16 and not bf16_required_version_check(): pytest.skip( - " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + "DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly" ) if get_accelerator().device_name() == "cpu": pytest.skip("CPU does not support this test yet") @@ -51,12 +53,10 @@ def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device): }, "compile": { "enabled": True, - "backend": "inductor" + "backend": get_accelerator().get_compile_backend() } } - if get_accelerator().device_name() == 'hpu': - config_dict['compile']['backend'] = 'hpu_backend' if offload_device == OffloadDeviceEnum.cpu: config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device} elif offload_device == OffloadDeviceEnum.nvme: diff --git a/tests/unit/runtime/compile/test_load_config.py b/tests/unit/runtime/compile/test_load_config.py index 601adae58884..cee8d3b23f6b 100644 --- a/tests/unit/runtime/compile/test_load_config.py +++ b/tests/unit/runtime/compile/test_load_config.py @@ -47,12 +47,10 @@ def base_config(): }, "compile": { "enabled": True, - "backend": "inductor" + "backend": get_accelerator().get_compile_backend() } } - if get_accelerator().device_name() == 'hpu': - config_dict['compile']['backend'] = 'hpu_backend' return config_dict diff --git a/version.txt b/version.txt index e867cc2a66a8..ac4a79626c87 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.14.2 +0.14.3