Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve universal checkpoint #5289

Merged
merged 21 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from .comm.comm import init_distributed

from .runtime import zero
from .runtime import DeepSpeedOptimizer, ZeROOptimizer
from .runtime.compiler import is_compile_supported

from .pipe import PipelineModule
Expand Down
1 change: 1 addition & 0 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BASE_OPTIMIZER_STATE = 'base_optimizer_state'
BASE_OPTIMIZER_STATE_STEP = 'base_optimizer_state_step'
SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups"
PARAM_GROUPS = 'param_groups'
GROUP_PADDINGS = 'group_paddings'
PARTITION_COUNT = 'partition_count'
ZERO_STAGE = 'zero_stage'
Expand Down
54 changes: 41 additions & 13 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
OPTIMIZER_STATE_DICT,
BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS,
PARAM_GROUPS,
PARAM_SLICE_MAPPINGS,
PARAM_SHAPES,
PARAM,
Expand Down Expand Up @@ -110,6 +111,9 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
fp32=fp32_groups[param_group_id],
)

if "step" in state_groups[param_group_id]:
flat_state["step"] = state_groups[param_group_id]["step"]

for name, fragment_mapping in param_slice_mappings[param_group_id].items():
if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params):
# Skip tied weights that are replicated in first and last pp stages
Expand Down Expand Up @@ -138,17 +142,28 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor,

#print(f"{param_name}: {offset}: {numel} => {path}")

t = state_flat_tensor.narrow(0, offset, numel).clone()
_save_checkpoint(path, t)
# State might be a python int or a tensor
if state_name != "step" and torch.is_tensor(state_flat_tensor):
state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone()
_save_checkpoint(path, state_flat_tensor)


def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
slices = []
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
paths = sorted(list(glob.glob(f"{prefix_path}.*")))
if len(paths) == 0:
continue

shards = [torch.load(p) for p in paths]
slice = torch.cat(shards, dim=0).reshape(slice_shape)

if state == "step":
assert all(v == shards[0] for v in shards), "All shards must have the same step value"
slice = shards[0]
else:
slice = torch.cat(shards, dim=0).reshape(slice_shape)

slices.append(slice)
return slices

Expand Down Expand Up @@ -177,6 +192,10 @@ def get_matched_pattern(patterns_, name_):
return pattern_
return None

step_merged = _merge_zero_shards(slice_base_path, "step", tp_degree, shape)
if step_merged:
_save_checkpoint(os.path.join(param_base_path, f"step.pt"), step_merged[0])

for state in ("fp32", "exp_avg", "exp_avg_sq"):
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
final_path = os.path.join(param_base_path, f"{state}.pt")
Expand Down Expand Up @@ -227,13 +246,21 @@ def _get_chunks(l, n):


def _do_parallel_work(do_work, work_chunks, num_workers):
pool = multiprocessing.Pool(num_workers)
results = []
for batch in tqdm.tqdm(work_chunks):
res = pool.map(do_work, batch)
results.extend(res)
pool.close()
pool.join()
if num_workers > 1:
pool = multiprocessing.Pool(num_workers)
results = []
for batch in tqdm.tqdm(work_chunks):
res = pool.map(do_work, batch)
results.extend(res)
pool.close()
pool.join()
else:
# No parallel pass for unit testing
# We can't create child processes in tests
results = []
for batch in tqdm.tqdm(work_chunks):
res = [do_work(x) for x in batch]
results.extend(res)
return results


Expand Down Expand Up @@ -273,6 +300,7 @@ def _save_optimizer_state(args, ds_checkpoint):

optim_sd = sd[OPTIMIZER_STATE_DICT]
output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states}
output_sd[PARAM_GROUPS] = optim_sd[BASE_OPTIMIZER_STATE][PARAM_GROUPS]
zero_output_folder = os.path.join(args.output_folder, "zero")
output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt")
_save_checkpoint(output_file_path, output_sd)
Expand All @@ -283,10 +311,9 @@ def _check_for_required_state(ds_checkpoint):
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'


def main():
def main(args):
print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint')

args = parse_arguments()
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}')

ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
Expand Down Expand Up @@ -332,4 +359,5 @@ def main():


if __name__ == "__main__":
main()
args = parse_arguments()
main(args)
8 changes: 8 additions & 0 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
if match:
hp_keys.append(match.group(1))

step = None
for key in hp_keys:
ckpt_file = os.path.join(folder, f"{key}.pt")
ckpt_dict = torch.load(ckpt_file)

if key == "step":
step = ckpt_dict
continue

full_hp_param = ckpt_dict[PARAM]

# need to deal with slices that were averaged.
Expand Down Expand Up @@ -103,6 +109,8 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):

hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach()

return step


def enable_universal_checkpoint(param_list):
for param in param_list:
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/checkpoint/zero_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ def _strip_tensor_paddings(self, sd):
if group_paddings[key] == 0:
continue
for state_name, state_value in group_state.items():
if torch.is_tensor(state_value):
if state_name != "step" and torch.is_tensor(state_value):
raw_length = state_value.numel() - group_paddings[key]
group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone()
else:
group_state[state_name] = state_value

def _clear_group_paddings(self, sd):
group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
Expand Down
8 changes: 0 additions & 8 deletions deepspeed/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,3 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team


class DeepSpeedOptimizer(object):
pass


class ZeROOptimizer(DeepSpeedOptimizer):
pass
63 changes: 63 additions & 0 deletions deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
import torch

from deepspeed.utils import logger
from deepspeed.utils.tensor_fragment import map_to_flat_opt_states
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank


class DeepSpeedOptimizer(object):
pass


class ZeROOptimizer(DeepSpeedOptimizer):

def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None:
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path)

self._load_global_state(optim_sd)

tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
if self.mpu is None:
logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.")
tp_world_size = 1
else:
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()

for i, (param_group,
loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()
steps = []

lp_groups = getattr(self, lp_groups_name)
for lp in lp_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
steps.append(step)

hp_param = param_group['params'][0]
assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal"
if steps[0] is not None:
self.optimizer.state[hp_param]['step'] = steps[0]

map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys)

for key, value in loaded_param_group.items():
if key == 'params':
continue
param_group[key] = value
29 changes: 7 additions & 22 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@
from collections import OrderedDict
import torch
import sys
import os
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed import comm as dist
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.runtime import ZeROOptimizer
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from packaging import version as pkg_version
from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process,
get_norm_with_moe_layers)
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups
from deepspeed.moe.utils import is_moe_param, is_moe_param_group
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups, map_to_flat_opt_states
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
Expand Down Expand Up @@ -493,6 +492,7 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l
self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)

if load_optimizer_states:
print(f"_load_legacy_checkpoint current_rank_sd[BASE_OPTIMIZER_STATE]")
self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])

if load_from_fp32_weights:
Expand All @@ -505,31 +505,16 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l
self._link_all_hp_params()

def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
self._load_hp_checkpoint_state(checkpoint_folder)
self.load_hp_checkpoint_state_from_checkpoint_dir("bf16_groups", checkpoint_folder)

def _load_global_state(self, sd):
pass

@property
def param_groups(self):
"""Forward the wrapped optimizer's parameters."""
return self.optimizer.param_groups

def _load_hp_checkpoint_state(self, checkpoint_dir):
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
tp_world_size = self.mpu.get_slice_parallel_world_size()

for i, param_group in enumerate(self.optimizer.param_groups):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()

for lp in self.bf16_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
map_to_flat_opt_states(param_group['params'][0], self.bf16_groups[i], self.optimizer.state, opt_keys)

def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
assert self.immediate_grad_update
self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False)
Expand Down
21 changes: 1 addition & 20 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2769,7 +2769,7 @@ def load_checkpoint(self,

load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
if load_zero_checkpoint:
if load_optimizer_states and not load_module_only:
if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint():
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
else:
success = False
Expand All @@ -2794,8 +2794,6 @@ def load_checkpoint(self,

if self.load_universal_checkpoint():
self.optimizer.update_lp_params()
if load_zero_checkpoint:
self.update_optimizer_step(step=client_states['iteration'])

return load_path, client_states

Expand Down Expand Up @@ -2973,23 +2971,6 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}")
return True

def update_optimizer_step(self, step):

def set_step(d):
if 'step' in d and isinstance(d['step'], torch.Tensor):
d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)
else:
d['step'] = step

optimizer = self.optimizer
base_optimizer = optimizer.optimizer
state = base_optimizer.state
for group in optimizer.param_groups:
set_step(group)
for p in group['params']:
if p in state and len(state[p]) > 0:
set_step(state[p])

def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
zero_ckpt_names = []
for dp_rank in range(dp_world_size):
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from deepspeed.runtime import DeepSpeedOptimizer
from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, required_torch_version, get_norm_with_moe_layers
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger, log_dist
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/fp16/unfused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from torch._utils import _flatten_dense_tensors

from deepspeed.runtime import DeepSpeedOptimizer
from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
from deepspeed.runtime.utils import get_global_norm, CheckOverflow, get_weight_norm, required_torch_version
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from deepspeed.utils import groups

from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime import ZeROOptimizer
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
Expand Down
Loading
Loading