Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Apr 19, 2024
2 parents e79ac1e + 99951ca commit 17b5529
Show file tree
Hide file tree
Showing 10 changed files with 357 additions and 49 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/amd-mi200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ name: amd-mi200

on:
workflow_dispatch:
pull_request:
paths:
- '.github/workflows/amd-mi200.yml'
- 'requirements/**'
schedule:
- cron: "0 0 * * *"

Expand Down
11 changes: 11 additions & 0 deletions csrc/fp_quantizer/includes/quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,14 @@ void launch_dequantization(uint8_t* val,
int q_mantisa_bits,
int q_exponent_bits,
cudaStream_t stream);

template <typename T, int mantisa>
void launch_selective_dequantization(uint8_t* val,
T* q_val,
int32_t* indexes,
int num_groups,
int group_size,
int num_indexes,
int q_mantisa_bits,
int q_exponent_bits,
cudaStream_t stream);
31 changes: 31 additions & 0 deletions csrc/fp_quantizer/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,39 @@ void dequantize(torch::Tensor& val,
#endif
}

#define DISPATCH_DEQUANTIZE_INDEX(T_TYPE, C_TYPE, mantisa) \
if (val.options().dtype() == torch::T_TYPE) { \
launch_selective_dequantization<C_TYPE, mantisa>((uint8_t*)val_q.data_ptr(), \
(C_TYPE*)val.data_ptr(), \
(int32_t*)indexes.data_ptr(), \
num_groups, \
group_size, \
num_indexes, \
q_mantisa_bits, \
q_exponent_bits, \
at::cuda::getCurrentCUDAStream()); \
return; \
}
void selective_dequantize(torch::Tensor& val,
torch::Tensor& val_q,
torch::Tensor& indexes,
int group_size,
int q_mantisa_bits,
int q_exponent_bits)
{
int total_elems = at::numel(val);
int num_indexes = indexes.size(0);
int num_groups = total_elems / group_size;

DISPATCH_DEQUANTIZE_INDEX(kHalf, __half, 10);
#ifdef BF16_AVAILABLE
DISPATCH_DEQUANTIZE_INDEX(kBFloat16, __nv_bfloat16, 7);
#endif
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("quantize", &quantize, "quantize function");
m.def("dequantize", &dequantize, "dequantize function");
m.def("selective_dequantize", &selective_dequantize, "selective dequantize function");
}
135 changes: 135 additions & 0 deletions csrc/fp_quantizer/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2);

} else {
mem_access::load_global<quantization::quanitzed_access_granularity>(int8_data,
load_base_ptr);
Expand Down Expand Up @@ -393,3 +394,137 @@ void launch_dequantization(uint8_t* val,
INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7);
#endif
INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10);

template <typename T,
int q_mantisa_bits,
int total_q_bits = 16,
int _mantisa_bits = 3,
int _exponent_bits = 4>
__global__ void apply_selective_dequantization(uint8_t* val,
T* q_val,
int32_t* indexes,
int group_size,
int total_num_elements)
{
int index = indexes[blockIdx.x];
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size;
int input_index = index * total_num_elements + tidx;
constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);
const uint32_t g_index = (input_index / group_size);
const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
const uint8_t* load_base_ptr =
val + g_index * (group_size_bytes + 4) + (input_index % group_size) * quantized_bits / 8;

int mantisa_mask = ((1 << q_mantisa_bits) - 1);
mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);

T* store_base_ptr = q_val + tidx + blockIdx.x * total_num_elements;
float scale;

uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&scale);
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
scale_as_int8 + quantization::quanitzed_access_granularity_6bits,
val + g_index * (group_size_bytes + 4) + group_size_bytes +
quantization::quanitzed_access_granularity_6bits);
} else
mem_access::load_global<quantization::quanitzed_access_granularity>(
scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);

