Skip to content

Commit

Permalink
Merge branch 'master' into optim-linear
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Apr 20, 2024
2 parents 980db87 + 99951ca commit ffe2223
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
cnt = 0


def dp_index_to_str(dp_index):
return f"{dp_index:0>2d}"


def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel):

global cnt # temp hack
Expand All @@ -140,9 +144,8 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor,
os.makedirs(param_base_path, exist_ok=True)

cnt += 1
counter = f"{dp_index:0>2d}"

path = os.path.join(param_base_path, f"{state_name}.{counter}")
path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}")

#print(f"{param_name}: {offset}: {numel} => {path}")

Expand All @@ -156,10 +159,21 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
slices = []
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
paths = sorted(list(glob.glob(f"{prefix_path}.*")))
paths = glob.glob(f"{prefix_path}.*")

if len(paths) == 0:
continue

pattern = re.compile(f"{prefix_path}\\.([0-9]+)")
dp_indices = set()
for p in paths:
m = pattern.match(p)
if m:
dp_indices.add(int(m.group(1)))
else:
raise ValueError(f"Cannot parse dp_rank from {p}")

paths = [f"{prefix_path}.{dp_index_to_str(dp_index)}" for dp_index in sorted(list(dp_indices))]
shards = [torch.load(p) for p in paths]

if state == "step":
Expand Down

0 comments on commit ffe2223

Please sign in to comment.