From 14f2bbe14f11c3efd75b147a2b5798ec980dfbe0 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Fri, 9 Feb 2024 15:10:56 +0000 Subject: [PATCH 1/6] added assert of torch vs numpy types --- .../data_pipeline/data_sampling/data_analyzer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py index cb0d366ce798..014caf75d602 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py @@ -84,12 +84,14 @@ def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtyp metric_results.append({"metric_value": metric_value, "metric_value_fname": metric_value_fname}) return metric_results - def update_metric_results(self, data, metric_types, metric_functions, metric_results): + def update_metric_results(self, data, metric_types, metric_dtypes, metric_functions, metric_results): for m_idx in range(len(metric_types)): - metric_type, metric_function, metric_result = metric_types[m_idx], \ - metric_functions[m_idx], metric_results[m_idx] + metric_type, metric_dtype, metric_function, metric_result = metric_types[m_idx], \ + metric_dtypes[m_idx], metric_functions[m_idx], metric_results[m_idx] + metric_values = metric_function(data) + assert metric_values.numpy().dtype == metric_dtype, \ + f"dtype {type(m_value)} returned by metric_function {metric_function} is not consistent with the metric_dtype {metric_dtype}" if metric_type == 'single_value_per_sample': - metric_values = metric_function(data) for row in range(metric_values.size()[0]): metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1)) metric_result["metric_to_sample_dict"][metric_values[row].item()].append( @@ -102,7 +104,6 @@ def update_metric_results(self, data, metric_types, metric_functions, metric_res writer.writerows([metric_result["metric_to_sample_dict"][m_value]]) metric_result["metric_to_sample_dict"][m_value] = [] elif metric_type == 'accumulate_value_over_samples': - metric_values = metric_function(data) if metric_result["metric_value"] is None: metric_result["metric_value"] = metric_values else: @@ -158,7 +159,7 @@ def run_map_helper(self, thread_id): try: data = next(iterator) if self.custom_map_update is None: - self.update_metric_results(data, self.metric_types, self.metric_functions, metric_results) + self.update_metric_results(data, self.metric_types, self.metric_dtypes, self.metric_functions, metric_results) else: self.custom_map_update(data, self.metric_types, self.metric_functions, metric_results) processed_sample += self.batch_size @@ -415,3 +416,4 @@ def run_reduce(self): else: self.custom_reduce(self.dataset, self.metric_names, self.metric_types, self.save_path, self.num_workers, self.num_threads, self.num_threads_reduce) + From 295fba6797526aee8b4ca475f0e90489e5fed2bb Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Wed, 14 Feb 2024 13:31:08 +0000 Subject: [PATCH 2/6] added check for single node reduce. added barriers --- .../data_pipeline/data_sampling/data_analyzer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py index 014caf75d602..b1f8e6aaeb24 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py @@ -36,7 +36,8 @@ def __init__(self, custom_map_init=None, custom_map_update=None, custom_map_finalize=None, - custom_reduce=None): + custom_reduce=None, + comm_group=None): super().__init__() self.dataset = dataset self.num_workers = num_workers @@ -55,6 +56,7 @@ def __init__(self, self.custom_map_update = custom_map_update self.custom_map_finalize = custom_map_finalize self.custom_reduce = custom_reduce + self.comm_group = comm_group def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtypes, save_path, worker_id): metric_results = [] @@ -196,6 +198,7 @@ def run_map(self): else: assert self.num_threads == 1 self.run_map_helper(0) + dist.barrier(group=self.comm_group) def get_metric_value_percentiles(self, metric_name, num_sample_per_value, total_num_samples): logger.info(f"Checking the value percentiles of metric {metric_name}...") @@ -410,10 +413,12 @@ def merge_map_results(self, dataset, metric_names, metric_types, save_path, num_ close_mmap_dataset_builder(metric_value_builder, metric_value_fname) def run_reduce(self): - if self.custom_reduce is None: + if self.worker_id == 0: # only one node does merging of files + if self.custom_reduce is None: self.merge_map_results(self.dataset, self.metric_names, self.metric_types, self.save_path, self.num_workers, self.num_threads, self.num_threads_reduce) - else: + else: self.custom_reduce(self.dataset, self.metric_names, self.metric_types, self.save_path, self.num_workers, self.num_threads, self.num_threads_reduce) + dist.barrier(group=self.comm_group) From f28e829b5ab970f2779638aa3a7699ed18375385 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Fri, 16 Feb 2024 09:00:37 +0000 Subject: [PATCH 3/6] recoverd master branch --- .../data_sampling/data_analyzer.py | 66 +++++++++++-------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py index b1f8e6aaeb24..3d4d8bde7d1c 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py @@ -13,7 +13,7 @@ from torch.utils.data import BatchSampler, SequentialSampler, DataLoader, Subset from deepspeed.utils import logger -from .indexed_dataset import MMapIndexedDataset +from .indexed_dataset import MMapIndexedDataset, valid_dtypes from .utils import split_dataset, split_index, create_mmap_dataset_builder, close_mmap_dataset_builder, find_fit_int_dtype @@ -37,7 +37,7 @@ def __init__(self, custom_map_update=None, custom_map_finalize=None, custom_reduce=None, - comm_group=None): + sample_indices=None): super().__init__() self.dataset = dataset self.num_workers = num_workers @@ -56,16 +56,14 @@ def __init__(self, self.custom_map_update = custom_map_update self.custom_map_finalize = custom_map_finalize self.custom_reduce = custom_reduce - self.comm_group = comm_group + self.sample_indices = sample_indices def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtypes, save_path, worker_id): metric_results = [] for m_idx in range(len(metric_names)): metric_name, metric_type, metric_dtype = metric_names[m_idx], \ metric_types[m_idx], metric_dtypes[m_idx] - assert metric_dtype not in [ - np.float64, np.double - ], "Currently floating point metric values are not supported. Please change your metric into integer values (and potentially multiply a larger coefficient to keep the precision)." + assert metric_dtype in valid_dtypes, f"metric_dtype {metric_dtype} not supported. Supported dtypes {valid_dtypes}" metric_save_path = f"{save_path}/{metric_name}/worker{worker_id}_thread{thread_id}/" os.makedirs(metric_save_path, exist_ok=True) if metric_type == 'single_value_per_sample': @@ -86,18 +84,34 @@ def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtyp metric_results.append({"metric_value": metric_value, "metric_value_fname": metric_value_fname}) return metric_results - def update_metric_results(self, data, metric_types, metric_dtypes, metric_functions, metric_results): + def update_metric_results(self, + data, + metric_types, + metric_dtypes, + metric_functions, + metric_results, + batch_start_idx=0): for m_idx in range(len(metric_types)): metric_type, metric_dtype, metric_function, metric_result = metric_types[m_idx], \ metric_dtypes[m_idx], metric_functions[m_idx], metric_results[m_idx] metric_values = metric_function(data) - assert metric_values.numpy().dtype == metric_dtype, \ - f"dtype {type(m_value)} returned by metric_function {metric_function} is not consistent with the metric_dtype {metric_dtype}" + + assert torch.is_tensor(metric_values) or isinstance(metric_values, np.ndarray), \ + "metric_function must return a tensor or array" + assert metric_values.dtype == metric_dtype, \ + f"metric_function result dtype {metric_values.dtype} does not match metric_dtype {metric_dtype}" + if isinstance(metric_values, np.ndarray): + metric_values = torch.from_numpy(metric_values) + if metric_type == 'single_value_per_sample': for row in range(metric_values.size()[0]): + sample_idx = batch_start_idx + row # sample idx following dataset iteration order + if isinstance(data, dict) and 'index' in data: # Megatron use case, idx provided in 'index' field + sample_idx = data['index'][row][0].item() + elif self.sample_indices is not None: # user defined shuffling of indices + sample_idx = self.sample_indices[sample_idx] metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1)) - metric_result["metric_to_sample_dict"][metric_values[row].item()].append( - data['index'][row][0].item()) + metric_result["metric_to_sample_dict"][metric_values[row].item()].append(sample_idx) for m_value in metric_result["metric_to_sample_dict"]: if len(metric_result["metric_to_sample_dict"][m_value]) > 100: metric_fname = metric_result["metric_to_sample_fname"] @@ -139,15 +153,12 @@ def run_map_helper(self, thread_id): f"on data subset {start_idx} to {end_idx}") thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx))) sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False) - if self.collate_fn is None: - iterator = iter(DataLoader(thread_dataset, batch_sampler=sampler, num_workers=0, pin_memory=False)) - else: - iterator = iter( - DataLoader(thread_dataset, - batch_sampler=sampler, - num_workers=0, - collate_fn=self.collate_fn, - pin_memory=False)) + iterator = iter( + DataLoader(thread_dataset, + batch_sampler=sampler, + num_workers=0, + collate_fn=self.collate_fn, + pin_memory=False)) if self.custom_map_init is None: metric_results = self.init_metric_results(thread_id, self.metric_names, self.metric_types, self.metric_dtypes, self.save_path, self.worker_id) @@ -160,10 +171,13 @@ def run_map_helper(self, thread_id): while True: try: data = next(iterator) + batch_start_idx = start_idx + processed_sample if self.custom_map_update is None: - self.update_metric_results(data, self.metric_types, self.metric_dtypes, self.metric_functions, metric_results) + self.update_metric_results(data, self.metric_types, self.metric_dtypes, self.metric_functions, + metric_results, batch_start_idx) else: - self.custom_map_update(data, self.metric_types, self.metric_functions, metric_results) + self.custom_map_update(data, self.metric_types, self.metric_dtypes, self.metric_functions, + metric_results, batch_start_idx) processed_sample += self.batch_size duration = (time.time() - start) / 3600.0 remain_duration = duration * total_sample / processed_sample - duration @@ -198,7 +212,6 @@ def run_map(self): else: assert self.num_threads == 1 self.run_map_helper(0) - dist.barrier(group=self.comm_group) def get_metric_value_percentiles(self, metric_name, num_sample_per_value, total_num_samples): logger.info(f"Checking the value percentiles of metric {metric_name}...") @@ -413,12 +426,9 @@ def merge_map_results(self, dataset, metric_names, metric_types, save_path, num_ close_mmap_dataset_builder(metric_value_builder, metric_value_fname) def run_reduce(self): - if self.worker_id == 0: # only one node does merging of files - if self.custom_reduce is None: + if self.custom_reduce is None: self.merge_map_results(self.dataset, self.metric_names, self.metric_types, self.save_path, self.num_workers, self.num_threads, self.num_threads_reduce) - else: + else: self.custom_reduce(self.dataset, self.metric_names, self.metric_types, self.save_path, self.num_workers, self.num_threads, self.num_threads_reduce) - dist.barrier(group=self.comm_group) - From f6c5c18d77c0eea6a08c3eb0fc44e2ef2dfa218e Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Fri, 8 Mar 2024 09:18:20 +0000 Subject: [PATCH 4/6] master in line with remote --- .../data_sampling/indexed_dataset.py | 50 +++++++++++-------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py index 60115fa6efef..1c56f5f503c5 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py @@ -98,25 +98,26 @@ def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) +# valid metric_dtypes as numpy and torch types dtypes = { - 1: np.uint8, - 2: np.int8, - 3: np.int16, - 4: np.int32, - 5: np.int64, - 6: np.float64, - 7: np.double, - 8: np.uint16, - 9: np.uint32, - 10: np.uint64 + 1: (np.uint8, torch.uint8), + 2: (np.int8, torch.int8), + 3: (np.int16, torch.int16), + 4: (np.int32, torch.int32), + 5: (np.int64, torch.int64), + 6: (np.uint16, None), + 7: (np.uint32, None), + 8: (np.uint64, None), } +valid_dtypes = set([dt[0] for dt in dtypes.values()] + [dt[1] for dt in dtypes.values() if dt[1] is not None]) + def code(dtype): - for k in dtypes.keys(): - if dtypes[k] == dtype: - return k - raise ValueError(dtype) + for c, (np_dt, torch_dt) in dtypes.items(): + if dtype in [np_dt, torch_dt]: + return c + raise ValueError(f"{dtype} not supported. Supported types: {valid_dtypes}") def index_file_path(prefix_path): @@ -153,7 +154,7 @@ def read_index(self, path): version = f.read(8) assert struct.unpack(' Date: Sun, 17 Mar 2024 15:21:08 +0000 Subject: [PATCH 5/6] bug fix, see PR msg --- deepspeed/runtime/data_pipeline/data_sampling/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/utils.py b/deepspeed/runtime/data_pipeline/data_sampling/utils.py index 9c643f3705de..dc55f96e222d 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/utils.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/utils.py @@ -3,7 +3,6 @@ # DeepSpeed Team -import math import numpy as np from deepspeed.utils import logger @@ -32,10 +31,8 @@ def find_fit_int_dtype(min_value, max_value): def split_index(start_idx, end_idx, num_partitions): - partition_size = math.ceil((end_idx - start_idx) / num_partitions) - partitions = [[start_idx + x * partition_size, - min(end_idx, start_idx + (x + 1) * partition_size)] for x in range(num_partitions)] - return partitions + partition_boundaries = np.linspace(start_idx, end_idx, dtype=int, num=num_partitions + 1) + return [(partition_boundaries[i], partition_boundaries[i + 1]) for i in range(num_partitions)] def split_dataset(dataset, num_workers, worker_id, num_threads): From a92b00fa58565abc3e6e1e817dd0ed490a85b125 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Sun, 17 Mar 2024 15:38:01 +0000 Subject: [PATCH 6/6] newline --- .../runtime/data_pipeline/data_sampling/indexed_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py index 1c56f5f503c5..453e6ba6039d 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py @@ -624,4 +624,4 @@ def finalize(self, index_file): self._data_file.close() with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: - index.write(self._sizes, self._doc_idx) \ No newline at end of file + index.write(self._sizes, self._doc_idx)