if (tidx < total_num_elements) {
uint64_t q_buf_in;
uint64_t q_buf_in1;
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data, load_base_ptr);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits,
load_base_ptr + quantization::quanitzed_access_granularity_6bits);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2);
} else {
mem_access::load_global<quantization::quanitzed_access_granularity>(int8_data,
load_base_ptr);
if (quantized_bits > 4) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data + quantization::quanitzed_access_granularity,
load_base_ptr + quantization::quanitzed_access_granularity);
if (quantized_bits == 12) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2);
}
}
}
T store_buf[vector_size];
uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
#pragma unroll
for (int j = 0; j < vector_size; j++) {
uint16_t new_data;
if (j < 5 || quantized_bits != 12) {
new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
} else {
if (j == 5) {
new_data = (uint16_t)(q_buf_in1);
new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
} else
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
}

uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
uint16_t dst_mantisa = (new_data & _mantisa_mask);

if (dst_exponent != (1 << q_exponent_bits) - 1)
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
(1 << (q_exponent_bits - 1)) - 1;

q_buf[j] =
((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) |
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
float up_cast = conversion::to<float>(store_buf[j]);
store_buf[j] = conversion::to<T>(up_cast * scale);
}
mem_access::store_global<quantization::access_granularity>(store_base_ptr, store_buf);
}
}

template <typename T, int mantisa>
void launch_selective_dequantization(uint8_t* val,
T* q_val,
int32_t* indexes,
int num_groups,
int group_size,
int num_indexes,
int q_mantisa_bits,
int q_exponent_bits,
cudaStream_t stream)
{
int total_elements_per_index = (num_groups / num_indexes) * group_size;
int blocks = (total_elements_per_index - 1) /
(quantization::threads * (quantization::access_granularity / sizeof(T))) +
1;
const dim3 grid(num_indexes, blocks);
const dim3 block(quantization::threads);
DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] {
apply_selective_dequantization<T, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS>
<<<grid, block, 0, stream>>>(val, q_val, indexes, group_size, total_elements_per_index);
});
}
#define INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(T, mantisa) \
template void launch_selective_dequantization<T, mantisa>( \
uint8_t*, T*, int32_t*, int, int, int, int, int, cudaStream_t);
// fp8(E4M3)
#ifdef BF16_AVAILABLE
INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__nv_bfloat16, 7);
#endif
INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__half, 10);
20 changes: 17 additions & 3 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
cnt = 0


def dp_index_to_str(dp_index):
return f"{dp_index:0>2d}"


def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel):

global cnt # temp hack
Expand All @@ -140,9 +144,8 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor,
os.makedirs(param_base_path, exist_ok=True)

cnt += 1
counter = f"{dp_index:0>2d}"

path = os.path.join(param_base_path, f"{state_name}.{counter}")
path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}")

#print(f"{param_name}: {offset}: {numel} => {path}")

Expand All @@ -156,10 +159,21 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
slices = []
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
paths = sorted(list(glob.glob(f"{prefix_path}.*")))
paths = glob.glob(f"{prefix_path}.*")

if len(paths) == 0:
continue

pattern = re.compile(f"{prefix_path}\\.([0-9]+)")
dp_indices = set()
for p in paths:
m = pattern.match(p)
if m:
dp_indices.add(int(m.group(1)))
else:
raise ValueError(f"Cannot parse dp_rank from {p}")

paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
shards = [torch.load(p) for p in paths]

if state == "step":
Expand Down
35 changes: 35 additions & 0 deletions deepspeed/ops/fp_quantizer/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,38 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non

fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
return fp_out

def selective_dequantize(self,
input_q,
indexes,
fp_out=None,
q_bits=8,
q_mantisa_bits=3,
scale=None) -> torch.Tensor:
assert (not hasattr(self, 'orig_shape') or len(self.orig_shape) == 3), \
"Selective-Dequantization works on 3d tensor only! Please reshape the tensor before calling dequantize function."
assert (self.orig_dtype is not None), \
"[De-quantization Error]: you need to call quantize before dequantizing!"
fp_out = torch.empty(
(indexes.shape[0],
*self.orig_shape[1:]), dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out
if q_bits == 8:
pass
elif q_bits == 12:
q_mantisa_bits = 4
elif q_bits == 6:
q_mantisa_bits = 2
elif q_bits == 4:
q_mantisa_bits = 1
else:
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"

if scale is not None:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()

fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
q_bits - q_mantisa_bits - 1)
return fp_out
Loading

0 comments on commit 17b5529

Please sign in to comment.