Skip to content

Commit

Permalink
simplify argument
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Mar 18, 2024
1 parent ccfba1a commit 24484ad
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor,
_save_checkpoint(path, state_flat_tensor)


def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape, step=False):
def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
slices = []
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
paths = sorted(list(glob.glob(f"{prefix_path}.*")))
shards = [torch.load(p) for p in paths]

if step:
if state == "step":
assert all(v == shards[0] for v in shards), "All shards must have the same step value"
slice = shards[0]
else:
Expand Down Expand Up @@ -189,7 +189,7 @@ def get_matched_pattern(patterns_, name_):
return pattern_
return None

step_merged = _merge_zero_shards(slice_base_path, "step", tp_degree, shape, step=True)
step_merged = _merge_zero_shards(slice_base_path, "step", tp_degree, shape)
_save_checkpoint(os.path.join(param_base_path, f"step.pt"), step_merged[0])

for state in ("fp32", "exp_avg", "exp_avg_sq"):
Expand Down

0 comments on commit 24484ad

Please sign in to comment.