Skip to content

Commit

Permalink
Improve parallel process of universal checkpoint conversion (microsof…
Browse files Browse the repository at this point in the history
…t#5343)

The conversion script from a regular checkpoint to the universal one
runs the followings in parallel.

1. extracts zero sharded optimizer states
2. merge the shards

However, it passes `map()` a set of only a few tasks (the number
specified as workers). Thus it needs to wait for the slowest tasks to
finish for every set.
This PR submits all the tasks to the pool and wait until the futures get
ready. We can keep all workers running.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
2 people authored and dbyoung18 committed Jun 11, 2024
1 parent a5c5df8 commit f113ab0
Showing 1 changed file with 10 additions and 24 deletions.
34 changes: 10 additions & 24 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import argparse
import glob
import itertools
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
import os
import re
import shutil
Expand Down Expand Up @@ -292,27 +292,18 @@ def get_matched_sub_params_pattern(name_):
return unmatched_patterns


def _get_chunks(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]


def _do_parallel_work(do_work, work_chunks, num_workers):
results = []
if num_workers > 1:
pool = multiprocessing.Pool(num_workers)
results = []
for batch in tqdm.tqdm(work_chunks):
res = pool.map(do_work, batch)
results.extend(res)
pool.close()
pool.join()
with ProcessPoolExecutor(max_workers=num_workers) as executor:
future_list = [executor.submit(do_work, work) for work in work_chunks]
for f in tqdm.tqdm(future_list):
results.append(f.result())
else:
# No parallel pass for unit testing
# We can't create child processes in tests
results = []
for batch in tqdm.tqdm(work_chunks):
res = [do_work(x) for x in batch]
results.extend(res)
for work in tqdm.tqdm(work_chunks):
results.append(do_work(work))
return results


Expand All @@ -321,20 +312,15 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
range(ds_checkpoint.dp_degree)))
#pprint(f'{_3d_range_list=}')
work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers))
#pprint(f'{work_chunks=}')

# extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0])
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
_do_parallel_work(do_work, work_chunks, args.num_extract_workers)
_do_parallel_work(do_work, _3d_range_list, args.num_extract_workers)


def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers))
#pprint(work_chunks)
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers)
unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers)

# verify that all patterns were used
# if a pattern was not used by any of the workers, then it was not used at all -> assert/alert
Expand Down

0 comments on commit f113ab0

Please sign in to comment.