From f113ab0737ac32d5c2f0d9faefff0832b1559338 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Mon, 22 Apr 2024 12:50:15 -0700 Subject: [PATCH] Improve parallel process of universal checkpoint conversion (#5343) The conversion script from a regular checkpoint to the universal one runs the followings in parallel. 1. extracts zero sharded optimizer states 2. merge the shards However, it passes `map()` a set of only a few tasks (the number specified as workers). Thus it needs to wait for the slowest tasks to finish for every set. This PR submits all the tasks to the pool and wait until the futures get ready. We can keep all workers running. --------- Co-authored-by: Olatunji Ruwase --- deepspeed/checkpoint/ds_to_universal.py | 34 ++++++++----------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index 63fa866718de..b1a8276589b6 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -10,7 +10,7 @@ import argparse import glob import itertools -import multiprocessing +from concurrent.futures import ProcessPoolExecutor import os import re import shutil @@ -292,27 +292,18 @@ def get_matched_sub_params_pattern(name_): return unmatched_patterns -def _get_chunks(l, n): - for i in range(0, len(l), n): - yield l[i:i + n] - - def _do_parallel_work(do_work, work_chunks, num_workers): + results = [] if num_workers > 1: - pool = multiprocessing.Pool(num_workers) - results = [] - for batch in tqdm.tqdm(work_chunks): - res = pool.map(do_work, batch) - results.extend(res) - pool.close() - pool.join() + with ProcessPoolExecutor(max_workers=num_workers) as executor: + future_list = [executor.submit(do_work, work) for work in work_chunks] + for f in tqdm.tqdm(future_list): + results.append(f.result()) else: # No parallel pass for unit testing # We can't create child processes in tests - results = [] - for batch in tqdm.tqdm(work_chunks): - res = [do_work(x) for x in batch] - results.extend(res) + for work in tqdm.tqdm(work_chunks): + results.append(do_work(work)) return results @@ -321,20 +312,15 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), range(ds_checkpoint.dp_degree))) #pprint(f'{_3d_range_list=}') - work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers)) - #pprint(f'{work_chunks=}') - # extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0]) do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint) - _do_parallel_work(do_work, work_chunks, args.num_extract_workers) + _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers) def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir): - work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers)) - #pprint(work_chunks) zero_output_folder = os.path.join(args.output_folder, "zero") do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree) - unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers) + unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers) # verify that all patterns were used # if a pattern was not used by any of the workers, then it was not used at all -> assert/alert