Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams committed Nov 15, 2024
1 parent 03dc2b9 commit 3c9c91a
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 179 deletions.
37 changes: 18 additions & 19 deletions csrc/includes/quantization_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,28 +147,22 @@ Group stats tracks the necessary statistics about the quantized group
to abstract the particulars for the main loop.
*/
// Helper functions
DS_D_INLINE __half h_abs(const __half& val) {
return __habs(val);
}
DS_D_INLINE __half h_abs(const __half& val) { return __habs(val); }

DS_D_INLINE __half2 h_abs(const __half2& val) {
return __habs2(val);
}
DS_D_INLINE __half2 h_abs(const __half2& val) { return __habs2(val); }

DS_D_INLINE float to_max_float(const __half& val) {
return __half2float(val);
}
DS_D_INLINE float to_max_float(const __half& val) { return __half2float(val); }

DS_D_INLINE float to_min_float(const __half& val) {
return __half2float(val);
}
DS_D_INLINE float to_min_float(const __half& val) { return __half2float(val); }

DS_D_INLINE float to_max_float(const __half2& val) {
DS_D_INLINE float to_max_float(const __half2& val)
{
const float2 partial_max = conversion::to<float2>(val);
return reduce::element<rop::Max>(partial_max.x, partial_max.y);
}

DS_D_INLINE float to_min_float(const __half2& val) {
DS_D_INLINE float to_min_float(const __half2& val)
{
const float2 partial_min = conversion::to<float2>(val);
return reduce::element<rop::Min>(partial_min.x, partial_min.y);
}
Expand All @@ -185,14 +179,16 @@ class GroupStats<Type::Symmetric, DataType> {

DS_D_INLINE GroupStats() { cur_max = reduce::init<rop::Max, DataType>(); }

DS_D_INLINE void update(DataType val) {
DS_D_INLINE void update(DataType val)
{
cur_max = reduce::element<rop::Max>(cur_max, h_abs(val));
}

template <int numBits, int threads_per_group>
DS_D_INLINE Params<Type::Symmetric, numBits> get_params(
cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp) {
cg::thread_block_tile<hw_warp_size>& warp)
{
float max = to_max_float(cur_max);
reduce::partitioned_block<rop::Max, threads_per_group>(tb, warp, max);
Params<Type::Symmetric, numBits> params(max);
Expand All @@ -207,20 +203,23 @@ class GroupStats<Type::Asymmetric, DataType> {
DataType cur_max;
DataType cur_min;

DS_D_INLINE GroupStats() {
DS_D_INLINE GroupStats()
{
cur_max = reduce::init<rop::Max, DataType>();
cur_min = reduce::init<rop::Min, DataType>();
}

DS_D_INLINE void update(DataType val) {
DS_D_INLINE void update(DataType val)
{
cur_max = reduce::element<rop::Max>(cur_max, val);
cur_min = reduce::element<rop::Min>(cur_min, val);
}

template <int numBits, int threads_per_group>
DS_D_INLINE Params<Type::Asymmetric, numBits> get_params(
cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp) {
cg::thread_block_tile<hw_warp_size>& warp)
{
float max = to_max_float(cur_max);
float min = to_min_float(cur_min);
reduce::partitioned_block<rop::Max, rop::Min, threads_per_group>(tb, warp, max, min);
Expand Down
98 changes: 48 additions & 50 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,14 @@ at::Tensor dequantize_int8_to_half_experimental(at::Tensor& data_in,
}

std::vector<at::Tensor> ds_loco_swizzle_quant(at::Tensor& input_vals,
at::Tensor& error_feedback,
float err_beta,
int groups,
int num_bits,
quantize::Type quant_type,
int pipeline_size,
int nodes,
int devices_per_node)
at::Tensor& error_feedback,
float err_beta,
int groups,
int num_bits,
quantize::Type quant_type,
int pipeline_size,
int nodes,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
Expand All @@ -206,21 +206,19 @@ std::vector<at::Tensor> ds_loco_swizzle_quant(at::Tensor& input_vals,
auto output = torch::empty({compressed_vals}, output_options);
const int elems_per_group = at::numel(input_vals) / groups;

launch_loco_swizzled_quant(
reinterpret_cast<int8_t*>(output.data_ptr()),
reinterpret_cast<float*>(scales.data_ptr()),
reinterpret_cast<const __half*>(input_vals.data_ptr()),
reinterpret_cast<__half*>(error_feedback.data_ptr()),
err_beta,
num_bits,
quant_type,
groups,
elems_per_group,
pipeline_size,
nodes,
devices_per_node,
at::cuda::getCurrentCUDAStream()
);
launch_loco_swizzled_quant(reinterpret_cast<int8_t*>(output.data_ptr()),
reinterpret_cast<float*>(scales.data_ptr()),
reinterpret_cast<const __half*>(input_vals.data_ptr()),
reinterpret_cast<__half*>(error_feedback.data_ptr()),
err_beta,
num_bits,
quant_type,
groups,
elems_per_group,
pipeline_size,
nodes,
devices_per_node,
at::cuda::getCurrentCUDAStream());

return {output, scales};
}
Expand Down Expand Up @@ -315,14 +313,14 @@ std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
}

std::vector<at::Tensor> loco_quantized_reduction(at::Tensor& input_vals,
at::Tensor& input_scales,
at::Tensor& error_feedback,
float err_beta,
int in_groups,
int out_groups,
int num_bits,
quantize::Type quant_type,
int devices_per_node)
at::Tensor& input_scales,
at::Tensor& error_feedback,
float err_beta,
int in_groups,
int out_groups,
int num_bits,
quantize::Type quant_type,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
Expand All @@ -341,7 +339,7 @@ std::vector<at::Tensor> loco_quantized_reduction(at::Tensor& input_vals,
.requires_grad(false);

std::vector<int64_t> sz(input_vals.sizes().begin(), input_vals.sizes().end());
sz[sz.size() - 1] = sz.back() / devices_per_node;
sz[sz.size() - 1] = sz.back() / devices_per_node;

const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node;

Expand All @@ -350,23 +348,21 @@ std::vector<at::Tensor> loco_quantized_reduction(at::Tensor& input_vals,
const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node);
const int elems_per_out_group = elems_per_in_tensor / out_groups;

launch_loco_dequant_reduce(
(int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(const int8_t*)input_vals.data_ptr(),
(const float*)input_scales.data_ptr(),
devices_per_node,
num_bits,
quant_type,
out_groups,
elems_per_out_group,
elems_per_in_tensor,
in_groups / devices_per_node,
elems_per_in_group,
(half*)error_feedback.data_ptr(),
err_beta,
at::cuda::getCurrentCUDAStream()
);
launch_loco_dequant_reduce((int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(const int8_t*)input_vals.data_ptr(),
(const float*)input_scales.data_ptr(),
devices_per_node,
num_bits,
quant_type,
out_groups,
elems_per_out_group,
elems_per_in_tensor,
in_groups / devices_per_node,
elems_per_in_group,
(half*)error_feedback.data_ptr(),
err_beta,
at::cuda::getCurrentCUDAStream());

return {output, scales};
}
Expand Down Expand Up @@ -402,5 +398,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("swizzle_quant", &ds_swizzle_quant);
m.def("quantized_reduction", &quantized_reduction);
m.def("loco_swizzle_quant", &ds_loco_swizzle_quant, "LoCo Swizzled Quantization Kernel");
m.def("loco_quantized_reduction", &loco_quantized_reduction, "LoCo Quantization and Reduction Kernel");
m.def("loco_quantized_reduction",
&loco_quantized_reduction,
"LoCo Quantization and Reduction Kernel");
}
57 changes: 25 additions & 32 deletions csrc/quantization/quant_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,6 @@ void launch_dequant_reduce(int8_t* reduced_data,
}
}



/*
Modified loco_dequant_reduce function that performs dequantization and reduction,
and incorporates error-feedback by updating the error_feedback tensor in-place.
Expand Down Expand Up @@ -303,14 +301,13 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data
__half local_buffer[totalChunks * storage_values];
__half err_buffer[totalChunks * storage_values];

quantize::GroupStats<quantType, __half> stats;
quantize::GroupStats<quantType, __half> stats;

#pragma unroll
for (int i = 0; i < totalChunks; i++) {
__half* iteration_buffer = local_buffer + i * storage_values;
__half* iter_err_buffer = err_buffer + i * storage_values;


#pragma unroll
for (int j = 0; j < storage_values; j++) {
iteration_buffer[j] = reduce::init<rop::Add, __half>();
Expand All @@ -334,7 +331,8 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data
input_scales + j * groups_per_in_tensor, iter_scale_idx);

__half dequant_buffer[storage_values];
dequantize::chunk<__half, numBits, quantType>(dequant_buffer, load_buffer, params);
dequantize::chunk<__half, numBits, quantType>(
dequant_buffer, load_buffer, params);

#pragma unroll
for (int k = 0; k < storage_values; k++) {
Expand All @@ -356,7 +354,8 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data
input_scales + j * groups_per_in_tensor, iter_scale_idx);

__half dequant_buffer[storage_values];
dequantize::chunk<__half, numBits, quantType>(dequant_buffer, load_buffer, params);
dequantize::chunk<__half, numBits, quantType>(
dequant_buffer, load_buffer, params);

#pragma unroll
for (int k = 0; k < storage_values; k++) {
Expand All @@ -367,7 +366,7 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data
}
}
mem_access::load_global<quantize::granularity>(
iter_err_buffer, error_feedback + iter_offset_err, do_loads);
iter_err_buffer, error_feedback + iter_offset_err, do_loads);
#pragma unroll
for (int k = 0; k < storage_values; k++) {
iteration_buffer[k] = __hadd(iteration_buffer[k], iter_err_buffer[k]);
Expand All @@ -380,14 +379,10 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data
// Initialize dequantization parameters based on params
auto de_params = params;
de_params.scale = 1.0f / params.scale;
if constexpr (quantType == quantize::Type::Asymmetric) {
de_params.offset = params.offset;
}
if constexpr (quantType == quantize::Type::Asymmetric) { de_params.offset = params.offset; }

if (tb.thread_index().x == 0) { params.store(reduced_scales, tb.group_index().x); }

if (tb.thread_index().x == 0) {
params.store(reduced_scales, tb.group_index().x);
}

#pragma unroll
for (int i = 0; i < totalChunks; i++) {
const int iter_offset = i * stride + base_offset;
Expand All @@ -398,8 +393,7 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data
if (i * stride + elem_offset < elems_per_out_group) {
// ----------- Begin Error-Feedback Modification -----------
int8_t local_output[elems_per_load];
quantize::_chunk<numBits, quantType>(
local_output, iteration_buffer, params);
quantize::_chunk<numBits, quantType>(local_output, iteration_buffer, params);
mem_access::store_global<mem_granularity>(reduced_data + iter_offset, local_output);

// Dequantize the quantized output to compute the dequantized value
Expand All @@ -410,10 +404,10 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data
for (int k = 0; k < storage_values; k++) {
__half new_error = __hsub(iteration_buffer[k], dequant_buffer[k]);
iter_err_buffer[k] = __hmul(iter_err_buffer[k], __float2half(err_beta)) +
__hmul(__float2half(1.0f - err_beta), new_error);
__hmul(__float2half(1.0f - err_beta), new_error);
}
mem_access::store_global<16>(error_feedback + iter_offset_err, iter_err_buffer);
}
}
}
}

Expand Down Expand Up @@ -480,19 +474,19 @@ void launch_loco_dequant_reduce_impl(int8_t* reduced_data,
}
}

#define LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \
launch_loco_dequant_reduce_impl<NUM_BITS, NUM_GPUS, QUANT_TYPE>(reduced_data, \
reduced_scales, \
input_data, \
input_scales, \
out_groups, \
elems_per_out_group, \
elems_per_in_tensor, \
groups_per_in_tensor, \
elems_per_in_group, \
num_gpus, \
error_feedback, \
err_beta, \
#define LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \
launch_loco_dequant_reduce_impl<NUM_BITS, NUM_GPUS, QUANT_TYPE>(reduced_data, \
reduced_scales, \
input_data, \
input_scales, \
out_groups, \
elems_per_out_group, \
elems_per_in_tensor, \
groups_per_in_tensor, \
elems_per_in_group, \
num_gpus, \
error_feedback, \
err_beta, \
stream);

void launch_loco_dequant_reduce(int8_t* reduced_data,
Expand Down Expand Up @@ -549,4 +543,3 @@ void launch_loco_dequant_reduce(int8_t* reduced_data,
}
}
}

Loading

0 comments on commit 3c9c91a

Please sign in to comment.