From 99951caa3d2155a3bb84109a0828543793e088cc Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Fri, 19 Apr 2024 14:19:47 -0700 Subject: [PATCH] Fix sorting of shard optimizer states files for universal checkpoint (#5395) This PR resolves the issue reported in #5283. To resolve the issue, we sort files of sharded optimizer states based on DP indices. --------- Co-authored-by: Olatunji Ruwase --- deepspeed/checkpoint/ds_to_universal.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index d5eca81c804f..63fa866718de 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -132,6 +132,10 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): cnt = 0 +def dp_index_to_str(dp_index): + return f"{dp_index:0>2d}" + + def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel): global cnt # temp hack @@ -140,9 +144,8 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, os.makedirs(param_base_path, exist_ok=True) cnt += 1 - counter = f"{dp_index:0>2d}" - path = os.path.join(param_base_path, f"{state_name}.{counter}") + path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}") #print(f"{param_name}: {offset}: {numel} => {path}") @@ -156,10 +159,21 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): slices = [] for tp_index in range(tp_degree): prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") - paths = sorted(list(glob.glob(f"{prefix_path}.*"))) + paths = glob.glob(f"{prefix_path}.*") + if len(paths) == 0: continue + pattern = re.compile(f"{prefix_path}\\.([0-9]+)") + dp_indices = set() + for p in paths: + m = pattern.match(p) + if m: + dp_indices.add(int(m.group(1))) + else: + raise ValueError(f"Cannot parse dp_rank from {p}") + + paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))] shards = [torch.load(p) for p in paths] if state == "step":