From 3194fe85c5b62940f5479deba0396121f2d87d59 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Thu, 18 Apr 2024 09:27:07 -0700 Subject: [PATCH 1/5] Add required paths to trigger AMD tests on PRs (#5406) Even though AMD tests are currently broken, this will at least trigger them on PRs that touch files that might impact them. Since the test name is listed as `amd-tests` rather than `unit-tests` they will currently not be required, however. Co-authored-by: root --- .github/workflows/amd-mi200.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/amd-mi200.yml b/.github/workflows/amd-mi200.yml index 00ff72ac8929..cd1cafe8e679 100644 --- a/.github/workflows/amd-mi200.yml +++ b/.github/workflows/amd-mi200.yml @@ -2,6 +2,10 @@ name: amd-mi200 on: workflow_dispatch: + pull_request: + paths: + - '.github/workflows/amd-mi200.yml' + - 'requirements/**' schedule: - cron: "0 0 * * *" From aaaf8bc5e07535e263f83733f8905400bf6f5aca Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Thu, 18 Apr 2024 20:39:07 +0200 Subject: [PATCH 2/5] Bug fix in `split_index` method (#5292) Bug description: on a dataset of 20 samples, when running 4 workers with 8 threads per worker, then the `split_dataset` would return for worker id `1`: ``` self.worker_splits [[0, 5], [5, 10], [10, 15], [15, 20]] self.thread_splits [[5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 10], [11, 10], [12, 10]] ``` `thread_splits` is wrong and causes a crash in the `DataAnalyzer`: the end sample id is lower than the initial one on the last 2 threads. This PR fixes that by fixing the behaviour of `split_index` --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/data_pipeline/data_sampling/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/utils.py b/deepspeed/runtime/data_pipeline/data_sampling/utils.py index 9c643f3705de..dc55f96e222d 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/utils.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/utils.py @@ -3,7 +3,6 @@ # DeepSpeed Team -import math import numpy as np from deepspeed.utils import logger @@ -32,10 +31,8 @@ def find_fit_int_dtype(min_value, max_value): def split_index(start_idx, end_idx, num_partitions): - partition_size = math.ceil((end_idx - start_idx) / num_partitions) - partitions = [[start_idx + x * partition_size, - min(end_idx, start_idx + (x + 1) * partition_size)] for x in range(num_partitions)] - return partitions + partition_boundaries = np.linspace(start_idx, end_idx, dtype=int, num=num_partitions + 1) + return [(partition_boundaries[i], partition_boundaries[i + 1]) for i in range(num_partitions)] def split_dataset(dataset, num_workers, worker_id, num_threads): From 64defe65b73f856466935befc49fb188756ab558 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Thu, 18 Apr 2024 23:14:08 +0200 Subject: [PATCH 3/5] Parallel map step for `DistributedDataAnalyzer` map-reduce (#5291) - adds multi CPU-processing to the `DistributedDataAnalyzer` map operation (parallelism set with parameter `num_workers`). Works with a `SharedMemory` / `Manager's` queue per metric, written concurrently by processes. - much faster `write_buffer_to_file` in `DistributedDataAnalyzer` reduce operation by copying to cpu and "detaching" output tensor. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Conglong Li --- .../data_sampling/data_analyzer.py | 128 ++++++++++++------ .../data_sampling/indexed_dataset.py | 6 +- 2 files changed, 93 insertions(+), 41 deletions(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py index 7088df223bd8..93d351169834 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py @@ -4,6 +4,7 @@ # DeepSpeed Team import os +import sys from collections import defaultdict import csv import time @@ -12,8 +13,8 @@ import torch from torch.utils.data import BatchSampler, SequentialSampler, DataLoader, Subset -from deepspeed.utils import logger import deepspeed.comm as dist +from deepspeed.utils import logger from deepspeed.runtime.data_pipeline.data_sampling.indexed_dataset import MMapIndexedDataset, valid_dtypes from deepspeed.runtime.data_pipeline.data_sampling.utils import split_dataset, split_index, create_mmap_dataset_builder, close_mmap_dataset_builder, find_fit_int_dtype @@ -457,6 +458,7 @@ def __init__( self, dataset, num_workers=1, + num_threads=1, worker_id=0, batch_size=1, metric_names=[], @@ -477,6 +479,8 @@ def __init__( self.collate_fn = collate_fn self.device = device self.sample_indices = sample_indices + self.num_threads = num_threads + self.worker_id = worker_id if not dist.is_initialized(): dist.init_distributed() @@ -494,13 +498,9 @@ def __init__( if self.worker_id == 0: logger.info(f"Distributed data analyzer initialized with {self.num_workers} workers.") - def run_map_reduce(self): - - # setup individual dataloaders - worker_splits, _ = split_dataset(self.dataset, self.num_workers, self.worker_id, num_threads=1) - start_idx, end_idx = worker_splits[self.worker_id] - logger.info(f"worker {self.worker_id}: start working on data subset {start_idx} to {end_idx}") - worker_dataset = Subset(self.dataset, list(range(start_idx, end_idx))) + def run_map_helper(self, thread_id=0, metric_queues=None): + thread_start_idx, thread_end_idx = self.thread_splits[thread_id][0], self.thread_splits[thread_id][1] + worker_dataset = Subset(self.dataset, list(range(thread_start_idx, thread_end_idx))) sampler = BatchSampler(SequentialSampler(worker_dataset), batch_size=self.batch_size, drop_last=False) dataloader = DataLoader(dataset=worker_dataset, batch_sampler=sampler, @@ -516,7 +516,7 @@ def run_map_reduce(self): metric_results.append([] if metric_type == 'single_value_per_sample' else None) # iterate dataloader and store metric results - batch_start_idx = start_idx + batch_start_idx = thread_start_idx for data in dataloader: for m_idx in range(len(self.metric_names)): metric_type, metric_function = self.metric_types[m_idx], self.metric_functions[m_idx] @@ -544,15 +544,73 @@ def run_map_reduce(self): metric_results[m_idx].add_(metric_values) batch_start_idx += len(data) + if self.num_threads == 1: + return metric_results + + # copy metric_results to the shared queue + assert metric_queues + for m_idx in range(len(self.metric_names)): + results = metric_results[m_idx] + if torch.is_tensor(results): + results = results.item() if results.dim() == 0 else results.tolist() + try: + metric_queues[m_idx].put((thread_id, results)) + except Exception as e: + logger.error(f"Error putting metric results to queue: {e}") + sys.exit(1) + + def run_map_reduce(self): + + # setup individual dataloaders + self.worker_splits, self.thread_splits = split_dataset(self.dataset, + self.num_workers, + self.worker_id, + num_threads=self.num_threads) + node_start_idx, node_end_idx = self.worker_splits[self.worker_id] + logger.info(f"worker {self.worker_id} working on data subset {node_start_idx} to {node_end_idx}.") + + if self.num_threads in [0, 1, None]: + metric_results = self.run_map_helper() + metric_results = [torch.tensor(m).to(self.device) for m in metric_results] + else: + + # create a shared queue of results per metric to be populated by individual threads + with Manager() as manager: + metric_queues = [manager.Queue() for _ in self.metric_names] + threads = [ + Process(target=self.run_map_helper, args=(t, metric_queues)) for t in range(self.num_threads) + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # gather results from shared queues into metric_results + metric_results = [None for _ in self.metric_names] + for m_idx, (queue, metric_type) in enumerate(zip(metric_queues, self.metric_types)): + while not queue.empty(): + t_idx, t_results = queue.get() + t_start_idx, t_end_idx = self.thread_splits[t_idx] + if t_start_idx >= t_end_idx: # no results from this thread + continue #corner case for small datasets and high thread count + t_results = torch.tensor(t_results) + if metric_type == 'single_value_per_sample': + # add thread results to the metric_results list, ordered by thread idx + if metric_results[m_idx] is None: # initialize if needed + metric_results[m_idx] = torch.zeros(node_end_idx - node_start_idx, + t_results.size(1)).to(self.device) + metric_results[m_idx][t_start_idx - node_start_idx:t_end_idx - node_start_idx] = t_results + else: + if metric_results[m_idx] is None: # initialize if needed + metric_results[m_idx] = torch.zeros(t_results.size()).to(self.device) + metric_results[m_idx].add_(t_results) + # compute dtype for sample ids total_num_samples = len(self.dataset) sample_idx_dtype = find_fit_int_dtype(0, total_num_samples - 1) logger.info(f"Total number of data samples: {total_num_samples}.") logger.info(f"Will use {sample_idx_dtype} to store the sample indexes.") - # convert to list of tensors - metric_results = [torch.tensor(m).to(self.device) for m in metric_results] - for m_idx in range(len(self.metric_names)): metric_values, metric_name, metric_type = \ metric_results[m_idx], self.metric_names[m_idx], self.metric_types[m_idx] @@ -611,8 +669,8 @@ def run_map_reduce(self): def file_write_ordered(self, tensor_list, fname, numpy_dtype): """ MPI_file_write_ordered extended to write a list of tensors, by one rank, iteratively """ - # each not has a list of rows (tensors) to be written to the file. - # we will serialize it to communicate it in one comm step. + # each node has a list of rows (tensors) to be written to the file. + # we will serialize it in order to communicate it in one comm step. tkwargs = dict(dtype=torch.int64, device=self.device) @@ -636,17 +694,13 @@ def file_write_ordered(self, tensor_list, fname, numpy_dtype): def write_buffer_to_file(buff, src, builder): assert self.worker_id == 0, "only rank 0 can write to file" - # # write one buffer at a time - # for row_len in row_lens[src]: - # builder.add_item(buff[:row_len].cpu()) - # buff = buff[row_len:] - - # collect all buffers and write them all at once - buffer_list = [] - for row_len in row_lens[src]: - buffer_list.append(buff[:row_len].cpu()) - buff = buff[row_len:] - builder.add_items(buffer_list) + # collect all buffers and write them at once + buff = buff.cpu().detach().numpy() + row_offsets = np.cumsum([0] + row_lens[src].tolist()) + arr_list = [] + for i in range(len(row_lens[src])): + arr_list.append(buff[row_offsets[i]:row_offsets[i + 1]]) + builder.add_items(arr_list) # 5. rank 0 prepares output folder and file if self.worker_id == 0: @@ -700,7 +754,7 @@ def gather_v(tensor, dst, comm_group, num_workers, worker_id): # all_gather requires all tensors to be of same size so we need to pad them max_size = max(sizes).item() buffer = torch.empty(max_size, dtype=tensor.dtype, device=tensor.device) - buffer[0:size] = torch.tensor(tensor, dtype=tensor.dtype, device=tensor.device) + buffer[0:size] = tensor.data buffer_list = None if worker_id == 0: # create padded recv buffers buffer_list = [torch.empty(max_size, dtype=tensor.dtype, device=tensor.device) for _ in range(num_workers)] @@ -763,16 +817,18 @@ def sample_sort(tensor, comm_group, num_workers, n_samples=100): def test_compare_both_data_analyzers(dataset): """ given a dataset, compare file and memory based data analyser""" - id = lambda t: torch.tensor(t).to(torch.int64) # identity + id = lambda t: t.to(torch.int64) # identity batch_sum = lambda t: id(t).sum() #sum batch + num_threads = 4 kwargs = dict( dataset=dataset, - batch_size=3, + batch_size=2**10, worker_id=int(os.environ['RANK']), num_workers=int(os.environ['WORLD_SIZE']), metric_names=["mod", "batch_sum"], metric_functions=[id, batch_sum], metric_types=['single_value_per_sample', 'accumulate_value_over_samples'], + num_threads=num_threads, ) dda = DistributedDataAnalyzer( @@ -785,10 +841,9 @@ def test_compare_both_data_analyzers(dataset): if dda.worker_id == 0: print("DistributedDataAnalyzer runtime: %s seconds " % (time.time() - start_time)) - da = DataAnalyzer(num_threads=2, - num_threads_reduce=2, - metric_dtypes=[torch.int64, torch.int64], + da = DataAnalyzer(num_threads_reduce=num_threads, save_path="./output_disk", + metric_dtypes=[torch.int64, torch.int64], **kwargs) start_time = time.time() da.run_map_reduce() @@ -815,14 +870,11 @@ def test_compare_both_data_analyzers(dataset): class TestDataset(torch.utils.data.Dataset): - def __init__(self, size=20): - self.values = [1001 + x % 6 for x in range(size)] + def __init__(self, size=10_000_000): + self.values = [(x + 7) % 10_000 for x in range(size)] self.size = size - def __len__(self): - return self.size - - def __getitem__(self, idx): - return self.values[idx] + __len__ = lambda self: self.size + __getitem__ = lambda self, idx: self.values[idx] test_compare_both_data_analyzers(TestDataset()) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py index 453e6ba6039d..872d05de0145 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py @@ -586,9 +586,9 @@ def add_item(self, tensor): self._data_file.write(np_array.tobytes(order='C')) self._sizes.append(np_array.size) - def add_items(self, tensor_list): - """ write a list of tensors to the file and update their sizes in the index""" - np_arrays = [np.array(t.numpy(), dtype=self._dtype) for t in tensor_list] + def add_items(self, arr_list): + """ write a list of arrays to the file and update their sizes in the index""" + np_arrays = [arr.astype(self._dtype) for arr in arr_list] self._data_file.writelines([arr.tobytes(order='C') for arr in np_arrays]) for arr in np_arrays: self._sizes.append(arr.size) From c632ea09f8d107d10f76aa2b776e4df3c1ccf98a Mon Sep 17 00:00:00 2001 From: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Date: Fri, 19 Apr 2024 08:58:27 -0700 Subject: [PATCH 4/5] Selective dequantization (#5375) This PR adds a new functionality for the dequantizer function, called `selective_dequantize`, which enables partially dequantizing a 3-dimensional matrix in case we don't need to dequantize all the data from lower bit (like fp8/fp6) to bf16. I also added a unit test to check its functionality. --------- Co-authored-by: Reza Yazdani Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- csrc/fp_quantizer/includes/quantize.h | 11 ++ csrc/fp_quantizer/quantize.cpp | 31 +++++ csrc/fp_quantizer/quantize.cu | 135 +++++++++++++++++++ deepspeed/ops/fp_quantizer/quantize.py | 35 +++++ tests/unit/ops/fp_quantizer/test_fp_quant.py | 29 ++++ 5 files changed, 241 insertions(+) diff --git a/csrc/fp_quantizer/includes/quantize.h b/csrc/fp_quantizer/includes/quantize.h index 2204c1ba74fc..507252d6e722 100644 --- a/csrc/fp_quantizer/includes/quantize.h +++ b/csrc/fp_quantizer/includes/quantize.h @@ -113,3 +113,14 @@ void launch_dequantization(uint8_t* val, int q_mantisa_bits, int q_exponent_bits, cudaStream_t stream); + +template +void launch_selective_dequantization(uint8_t* val, + T* q_val, + int32_t* indexes, + int num_groups, + int group_size, + int num_indexes, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream); diff --git a/csrc/fp_quantizer/quantize.cpp b/csrc/fp_quantizer/quantize.cpp index 4a88ff767636..ec631c576e27 100644 --- a/csrc/fp_quantizer/quantize.cpp +++ b/csrc/fp_quantizer/quantize.cpp @@ -78,8 +78,39 @@ void dequantize(torch::Tensor& val, #endif } +#define DISPATCH_DEQUANTIZE_INDEX(T_TYPE, C_TYPE, mantisa) \ + if (val.options().dtype() == torch::T_TYPE) { \ + launch_selective_dequantization((uint8_t*)val_q.data_ptr(), \ + (C_TYPE*)val.data_ptr(), \ + (int32_t*)indexes.data_ptr(), \ + num_groups, \ + group_size, \ + num_indexes, \ + q_mantisa_bits, \ + q_exponent_bits, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } +void selective_dequantize(torch::Tensor& val, + torch::Tensor& val_q, + torch::Tensor& indexes, + int group_size, + int q_mantisa_bits, + int q_exponent_bits) +{ + int total_elems = at::numel(val); + int num_indexes = indexes.size(0); + int num_groups = total_elems / group_size; + + DISPATCH_DEQUANTIZE_INDEX(kHalf, __half, 10); +#ifdef BF16_AVAILABLE + DISPATCH_DEQUANTIZE_INDEX(kBFloat16, __nv_bfloat16, 7); +#endif +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("quantize", &quantize, "quantize function"); m.def("dequantize", &dequantize, "dequantize function"); + m.def("selective_dequantize", &selective_dequantize, "selective dequantize function"); } diff --git a/csrc/fp_quantizer/quantize.cu b/csrc/fp_quantizer/quantize.cu index 5f0b58f124f0..5ada6894747f 100644 --- a/csrc/fp_quantizer/quantize.cu +++ b/csrc/fp_quantizer/quantize.cu @@ -270,6 +270,7 @@ __global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int mem_access::load_global( int8_data + quantization::quanitzed_access_granularity_6bits * 2, load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2); + } else { mem_access::load_global(int8_data, load_base_ptr); @@ -393,3 +394,137 @@ void launch_dequantization(uint8_t* val, INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7); #endif INSTANTIATE_LAUNCH_DEQUANTIZATION(__half, 10); + +template +__global__ void apply_selective_dequantization(uint8_t* val, + T* q_val, + int32_t* indexes, + int group_size, + int total_num_elements) +{ + int index = indexes[blockIdx.x]; + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size; + int input_index = index * total_num_elements + tidx; + constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; + constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; + constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; + constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; + constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits); + const uint32_t g_index = (input_index / group_size); + const uint32_t group_size_bytes = (group_size * quantized_bits / 8); + const uint8_t* load_base_ptr = + val + g_index * (group_size_bytes + 4) + (input_index % group_size) * quantized_bits / 8; + + int mantisa_mask = ((1 << q_mantisa_bits) - 1); + mantisa_mask <<= (_mantisa_bits - q_mantisa_bits); + + T* store_base_ptr = q_val + tidx + blockIdx.x * total_num_elements; + float scale; + + uint8_t* scale_as_int8 = reinterpret_cast(&scale); + if (quantized_bits == 6) { + mem_access::load_global( + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); + mem_access::load_global( + scale_as_int8 + quantization::quanitzed_access_granularity_6bits, + val + g_index * (group_size_bytes + 4) + group_size_bytes + + quantization::quanitzed_access_granularity_6bits); + } else + mem_access::load_global( + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); + + if (tidx < total_num_elements) { + uint64_t q_buf_in; + uint64_t q_buf_in1; + uint8_t* int8_data = reinterpret_cast(&q_buf_in); + uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); + if (quantized_bits == 6) { + mem_access::load_global( + int8_data, load_base_ptr); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits, + load_base_ptr + quantization::quanitzed_access_granularity_6bits); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits * 2, + load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2); + } else { + mem_access::load_global(int8_data, + load_base_ptr); + if (quantized_bits > 4) { + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity, + load_base_ptr + quantization::quanitzed_access_granularity); + if (quantized_bits == 12) { + mem_access::load_global( + int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2); + } + } + } + T store_buf[vector_size]; + uint16_t* q_buf = reinterpret_cast(store_buf); +#pragma unroll + for (int j = 0; j < vector_size; j++) { + uint16_t new_data; + if (j < 5 || quantized_bits != 12) { + new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); + } else { + if (j == 5) { + new_data = (uint16_t)(q_buf_in1); + new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); + } else + new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); + } + + uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); + uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; + uint16_t dst_mantisa = (new_data & _mantisa_mask); + + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; + + q_buf[j] = + ((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) | + (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); + float up_cast = conversion::to(store_buf[j]); + store_buf[j] = conversion::to(up_cast * scale); + } + mem_access::store_global(store_base_ptr, store_buf); + } +} + +template +void launch_selective_dequantization(uint8_t* val, + T* q_val, + int32_t* indexes, + int num_groups, + int group_size, + int num_indexes, + int q_mantisa_bits, + int q_exponent_bits, + cudaStream_t stream) +{ + int total_elements_per_index = (num_groups / num_indexes) * group_size; + int blocks = (total_elements_per_index - 1) / + (quantization::threads * (quantization::access_granularity / sizeof(T))) + + 1; + const dim3 grid(num_indexes, blocks); + const dim3 block(quantization::threads); + DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { + apply_selective_dequantization + <<>>(val, q_val, indexes, group_size, total_elements_per_index); + }); +} +#define INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(T, mantisa) \ + template void launch_selective_dequantization( \ + uint8_t*, T*, int32_t*, int, int, int, int, int, cudaStream_t); +// fp8(E4M3) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__nv_bfloat16, 7); +#endif +INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__half, 10); diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index 5dc3c190ae5d..0d4bf7bc6db1 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -77,3 +77,38 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out + + def selective_dequantize(self, + input_q, + indexes, + fp_out=None, + q_bits=8, + q_mantisa_bits=3, + scale=None) -> torch.Tensor: + assert (not hasattr(self, 'orig_shape') or len(self.orig_shape) == 3), \ + "Selective-Dequantization works on 3d tensor only! Please reshape the tensor before calling dequantize function." + assert (self.orig_dtype is not None), \ + "[De-quantization Error]: you need to call quantize before dequantizing!" + fp_out = torch.empty( + (indexes.shape[0], + *self.orig_shape[1:]), dtype=self.orig_dtype, device=input_q.device) if fp_out is None else fp_out + if q_bits == 8: + pass + elif q_bits == 12: + q_mantisa_bits = 4 + elif q_bits == 6: + q_mantisa_bits = 2 + elif q_bits == 4: + q_mantisa_bits = 1 + else: + assert (0), \ + f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" + + if scale is not None: + assert input_q.numel() == fp_out.numel(), \ + f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' + input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() + + fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits, + q_bits - q_mantisa_bits - 1) + return fp_out diff --git a/tests/unit/ops/fp_quantizer/test_fp_quant.py b/tests/unit/ops/fp_quantizer/test_fp_quant.py index 101f4cd69811..bed8bd7e3bcc 100644 --- a/tests/unit/ops/fp_quantizer/test_fp_quant.py +++ b/tests/unit/ops/fp_quantizer/test_fp_quant.py @@ -61,6 +61,35 @@ def test_fp_quant_meta(dtype): assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +def test_fp_quant_selective(dtype): + group_size = 128 + q_bits = 8 + exp_bits = 4 + man_bits = 3 + + fpq = FP_Quantize(group_size=group_size) + indexes = torch.zeros(2, dtype=torch.int32, device='cuda') + indexes[0] = 1 + indexes[1] = 3 + for i in range(10): + x = torch.rand(4, 1024, dtype=dtype, device='cuda') + + x = x.reshape(4, 1, x.shape[-1]) + ds_x = x.clone() + x_quantized = fpq.quantize(ds_x, q_bits=q_bits) + x_dequantized = fpq.selective_dequantize(x_quantized, indexes, q_bits=q_bits) + + qtorch_out = qtorch_quantize(x.index_select(0, indexes), + exp_bits=exp_bits, + man_bits=man_bits, + group_size=group_size) + qtorch_error = (qtorch_out - x.index_select(0, indexes)).abs().sum() / x.numel() + ds_error = (x_dequantized - x.index_select(0, indexes)).abs().sum() / x.numel() + + assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" + + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) @pytest.mark.parametrize("q_bits", [8, 6, 12], ids=["qbits8", "qbits6", "qbits12"]) def test_fp_quant(dtype, q_bits): From 99951caa3d2155a3bb84109a0828543793e088cc Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Fri, 19 Apr 2024 14:19:47 -0700 Subject: [PATCH 5/5] Fix sorting of shard optimizer states files for universal checkpoint (#5395) This PR resolves the issue reported in #5283. To resolve the issue, we sort files of sharded optimizer states based on DP indices. --------- Co-authored-by: Olatunji Ruwase --- deepspeed/checkpoint/ds_to_universal.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index d5eca81c804f..63fa866718de 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -132,6 +132,10 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): cnt = 0 +def dp_index_to_str(dp_index): + return f"{dp_index:0>2d}" + + def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel): global cnt # temp hack @@ -140,9 +144,8 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, os.makedirs(param_base_path, exist_ok=True) cnt += 1 - counter = f"{dp_index:0>2d}" - path = os.path.join(param_base_path, f"{state_name}.{counter}") + path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}") #print(f"{param_name}: {offset}: {numel} => {path}") @@ -156,10 +159,21 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): slices = [] for tp_index in range(tp_degree): prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") - paths = sorted(list(glob.glob(f"{prefix_path}.*"))) + paths = glob.glob(f"{prefix_path}.*") + if len(paths) == 0: continue + pattern = re.compile(f"{prefix_path}\\.([0-9]+)") + dp_indices = set() + for p in paths: + m = pattern.match(p) + if m: + dp_indices.add(int(m.group(1))) + else: + raise ValueError(f"Cannot parse dp_rank from {p}") + + paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))] shards = [torch.load(p) for p in paths] if state == "step":