From 3c9c91a9ab4e740a5a1becef423f5b8f5fe65618 Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Fri, 15 Nov 2024 12:54:54 -0800 Subject: [PATCH] Formatting --- csrc/includes/quantization_utils.h | 37 ++++--- csrc/quantization/pt_binding.cpp | 98 +++++++++---------- csrc/quantization/quant_reduce.cu | 57 +++++------ csrc/quantization/swizzled_quantize.cu | 92 +++++++++-------- .../runtime/comm/coalesced_collectives.py | 41 ++++---- deepspeed/runtime/zero/stage3.py | 16 ++- 6 files changed, 162 insertions(+), 179 deletions(-) diff --git a/csrc/includes/quantization_utils.h b/csrc/includes/quantization_utils.h index d1e1db52cc38..de52c8c376d7 100644 --- a/csrc/includes/quantization_utils.h +++ b/csrc/includes/quantization_utils.h @@ -147,28 +147,22 @@ Group stats tracks the necessary statistics about the quantized group to abstract the particulars for the main loop. */ // Helper functions -DS_D_INLINE __half h_abs(const __half& val) { - return __habs(val); -} +DS_D_INLINE __half h_abs(const __half& val) { return __habs(val); } -DS_D_INLINE __half2 h_abs(const __half2& val) { - return __habs2(val); -} +DS_D_INLINE __half2 h_abs(const __half2& val) { return __habs2(val); } -DS_D_INLINE float to_max_float(const __half& val) { - return __half2float(val); -} +DS_D_INLINE float to_max_float(const __half& val) { return __half2float(val); } -DS_D_INLINE float to_min_float(const __half& val) { - return __half2float(val); -} +DS_D_INLINE float to_min_float(const __half& val) { return __half2float(val); } -DS_D_INLINE float to_max_float(const __half2& val) { +DS_D_INLINE float to_max_float(const __half2& val) +{ const float2 partial_max = conversion::to(val); return reduce::element(partial_max.x, partial_max.y); } -DS_D_INLINE float to_min_float(const __half2& val) { +DS_D_INLINE float to_min_float(const __half2& val) +{ const float2 partial_min = conversion::to(val); return reduce::element(partial_min.x, partial_min.y); } @@ -185,14 +179,16 @@ class GroupStats { DS_D_INLINE GroupStats() { cur_max = reduce::init(); } - DS_D_INLINE void update(DataType val) { + DS_D_INLINE void update(DataType val) + { cur_max = reduce::element(cur_max, h_abs(val)); } template DS_D_INLINE Params get_params( cg::thread_block& tb, - cg::thread_block_tile& warp) { + cg::thread_block_tile& warp) + { float max = to_max_float(cur_max); reduce::partitioned_block(tb, warp, max); Params params(max); @@ -207,12 +203,14 @@ class GroupStats { DataType cur_max; DataType cur_min; - DS_D_INLINE GroupStats() { + DS_D_INLINE GroupStats() + { cur_max = reduce::init(); cur_min = reduce::init(); } - DS_D_INLINE void update(DataType val) { + DS_D_INLINE void update(DataType val) + { cur_max = reduce::element(cur_max, val); cur_min = reduce::element(cur_min, val); } @@ -220,7 +218,8 @@ class GroupStats { template DS_D_INLINE Params get_params( cg::thread_block& tb, - cg::thread_block_tile& warp) { + cg::thread_block_tile& warp) + { float max = to_max_float(cur_max); float min = to_min_float(cur_min); reduce::partitioned_block(tb, warp, max, min); diff --git a/csrc/quantization/pt_binding.cpp b/csrc/quantization/pt_binding.cpp index 950655373a1b..1361472d4a8d 100644 --- a/csrc/quantization/pt_binding.cpp +++ b/csrc/quantization/pt_binding.cpp @@ -177,14 +177,14 @@ at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in, } std::vector ds_loco_swizzle_quant(at::Tensor& input_vals, - at::Tensor& error_feedback, - float err_beta, - int groups, - int num_bits, - quantize::Type quant_type, - int pipeline_size, - int nodes, - int devices_per_node) + at::Tensor& error_feedback, + float err_beta, + int groups, + int num_bits, + quantize::Type quant_type, + int pipeline_size, + int nodes, + int devices_per_node) { auto scales_options = at::TensorOptions() .dtype(at::kFloat) @@ -206,21 +206,19 @@ std::vector ds_loco_swizzle_quant(at::Tensor& input_vals, auto output = torch::empty({compressed_vals}, output_options); const int elems_per_group = at::numel(input_vals) / groups; - launch_loco_swizzled_quant( - reinterpret_cast(output.data_ptr()), - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(input_vals.data_ptr()), - reinterpret_cast<__half*>(error_feedback.data_ptr()), - err_beta, - num_bits, - quant_type, - groups, - elems_per_group, - pipeline_size, - nodes, - devices_per_node, - at::cuda::getCurrentCUDAStream() - ); + launch_loco_swizzled_quant(reinterpret_cast(output.data_ptr()), + reinterpret_cast(scales.data_ptr()), + reinterpret_cast(input_vals.data_ptr()), + reinterpret_cast<__half*>(error_feedback.data_ptr()), + err_beta, + num_bits, + quant_type, + groups, + elems_per_group, + pipeline_size, + nodes, + devices_per_node, + at::cuda::getCurrentCUDAStream()); return {output, scales}; } @@ -315,14 +313,14 @@ std::vector quantized_reduction(at::Tensor& input_vals, } std::vector loco_quantized_reduction(at::Tensor& input_vals, - at::Tensor& input_scales, - at::Tensor& error_feedback, - float err_beta, - int in_groups, - int out_groups, - int num_bits, - quantize::Type quant_type, - int devices_per_node) + at::Tensor& input_scales, + at::Tensor& error_feedback, + float err_beta, + int in_groups, + int out_groups, + int num_bits, + quantize::Type quant_type, + int devices_per_node) { auto scales_options = at::TensorOptions() .dtype(at::kFloat) @@ -341,7 +339,7 @@ std::vector loco_quantized_reduction(at::Tensor& input_vals, .requires_grad(false); std::vector sz(input_vals.sizes().begin(), input_vals.sizes().end()); - sz[sz.size() - 1] = sz.back() / devices_per_node; + sz[sz.size() - 1] = sz.back() / devices_per_node; const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node; @@ -350,23 +348,21 @@ std::vector loco_quantized_reduction(at::Tensor& input_vals, const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node); const int elems_per_out_group = elems_per_in_tensor / out_groups; - launch_loco_dequant_reduce( - (int8_t*)output.data_ptr(), - (float*)scales.data_ptr(), - (const int8_t*)input_vals.data_ptr(), - (const float*)input_scales.data_ptr(), - devices_per_node, - num_bits, - quant_type, - out_groups, - elems_per_out_group, - elems_per_in_tensor, - in_groups / devices_per_node, - elems_per_in_group, - (half*)error_feedback.data_ptr(), - err_beta, - at::cuda::getCurrentCUDAStream() - ); + launch_loco_dequant_reduce((int8_t*)output.data_ptr(), + (float*)scales.data_ptr(), + (const int8_t*)input_vals.data_ptr(), + (const float*)input_scales.data_ptr(), + devices_per_node, + num_bits, + quant_type, + out_groups, + elems_per_out_group, + elems_per_in_tensor, + in_groups / devices_per_node, + elems_per_in_group, + (half*)error_feedback.data_ptr(), + err_beta, + at::cuda::getCurrentCUDAStream()); return {output, scales}; } @@ -402,5 +398,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("swizzle_quant", &ds_swizzle_quant); m.def("quantized_reduction", &quantized_reduction); m.def("loco_swizzle_quant", &ds_loco_swizzle_quant, "LoCo Swizzled Quantization Kernel"); - m.def("loco_quantized_reduction", &loco_quantized_reduction, "LoCo Quantization and Reduction Kernel"); + m.def("loco_quantized_reduction", + &loco_quantized_reduction, + "LoCo Quantization and Reduction Kernel"); } diff --git a/csrc/quantization/quant_reduce.cu b/csrc/quantization/quant_reduce.cu index ef39cde6b9c1..ed1dae8470c7 100644 --- a/csrc/quantization/quant_reduce.cu +++ b/csrc/quantization/quant_reduce.cu @@ -262,8 +262,6 @@ void launch_dequant_reduce(int8_t* reduced_data, } } - - /* Modified loco_dequant_reduce function that performs dequantization and reduction, and incorporates error-feedback by updating the error_feedback tensor in-place. @@ -303,14 +301,13 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data __half local_buffer[totalChunks * storage_values]; __half err_buffer[totalChunks * storage_values]; - quantize::GroupStats stats; + quantize::GroupStats stats; #pragma unroll for (int i = 0; i < totalChunks; i++) { __half* iteration_buffer = local_buffer + i * storage_values; __half* iter_err_buffer = err_buffer + i * storage_values; - #pragma unroll for (int j = 0; j < storage_values; j++) { iteration_buffer[j] = reduce::init(); @@ -334,7 +331,8 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data input_scales + j * groups_per_in_tensor, iter_scale_idx); __half dequant_buffer[storage_values]; - dequantize::chunk<__half, numBits, quantType>(dequant_buffer, load_buffer, params); + dequantize::chunk<__half, numBits, quantType>( + dequant_buffer, load_buffer, params); #pragma unroll for (int k = 0; k < storage_values; k++) { @@ -356,7 +354,8 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data input_scales + j * groups_per_in_tensor, iter_scale_idx); __half dequant_buffer[storage_values]; - dequantize::chunk<__half, numBits, quantType>(dequant_buffer, load_buffer, params); + dequantize::chunk<__half, numBits, quantType>( + dequant_buffer, load_buffer, params); #pragma unroll for (int k = 0; k < storage_values; k++) { @@ -367,7 +366,7 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data } } mem_access::load_global( - iter_err_buffer, error_feedback + iter_offset_err, do_loads); + iter_err_buffer, error_feedback + iter_offset_err, do_loads); #pragma unroll for (int k = 0; k < storage_values; k++) { iteration_buffer[k] = __hadd(iteration_buffer[k], iter_err_buffer[k]); @@ -380,14 +379,10 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data // Initialize dequantization parameters based on params auto de_params = params; de_params.scale = 1.0f / params.scale; - if constexpr (quantType == quantize::Type::Asymmetric) { - de_params.offset = params.offset; - } + if constexpr (quantType == quantize::Type::Asymmetric) { de_params.offset = params.offset; } + + if (tb.thread_index().x == 0) { params.store(reduced_scales, tb.group_index().x); } - if (tb.thread_index().x == 0) { - params.store(reduced_scales, tb.group_index().x); - } - #pragma unroll for (int i = 0; i < totalChunks; i++) { const int iter_offset = i * stride + base_offset; @@ -398,8 +393,7 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data if (i * stride + elem_offset < elems_per_out_group) { // ----------- Begin Error-Feedback Modification ----------- int8_t local_output[elems_per_load]; - quantize::_chunk( - local_output, iteration_buffer, params); + quantize::_chunk(local_output, iteration_buffer, params); mem_access::store_global(reduced_data + iter_offset, local_output); // Dequantize the quantized output to compute the dequantized value @@ -410,10 +404,10 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data for (int k = 0; k < storage_values; k++) { __half new_error = __hsub(iteration_buffer[k], dequant_buffer[k]); iter_err_buffer[k] = __hmul(iter_err_buffer[k], __float2half(err_beta)) + - __hmul(__float2half(1.0f - err_beta), new_error); + __hmul(__float2half(1.0f - err_beta), new_error); } mem_access::store_global<16>(error_feedback + iter_offset_err, iter_err_buffer); - } + } } } @@ -480,19 +474,19 @@ void launch_loco_dequant_reduce_impl(int8_t* reduced_data, } } -#define LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \ - launch_loco_dequant_reduce_impl(reduced_data, \ - reduced_scales, \ - input_data, \ - input_scales, \ - out_groups, \ - elems_per_out_group, \ - elems_per_in_tensor, \ - groups_per_in_tensor, \ - elems_per_in_group, \ - num_gpus, \ - error_feedback, \ - err_beta, \ +#define LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \ + launch_loco_dequant_reduce_impl(reduced_data, \ + reduced_scales, \ + input_data, \ + input_scales, \ + out_groups, \ + elems_per_out_group, \ + elems_per_in_tensor, \ + groups_per_in_tensor, \ + elems_per_in_group, \ + num_gpus, \ + error_feedback, \ + err_beta, \ stream); void launch_loco_dequant_reduce(int8_t* reduced_data, @@ -549,4 +543,3 @@ void launch_loco_dequant_reduce(int8_t* reduced_data, } } } - diff --git a/csrc/quantization/swizzled_quantize.cu b/csrc/quantization/swizzled_quantize.cu index 1aa5ef483c90..ef539ef18200 100644 --- a/csrc/quantization/swizzled_quantize.cu +++ b/csrc/quantization/swizzled_quantize.cu @@ -3,10 +3,10 @@ // DeepSpeed Team +#include "dequantization_utils.h" #include "memory_access_utils.h" #include "quantization_utils.h" #include "reduction_utils.h" -#include "dequantization_utils.h" using rop = reduce::ROpType; @@ -196,7 +196,6 @@ void launch_swizzled_quant(int8_t* q_data, } } - template __global__ void loco_swizzled_quant_kernel(int8_t* quantized_data, float* quantized_scales, @@ -219,33 +218,29 @@ __global__ void loco_swizzled_quant_kernel(int8_t* quantized_data, int device_id = blockIdx_z_swizzled / nodes; int node_id = blockIdx_z_swizzled % nodes; - int blockIdx_z_orig = node_id * devices_per_node + device_id; - - int block_rank_orig = blockIdx_x + blockIdx_y * gridDim.x + blockIdx_z_orig * gridDim.x * gridDim.y; - + int block_rank_orig = + blockIdx_x + blockIdx_y * gridDim.x + blockIdx_z_orig * gridDim.x * gridDim.y; const int elem_offset = threadIdx.x * quantize::h_per_load; const int block_offset_orig = block_rank_orig * elems_per_group; const int base_offset_orig = block_offset_orig + elem_offset; const int stride = blockDim.x * quantize::h_per_load; const __half* uncompressed_data_base = uncompressed_data + base_offset_orig; - - const int block_rank_swizzled = blockIdx_x + blockIdx_y * gridDim.x + blockIdx_z_swizzled * gridDim.x * gridDim.y; + + const int block_rank_swizzled = + blockIdx_x + blockIdx_y * gridDim.x + blockIdx_z_swizzled * gridDim.x * gridDim.y; const int block_offset_swizzled = block_rank_swizzled * elems_per_group; const int base_offset_swizzled = block_offset_swizzled + elem_offset; const __half* error_feedback_base = error_feedback + base_offset_swizzled; - __half local_buffer[totalChunks * quantize::h_per_load]; __half err_buffer[totalChunks * quantize::h_per_load]; + quantize::GroupStats stats; - quantize::GroupStats stats; - - - #pragma unroll +#pragma unroll for (int i = 0; i < totalChunks; i++) { __half* iteration_buffer = local_buffer + i * quantize::h_per_load; __half* iter_err_buffer = err_buffer + i * quantize::h_per_load; @@ -253,33 +248,26 @@ __global__ void loco_swizzled_quant_kernel(int8_t* quantized_data, bool do_loads = (elem_offset + i_stride) < elems_per_group; mem_access::load_global( - iteration_buffer, uncompressed_data_base + i_stride, - do_loads); - + iteration_buffer, uncompressed_data_base + i_stride, do_loads); + mem_access::load_global( - iter_err_buffer, error_feedback_base + i_stride, - do_loads); - - #pragma unroll + iter_err_buffer, error_feedback_base + i_stride, do_loads); + +#pragma unroll for (int j = 0; j < quantize::h_per_load; j++) { iteration_buffer[j] = __hadd(iteration_buffer[j], iter_err_buffer[j]); stats.update(iteration_buffer[j]); } } - auto params = stats.template get_params(tb, warp); // Initialize dequantization parameters based on params auto de_params = params; de_params.scale = 1.0f / params.scale; - if constexpr (quantType == quantize::Type::Asymmetric) { - de_params.offset = params.offset; - } + if constexpr (quantType == quantize::Type::Asymmetric) { de_params.offset = params.offset; } - if (threadIdx.x == 0) { - params.store(quantized_scales, block_rank_swizzled); - } + if (threadIdx.x == 0) { params.store(quantized_scales, block_rank_swizzled); } constexpr int out_scalar_effect = 8 / numBits; const int out_block_offset = block_rank_swizzled * elems_per_group / out_scalar_effect; @@ -289,8 +277,7 @@ __global__ void loco_swizzled_quant_kernel(int8_t* quantized_data, const int out_stride = stride / out_scalar_effect; constexpr int num_int8_out = quantize::h_per_load / out_scalar_effect; - - #pragma unroll +#pragma unroll for (int i = 0; i < totalChunks; i++) { const int iter_offset = i * stride + base_offset_swizzled; __half* iteration_buffer = local_buffer + i * quantize::h_per_load; @@ -298,33 +285,37 @@ __global__ void loco_swizzled_quant_kernel(int8_t* quantized_data, if (i * stride + elem_offset < elems_per_group) { int8_t local_output[quantize::h_per_load / out_scalar_effect]; - quantize::_chunk( - local_output, iteration_buffer, params); + quantize::_chunk(local_output, iteration_buffer, params); mem_access::store_global(out_base + i * out_stride, local_output); // Dequantize the quantized output to compute the dequantized value __half dequant_buffer[quantize::h_per_load]; dequantize::chunk<__half, numBits, quantType>(dequant_buffer, local_output, de_params); - // Compute new error: sum - dequant_buffer - #pragma unroll +// Compute new error: sum - dequant_buffer +#pragma unroll for (int j = 0; j < quantize::h_per_load; j++) { __half new_error = __hsub(iteration_buffer[j], dequant_buffer[j]); iter_err_buffer[j] = __hmul(iter_err_buffer[j], __float2half(err_beta)) + - __hmul(__float2half(1.0f - err_beta), new_error); + __hmul(__float2half(1.0f - err_beta), new_error); } mem_access::store_global<16>(error_feedback + iter_offset, iter_err_buffer); } } } - - - -#define LAUNCH_LOCO_SWIZZLE_QUANT(total_chunks, threads) \ - loco_swizzled_quant_kernel<<>>( \ - output_data, params, input_data, error_feedback, err_beta, \ - groups, elems_per_group, pipelining, nodes, devices_per_node); +#define LAUNCH_LOCO_SWIZZLE_QUANT(total_chunks, threads) \ + loco_swizzled_quant_kernel \ + <<>>(output_data, \ + params, \ + input_data, \ + error_feedback, \ + err_beta, \ + groups, \ + elems_per_group, \ + pipelining, \ + nodes, \ + devices_per_node); template void launch_loco_swizzled_quant_impl(int8_t* output_data, @@ -383,12 +374,17 @@ void launch_loco_swizzled_quant_impl(int8_t* output_data, } #define DISPATCH_LOCO_SWIZZLE_QUANT(num_bits, qtype) \ - launch_loco_swizzled_quant_impl( \ - output_data, params, input_data, error_feedback, err_beta, \ - groups, elems_per_group, pipelining, nodes, devices_per_node, \ - stream); - - + launch_loco_swizzled_quant_impl(output_data, \ + params, \ + input_data, \ + error_feedback, \ + err_beta, \ + groups, \ + elems_per_group, \ + pipelining, \ + nodes, \ + devices_per_node, \ + stream); void launch_loco_swizzled_quant(int8_t* output_data, float* params, @@ -417,4 +413,4 @@ void launch_loco_swizzled_quant(int8_t* output_data, DISPATCH_LOCO_SWIZZLE_QUANT(8, quantize::Type::Symmetric); } } -} \ No newline at end of file +} diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py index 65144087d79f..c2fa907d7dbb 100644 --- a/deepspeed/runtime/comm/coalesced_collectives.py +++ b/deepspeed/runtime/comm/coalesced_collectives.py @@ -75,11 +75,13 @@ def all_to_all_quant_reduce(tensors: List[Tensor], groups: {}) -> List[Tensor]: output_lst[idx] = (sum(list(final_output.chunk(num_nodes))) / num_nodes).view(-1) return output_lst + @instrument_w_nvtx @torch.no_grad() -def all_to_all_loco_quant_reduce(params: List[Tensor], - groups: {}, - loco_param: Any = None, +def all_to_all_loco_quant_reduce( + params: List[Tensor], + groups: {}, + loco_param: Any = None, ) -> List[Tensor]: global quantizer_module global loco_idx @@ -110,50 +112,47 @@ def all_to_all_loco_quant_reduce(params: List[Tensor], if not hasattr(p, 'intra_ef_buf') or loco_idx > reset_T: loco_idx = 0 intra_err = torch.zeros_like(p.grad) - inter_err = torch.zeros(tensor.numel()//local_world_size, - device = tensor.device, dtype = tensor.dtype) + inter_err = torch.zeros(tensor.numel() // local_world_size, device=tensor.device, dtype=tensor.dtype) else: - intra_err = quantizer_module.dequantize(p.intra_ef_buf[0], p.intra_ef_buf[1], p.intra_ef_buf[1].numel(), - 8, quantizer_module.Symmetric) - inter_err = quantizer_module.dequantize(p.inter_ef_buf[0], p.inter_ef_buf[1], p.inter_ef_buf[1].numel(), - 8, quantizer_module.Symmetric) + intra_err = quantizer_module.dequantize(p.intra_ef_buf[0], p.intra_ef_buf[1], + p.intra_ef_buf[1].numel(), 8, quantizer_module.Symmetric) + inter_err = quantizer_module.dequantize(p.inter_ef_buf[0], p.inter_ef_buf[1], + p.inter_ef_buf[1].numel(), 8, quantizer_module.Symmetric) intra_quant_group = max(tensor.shape[0], tensor.shape[1], global_world_size) inter_quant_group = intra_quant_group // local_world_size intra_quant_int4, intra_q_scales = quantizer_module.loco_swizzle_quant(tensor, intra_err, err_beta, - intra_quant_group, 4, - quantizer_module.Symmetric, 1, num_nodes, - local_world_size) + intra_quant_group, 4, + quantizer_module.Symmetric, 1, + num_nodes, local_world_size) local_output = torch.empty_like(intra_quant_int4) scale_output = torch.empty_like(intra_q_scales) all_to_all_single(local_output, intra_quant_int4, group=groups[f'local_{intra_idx}']) all_to_all_single(scale_output, intra_q_scales, group=groups[f'local_{intra_idx}']) - p.intra_ef_buf = quantizer_module.quantize(intra_err, intra_quant_group, 8, - quantizer_module.Symmetric) + p.intra_ef_buf = quantizer_module.quantize(intra_err, intra_quant_group, 8, quantizer_module.Symmetric) global_input_tensor, global_scales = quantizer_module.loco_quantized_reduction( - local_output, scale_output, inter_err, err_beta, - intra_quant_group, inter_quant_group, 4, quantizer_module.Symmetric, - local_world_size) - + local_output, scale_output, inter_err, err_beta, intra_quant_group, inter_quant_group, 4, + quantizer_module.Symmetric, local_world_size) + global_output = torch.empty_like(global_input_tensor) global_scale_output = torch.empty_like(global_scales) all_to_all_single(global_output, global_input_tensor, group=groups[f'global_{inter_idx}']) all_to_all_single(global_scale_output, global_scales, group=groups[f'global_{inter_idx}']) - p.inter_ef_buf = quantizer_module.quantize(inter_err, inter_quant_group, 8, - quantizer_module.Symmetric) + p.inter_ef_buf = quantizer_module.quantize(inter_err, inter_quant_group, 8, quantizer_module.Symmetric) final_output = quantizer_module.dequantize(global_output, global_scale_output, global_scale_output.numel(), 4, quantizer_module.Symmetric) assert final_output.numel( ) % num_nodes == 0, f"final_output.numel()={final_output.numel()} is not divisible by num_nodes={num_nodes}" output_lst[idx] = (sum(list(final_output.chunk(num_nodes))) / num_nodes).view(-1) - loco_idx +=1 + loco_idx += 1 return output_lst + @instrument_w_nvtx @torch.no_grad() def reduce_scatter_coalesced( diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 7c52ca3f17f3..2088999a6843 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -158,7 +158,7 @@ def __init__( zero_quantized_weights=False, zero_quantized_nontrainable_weights=False, zero_module_granularity_threshold=0, - zeropp_loco_param = None, + zeropp_loco_param=None, ): see_memory_usage("Stage 3 initialize beginning", force=True) @@ -1386,11 +1386,10 @@ def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor] global_world_size = dist.get_world_size() num_nodes = global_world_size // local_world_size if self.all2all_process_group is not None and num_nodes > 1: - grad_partitions_for_rank = ( - all_to_all_loco_quant_reduce(params_to_reduce, self.all2all_process_group, self.zeropp_loco_param) - if self.zeropp_loco_param is not None else - all_to_all_quant_reduce(full_grads_for_rank, self.all2all_process_group) - ) + grad_partitions_for_rank = (all_to_all_loco_quant_reduce(params_to_reduce, self.all2all_process_group, + self.zeropp_loco_param) + if self.zeropp_loco_param is not None else all_to_all_quant_reduce( + full_grads_for_rank, self.all2all_process_group)) else: grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, self.dp_process_group) @@ -2035,7 +2034,6 @@ def _loco_err_buf_update(self, overflow: bool, scale=1.0): p.intra_ef_buf[1] *= scale p.inter_ef_buf[1] *= scale - @instrument_w_nvtx def _overflow_check_and_loss_scale_update(self): @@ -2049,9 +2047,9 @@ def _overflow_check_and_loss_scale_update(self): if self.overflow: self._overflow_clean_up(prev_scale) - + #update loco error buf - self._loco_err_buf_update(self.overflow, self.loss_scale/prev_scale) + self._loco_err_buf_update(self.overflow, self.loss_scale / prev_scale) return self.overflow