Skip to content

Commit

Permalink
Merge LoCo with Zero++
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuanxinTang committed Nov 8, 2024
1 parent a1b0c35 commit 9118612
Show file tree
Hide file tree
Showing 9 changed files with 825 additions and 82 deletions.
30 changes: 30 additions & 0 deletions csrc/includes/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,36 @@ void launch_swizzled_quant(int8_t* q_data,
int devices_per_node,
cudaStream_t stream);

void launch_loco_swizzled_quant(int8_t* quantized_data,
float* quantized_scales,
const __half* uncompressed_data,
__half* error_feedback,
const float err_beta,
int num_bits,
quantize::Type quant_type,
int groups,
int elems_per_group,
int pipelining,
int nodes,
int devices_per_node,
cudaStream_t stream);

void launch_loco_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
const float* input_scales,
int num_gpus,
int num_bits,
quantize::Type quant_type,
int out_groups,
int elems_per_out_group,
int elems_per_in_tensor,
int groups_per_in_tensor,
int elems_per_in_group,
__half* error_feedback,
const float err_beta,
cudaStream_t stream);

void launch_dequant_reduce(int8_t* reduced_data,
float* reduced_scales,
const int8_t* input_data,
Expand Down
128 changes: 51 additions & 77 deletions csrc/includes/quantization_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ constexpr int max_threads = 1024;
Class to hold the quantization parameters for a given tensor.
Holds the implementation of the quantization operation.
*/

template <Type qType, int numBits>
class Params {
public:
Expand Down Expand Up @@ -145,112 +146,85 @@ class Params<Type::Asymmetric, numBits> {
Group stats tracks the necessary statistics about the quantized group
to abstract the particulars for the main loop.
*/
template <Type qType>
class GroupStats {
public:
DS_D_INLINE void update(__half2 val);
// Helper functions
DS_D_INLINE __half h_abs(const __half& val) {
return __habs(val);
}

DS_D_INLINE void reduce(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp);
};
DS_D_INLINE __half2 h_abs(const __half2& val) {
return __habs2(val);
}

DS_D_INLINE float to_max_float(const __half& val) {
return __half2float(val);
}

DS_D_INLINE float to_min_float(const __half& val) {
return __half2float(val);
}

DS_D_INLINE float to_max_float(const __half2& val) {
const float2 partial_max = conversion::to<float2>(val);
return reduce::element<rop::Max>(partial_max.x, partial_max.y);
}

template <>
class GroupStats<Type::Symmetric> {
DS_D_INLINE float to_min_float(const __half2& val) {
const float2 partial_min = conversion::to<float2>(val);
return reduce::element<rop::Min>(partial_min.x, partial_min.y);
}

// GroupStats class template
template <Type qType, typename DataType = __half2>
class GroupStats;

// Symmetric Quantization
template <typename DataType>
class GroupStats<Type::Symmetric, DataType> {
public:
// Symmetric quantization only tracks the maximum absolute value
__half2 cur_max;
float max;
DataType cur_max;

/*
Technically, this would give bad results if there
are 0 values to process since the reduction would
give -inf instead of 0. We do not consider this
to be a reasonable edge case.
*/
DS_D_INLINE GroupStats() { cur_max = reduce::init<rop::Max, __half2>(); }
DS_D_INLINE GroupStats() { cur_max = reduce::init<rop::Max, DataType>(); }

/*
Updated the running absmax used to calculate params.
Function Arguments :
val : The __half2 value to update the running min and max with.
*/
DS_D_INLINE void update(__half2 val)
{
cur_max = reduce::element<rop::Max>(cur_max, __habs2(val));
DS_D_INLINE void update(DataType val) {
cur_max = reduce::element<rop::Max>(cur_max, h_abs(val));
}

/*
Function to return calculated quantization params.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
Function Arguments :
tb - Threadblock object. cg::thread_block
warp - Warp object. cg::thread_block_tile<hw_warp_size>
*/
template <int numBits, int threads_per_group>
DS_D_INLINE Params<Type::Symmetric, numBits> get_params(
cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp)
{
const float2 partial_max = conversion::to<float2>(cur_max);
float max = reduce::element<rop::Max>(partial_max.x, partial_max.y);

cg::thread_block_tile<hw_warp_size>& warp) {
float max = to_max_float(cur_max);
reduce::partitioned_block<rop::Max, threads_per_group>(tb, warp, max);
Params<Type::Symmetric, numBits> params(max);

return params;
}
};

template <>
class GroupStats<Type::Asymmetric> {
// Asymmetric Quantization
template <typename DataType>
class GroupStats<Type::Asymmetric, DataType> {
public:
__half2 cur_max;
__half2 cur_min;
DataType cur_max;
DataType cur_min;

/*
Initialize cur_max to -inf, cur_min to inf since
we are doing a true range analysis.
*/
DS_D_INLINE GroupStats()
{
cur_max = reduce::init<rop::Max, __half2>();
cur_min = reduce::init<rop::Min, __half2>();
DS_D_INLINE GroupStats() {
cur_max = reduce::init<rop::Max, DataType>();
cur_min = reduce::init<rop::Min, DataType>();
}

/*
Updated the running min and max used to calculate params.
Function Arguments :
val : The __half2 value to update the running min and max with.
*/
DS_D_INLINE void update(__half2 val)
{
DS_D_INLINE void update(DataType val) {
cur_max = reduce::element<rop::Max>(cur_max, val);
cur_min = reduce::element<rop::Min>(cur_min, val);
}

/*
Function to return calculated quantization params.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
Function Arguments :
tb - Threadblock object. cg::thread_block
warp - Warp object. cg::thread_block_tile<hw_warp_size>
*/
template <int numBits, int threads_per_group>
DS_D_INLINE Params<Type::Asymmetric, numBits> get_params(
cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp)
{
const float2 partial_max = conversion::to<float2>(cur_max);
float max = reduce::element<rop::Max>(partial_max.x, partial_max.y);

const float2 partial_min = conversion::to<float2>(cur_min);
float min = reduce::element<rop::Min>(partial_min.x, partial_min.y);

cg::thread_block_tile<hw_warp_size>& warp) {
float max = to_max_float(cur_max);
float min = to_min_float(cur_min);
reduce::partitioned_block<rop::Max, rop::Min, threads_per_group>(tb, warp, max, min);

Params<Type::Asymmetric, numBits> params(max, min);

return params;
}
};
Expand Down
108 changes: 108 additions & 0 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,55 @@ at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
return output;
}

std::vector<at::Tensor> ds_loco_swizzle_quant(at::Tensor& input_vals,
at::Tensor& error_feedback,
float err_beta,
int groups,
int num_bits,
quantize::Type quant_type,
int pipeline_size,
int nodes,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;
auto scales = torch::empty({groups, scales_elems}, scales_options);

auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

const int quantization_scalar = 8 / num_bits;
const int compressed_vals = at::numel(input_vals) / quantization_scalar;

auto output = torch::empty({compressed_vals}, output_options);
const int elems_per_group = at::numel(input_vals) / groups;

launch_loco_swizzled_quant(
reinterpret_cast<int8_t*>(output.data_ptr()),
reinterpret_cast<float*>(scales.data_ptr()),
reinterpret_cast<const __half*>(input_vals.data_ptr()),
reinterpret_cast<__half*>(error_feedback.data_ptr()),
err_beta,
num_bits,
quant_type,
groups,
elems_per_group,
pipeline_size,
nodes,
devices_per_node,
at::cuda::getCurrentCUDAStream()
);

return {output, scales};
}

std::vector<at::Tensor> ds_swizzle_quant(at::Tensor& input_vals,
int groups,
int num_bits,
Expand Down Expand Up @@ -265,6 +314,63 @@ std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
return {output, scales};
}

std::vector<at::Tensor> loco_quantized_reduction(at::Tensor& input_vals,
at::Tensor& input_scales,
at::Tensor& error_feedback,
float err_beta,
int in_groups,
int out_groups,
int num_bits,
quantize::Type quant_type,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

const int scales_elems = (quantize::requires_offset(quant_type)) ? 2 : 1;

auto scales = torch::empty({out_groups, scales_elems}, scales_options);

auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);

std::vector<int64_t> sz(input_vals.sizes().begin(), input_vals.sizes().end());
sz[sz.size() - 1] = sz.back() / devices_per_node;

const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node;

auto output = torch::empty(sz, output_options);

const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node);
const int elems_per_out_group = elems_per_in_tensor / out_groups;

launch_loco_dequant_reduce(
(int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(const int8_t*)input_vals.data_ptr(),
(const float*)input_scales.data_ptr(),
devices_per_node,
num_bits,
quant_type,
out_groups,
elems_per_out_group,
elems_per_in_tensor,
in_groups / devices_per_node,
elems_per_in_group,
(half*)error_feedback.data_ptr(),
err_beta,
at::cuda::getCurrentCUDAStream()
);

return {output, scales};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
Expand Down Expand Up @@ -295,4 +401,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"Dequantize int8 to half (experimental)");
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
m.def("loco_swizzle_quant", &ds_loco_swizzle_quant, "LoCo Swizzled Quantization Kernel");
m.def("loco_quantized_reduction", &loco_quantized_reduction, "LoCo Quantization and Reduction Kernel");
}
Loading

0 comments on commit 9118612

Please sign in to comment.