Skip to content

Commit

Permalink
Fix sorting of shard optimizer states files for universal checkpoint (#…
Browse files Browse the repository at this point in the history
…5395)

This PR resolves the issue reported in #5283.
To resolve the issue, we sort files of sharded optimizer states based on
DP indices.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
tohtana and tjruwase authored Apr 19, 2024
1 parent c632ea0 commit 99951ca
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
cnt = 0


def dp_index_to_str(dp_index):
return f"{dp_index:0>2d}"


def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel):

global cnt # temp hack
Expand All @@ -140,9 +144,8 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor,
os.makedirs(param_base_path, exist_ok=True)

cnt += 1
counter = f"{dp_index:0>2d}"

path = os.path.join(param_base_path, f"{state_name}.{counter}")
path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}")

#print(f"{param_name}: {offset}: {numel} => {path}")

Expand All @@ -156,10 +159,21 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
slices = []
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
paths = sorted(list(glob.glob(f"{prefix_path}.*")))
paths = glob.glob(f"{prefix_path}.*")

if len(paths) == 0:
continue

pattern = re.compile(f"{prefix_path}\\.([0-9]+)")
dp_indices = set()
for p in paths:
m = pattern.match(p)
if m:
dp_indices.add(int(m.group(1)))
else:
raise ValueError(f"Cannot parse dp_rank from {p}")

paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
shards = [torch.load(p) for p in paths]

if state == "step":
Expand Down

0 comments on commit 99951ca

Please sign in to comment.