diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 6c7aa8b15ef9..f1d99e1b0e43 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -42,7 +42,6 @@ from .comm.comm import init_distributed from .runtime import zero -from .runtime import DeepSpeedOptimizer, ZeROOptimizer from .runtime.compiler import is_compile_supported from .pipe import PipelineModule diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index f809a0c39270..b3f199a67b98 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -16,6 +16,7 @@ BASE_OPTIMIZER_STATE = 'base_optimizer_state' BASE_OPTIMIZER_STATE_STEP = 'base_optimizer_state_step' SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups" +PARAM_GROUPS = 'param_groups' GROUP_PADDINGS = 'group_paddings' PARTITION_COUNT = 'partition_count' ZERO_STAGE = 'zero_stage' diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index f40c5630899d..9ec5d0b169e4 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -22,6 +22,7 @@ OPTIMIZER_STATE_DICT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, + PARAM_GROUPS, PARAM_SLICE_MAPPINGS, PARAM_SHAPES, PARAM, @@ -110,6 +111,9 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): fp32=fp32_groups[param_group_id], ) + if "step" in state_groups[param_group_id]: + flat_state["step"] = state_groups[param_group_id]["step"] + for name, fragment_mapping in param_slice_mappings[param_group_id].items(): if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params): # Skip tied weights that are replicated in first and last pp stages @@ -138,8 +142,10 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, #print(f"{param_name}: {offset}: {numel} => {path}") - t = state_flat_tensor.narrow(0, offset, numel).clone() - _save_checkpoint(path, t) + # State might be a python int or a tensor + if state_name != "step" and torch.is_tensor(state_flat_tensor): + state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone() + _save_checkpoint(path, state_flat_tensor) def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): @@ -147,8 +153,17 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): for tp_index in range(tp_degree): prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") paths = sorted(list(glob.glob(f"{prefix_path}.*"))) + if len(paths) == 0: + continue + shards = [torch.load(p) for p in paths] - slice = torch.cat(shards, dim=0).reshape(slice_shape) + + if state == "step": + assert all(v == shards[0] for v in shards), "All shards must have the same step value" + slice = shards[0] + else: + slice = torch.cat(shards, dim=0).reshape(slice_shape) + slices.append(slice) return slices @@ -177,6 +192,10 @@ def get_matched_pattern(patterns_, name_): return pattern_ return None + step_merged = _merge_zero_shards(slice_base_path, "step", tp_degree, shape) + if step_merged: + _save_checkpoint(os.path.join(param_base_path, f"step.pt"), step_merged[0]) + for state in ("fp32", "exp_avg", "exp_avg_sq"): slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape) final_path = os.path.join(param_base_path, f"{state}.pt") @@ -227,13 +246,21 @@ def _get_chunks(l, n): def _do_parallel_work(do_work, work_chunks, num_workers): - pool = multiprocessing.Pool(num_workers) - results = [] - for batch in tqdm.tqdm(work_chunks): - res = pool.map(do_work, batch) - results.extend(res) - pool.close() - pool.join() + if num_workers > 1: + pool = multiprocessing.Pool(num_workers) + results = [] + for batch in tqdm.tqdm(work_chunks): + res = pool.map(do_work, batch) + results.extend(res) + pool.close() + pool.join() + else: + # No parallel pass for unit testing + # We can't create child processes in tests + results = [] + for batch in tqdm.tqdm(work_chunks): + res = [do_work(x) for x in batch] + results.extend(res) return results @@ -273,6 +300,7 @@ def _save_optimizer_state(args, ds_checkpoint): optim_sd = sd[OPTIMIZER_STATE_DICT] output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states} + output_sd[PARAM_GROUPS] = optim_sd[BASE_OPTIMIZER_STATE][PARAM_GROUPS] zero_output_folder = os.path.join(args.output_folder, "zero") output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt") _save_checkpoint(output_file_path, output_sd) @@ -283,10 +311,9 @@ def _check_for_required_state(ds_checkpoint): assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.' -def main(): +def main(args): print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint') - args = parse_arguments() print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}') ds_checkpoint = DeepSpeedCheckpoint(args.input_folder) @@ -332,4 +359,5 @@ def main(): if __name__ == "__main__": - main() + args = parse_arguments() + main(args) diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py index a1314e004969..86c8dc904b8c 100644 --- a/deepspeed/checkpoint/universal_checkpoint.py +++ b/deepspeed/checkpoint/universal_checkpoint.py @@ -22,9 +22,15 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): if match: hp_keys.append(match.group(1)) + step = None for key in hp_keys: ckpt_file = os.path.join(folder, f"{key}.pt") ckpt_dict = torch.load(ckpt_file) + + if key == "step": + step = ckpt_dict + continue + full_hp_param = ckpt_dict[PARAM] # need to deal with slices that were averaged. @@ -103,6 +109,8 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach() + return step + def enable_universal_checkpoint(param_list): for param in param_list: diff --git a/deepspeed/checkpoint/zero_checkpoint.py b/deepspeed/checkpoint/zero_checkpoint.py index c65745d3dd0c..6730b93dfd4f 100644 --- a/deepspeed/checkpoint/zero_checkpoint.py +++ b/deepspeed/checkpoint/zero_checkpoint.py @@ -105,9 +105,11 @@ def _strip_tensor_paddings(self, sd): if group_paddings[key] == 0: continue for state_name, state_value in group_state.items(): - if torch.is_tensor(state_value): + if state_name != "step" and torch.is_tensor(state_value): raw_length = state_value.numel() - group_paddings[key] group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone() + else: + group_state[state_name] = state_value def _clear_group_paddings(self, sd): group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS) diff --git a/deepspeed/runtime/__init__.py b/deepspeed/runtime/__init__.py index 347ff7993d82..208299fb8c50 100644 --- a/deepspeed/runtime/__init__.py +++ b/deepspeed/runtime/__init__.py @@ -2,11 +2,3 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team - - -class DeepSpeedOptimizer(object): - pass - - -class ZeROOptimizer(DeepSpeedOptimizer): - pass diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py new file mode 100644 index 000000000000..6cfd66f1cc38 --- /dev/null +++ b/deepspeed/runtime/base_optimizer.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import torch + +from deepspeed.utils import logger +from deepspeed.utils.tensor_fragment import map_to_flat_opt_states +from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank + + +class DeepSpeedOptimizer(object): + pass + + +class ZeROOptimizer(DeepSpeedOptimizer): + + def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: + checkpoint_dir = os.path.join(checkpoint_dir, "zero") + optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") + assert os.path.isfile( + optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' + optim_sd = torch.load(optim_state_path) + + self._load_global_state(optim_sd) + + tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) + if self.mpu is None: + logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.") + tp_world_size = 1 + else: + tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ + else self.mpu.get_tensor_model_parallel_world_size() + + for i, (param_group, + loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])): + # We have an assumption that all params in the same param_group have the same keys + opt_keys = set() + steps = [] + + lp_groups = getattr(self, lp_groups_name) + for lp in lp_groups[i]: + if lp._hp_mapping is not None: + #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") + step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, + tp_world_size) + for key in lp._hp_mapping.get_optim_state_keys(): + opt_keys.add(key) + steps.append(step) + + hp_param = param_group['params'][0] + assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal" + if steps[0] is not None: + self.optimizer.state[hp_param]['step'] = steps[0] + + map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys) + + for key, value in loaded_param_group.items(): + if key == 'params': + continue + param_group[key] = value diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 4ec603af1505..7b98216c1cba 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -6,19 +6,18 @@ from collections import OrderedDict import torch import sys -import os from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed import comm as dist from deepspeed.runtime.constants import PIPE_REPLICATED -from deepspeed.runtime import ZeROOptimizer +from deepspeed.runtime.base_optimizer import ZeROOptimizer from packaging import version as pkg_version from deepspeed.git_version_info import version from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim, align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, is_model_parallel_parameter, see_memory_usage, graph_process, get_norm_with_moe_layers) +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups from deepspeed.moe.utils import is_moe_param, is_moe_param_group -from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups, map_to_flat_opt_states from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, @@ -493,6 +492,7 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad) if load_optimizer_states: + print(f"_load_legacy_checkpoint current_rank_sd[BASE_OPTIMIZER_STATE]") self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) if load_from_fp32_weights: @@ -505,31 +505,16 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l self._link_all_hp_params() def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): - self._load_hp_checkpoint_state(checkpoint_folder) + self.load_hp_checkpoint_state_from_checkpoint_dir("bf16_groups", checkpoint_folder) + + def _load_global_state(self, sd): + pass @property def param_groups(self): """Forward the wrapped optimizer's parameters.""" return self.optimizer.param_groups - def _load_hp_checkpoint_state(self, checkpoint_dir): - checkpoint_dir = os.path.join(checkpoint_dir, "zero") - tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) - tp_world_size = self.mpu.get_slice_parallel_world_size() - - for i, param_group in enumerate(self.optimizer.param_groups): - # We have an assumption that all params in the same param_group have the same keys - opt_keys = set() - - for lp in self.bf16_groups[i]: - if lp._hp_mapping is not None: - #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") - lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, - tp_world_size) - for key in lp._hp_mapping.get_optim_state_keys(): - opt_keys.add(key) - map_to_flat_opt_states(param_group['params'][0], self.bf16_groups[i], self.optimizer.state, opt_keys) - def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): assert self.immediate_grad_update self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index bd2e91431aff..3ad37baeedcb 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2769,7 +2769,7 @@ def load_checkpoint(self, load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) if load_zero_checkpoint: - if load_optimizer_states and not load_module_only: + if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint(): success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) else: success = False @@ -2794,8 +2794,6 @@ def load_checkpoint(self, if self.load_universal_checkpoint(): self.optimizer.update_lp_params() - if load_zero_checkpoint: - self.update_optimizer_step(step=client_states['iteration']) return load_path, client_states @@ -2973,23 +2971,6 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}") return True - def update_optimizer_step(self, step): - - def set_step(d): - if 'step' in d and isinstance(d['step'], torch.Tensor): - d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device) - else: - d['step'] = step - - optimizer = self.optimizer - base_optimizer = optimizer.optimizer - state = base_optimizer.state - for group in optimizer.param_groups: - set_step(group) - for p in group['params']: - if p in state and len(state[p]) > 0: - set_step(state[p]) - def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode): zero_ckpt_names = [] for dp_rank in range(dp_world_size): diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 416642a89901..9ed250252e17 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -10,7 +10,7 @@ import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from deepspeed.runtime import DeepSpeedOptimizer +from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, required_torch_version, get_norm_with_moe_layers from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.utils import logger, log_dist diff --git a/deepspeed/runtime/fp16/unfused_optimizer.py b/deepspeed/runtime/fp16/unfused_optimizer.py index 14271255df2e..a7fd1910d7b2 100755 --- a/deepspeed/runtime/fp16/unfused_optimizer.py +++ b/deepspeed/runtime/fp16/unfused_optimizer.py @@ -11,7 +11,7 @@ import torch from torch._utils import _flatten_dense_tensors -from deepspeed.runtime import DeepSpeedOptimizer +from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer from deepspeed.runtime.utils import get_global_norm, CheckOverflow, get_weight_norm, required_torch_version from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.utils import logger diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 03813afa4ed1..c6ff216edfcb 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -11,7 +11,7 @@ from deepspeed.utils import groups from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from deepspeed.runtime import ZeROOptimizer +from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 6cfcc418e71a..2f98379aa14d 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -4,13 +4,12 @@ # DeepSpeed Team import torch -import os from deepspeed import comm as dist from packaging import version as pkg_version from collections import OrderedDict from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from deepspeed.runtime import ZeROOptimizer +from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, empty_cache, see_memory_usage, inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups) @@ -28,7 +27,7 @@ from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) -from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, map_to_flat_opt_states +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.utils import groups @@ -2287,42 +2286,13 @@ def load_state_dict(self, self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights) def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): - self._load_hp_checkpoint_state(checkpoint_folder) + self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder) @property def param_groups(self): """Forward the wrapped optimizer's parameters.""" return self.optimizer.param_groups - def _load_hp_checkpoint_state(self, checkpoint_dir): - checkpoint_dir = os.path.join(checkpoint_dir, "zero") - optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") - assert os.path.isfile( - optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' - optim_sd = torch.load(optim_state_path) - self._load_global_state(optim_sd) - - tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) - if self.mpu is None: - logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.") - tp_world_size = 1 - else: - tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ - else self.mpu.get_tensor_model_parallel_world_size() - - for i, param_group in enumerate(self.optimizer.param_groups): - # We have an assumption that all params in the same param_group have the same keys - opt_keys = set() - - for lp in self.bit16_groups[i]: - if lp._hp_mapping is not None: - #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") - lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, - tp_world_size) - for key in lp._hp_mapping.get_optim_state_keys(): - opt_keys.add(key) - map_to_flat_opt_states(param_group['params'][0], self.bit16_groups[i], self.optimizer.state, opt_keys) - def _load_global_state(self, sd): self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale) diff --git a/tests/unit/checkpoint/common.py b/tests/unit/checkpoint/common.py index 08fa1eb671bd..3fb13b214ea0 100644 --- a/tests/unit/checkpoint/common.py +++ b/tests/unit/checkpoint/common.py @@ -86,15 +86,20 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True, load def compare_state_dicts(state0, state1, expected_mismatch_keys=[]): - for (k0, s0), (k1, s1) in zip(state0.items(), state1.items()): - assert k0 == k1, f'failure due to key mismatch {k0} != {k1}' - if k0 in expected_mismatch_keys: + key_set0 = set(k for k in state0.keys() if k not in expected_mismatch_keys) + key_set1 = set(k for k in state1.keys() if k not in expected_mismatch_keys) + assert key_set0 == key_set1, f'failure due to key mismatch {key_set0} != {key_set1}' + + for k in key_set0: + s0 = state0[k] + s1 = state1[k] + if k in expected_mismatch_keys: continue if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor): assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}' assert torch.equal(s0.to('cpu'), s1.to('cpu')) else: - assert s0 == s1, f'failures with keys = {k0}, {k1}, values = {type(s0[0])} and {type(s1[0])}' + assert s0 == s1, f'failures with keys = {k}, {k}, values = {s0} and {s1}' def compare_opt_state_dicts(state0, state1, expected_mismatch_keys=[]): diff --git a/tests/unit/checkpoint/test_universal_checkpoint.py b/tests/unit/checkpoint/test_universal_checkpoint.py new file mode 100644 index 000000000000..7adfe8410b55 --- /dev/null +++ b/tests/unit/checkpoint/test_universal_checkpoint.py @@ -0,0 +1,215 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +from types import SimpleNamespace +from torch.utils._pytree import tree_map + +from deepspeed.runtime.utils import required_torch_version +from deepspeed.checkpoint import UNIVERSAL_CHECKPOINT_INFO +from deepspeed.checkpoint.ds_to_universal import main as convert_to_universal + +from unit.common import DistributedTest, DistributedFixture +from unit.simple_model import * +from unit.util import bf16_required_version_check + +from unit.checkpoint.common import compare_opt_state_dicts, compare_state_dicts + +import pytest +import deepspeed.comm as dist + + +def get_expected_mismatch_keys(): + # torch 1.2.* stores raw tensor id numbers in checkpoint state which leads to + # false positive mismatches in checkpoint state comparisons. + # Newer torch versions store tensor ids as 0, 1, 2, ... + return [] if required_torch_version(min_version=1.4) else ['params'] + + +def maybe_step(t): + return not torch.is_tensor(t) or (t.device.type == 'cpu' and t.numel() == 1) + + +def gather_opt_state(optimizer_state): + + def gather_tensor(t): + + if maybe_step(t): + return t + else: + buffer = [torch.zeros_like(t.flatten()) for _ in range(dist.get_world_size())] + dist.all_gather(buffer, t.flatten()) + return torch.cat(buffer) + + return tree_map(gather_tensor, optimizer_state) + + +def remove_pad_in_opt_state(optimizer_state, num_params): + + def remove_pad(t): + if maybe_step(t): + return t + else: + return t[:num_params] + + return tree_map(remove_pad, optimizer_state) + + +CP_TAG = "test_tag" + + +def init_ds_engine(model, ds_config, use_torch_adam): + + if use_torch_adam: + ds_optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + del ds_config["optimizer"] + model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, optimizer=ds_optimizer) + else: + model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, model_parameters=model.parameters()) + + return model + + +def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir): + if dtype == torch.bfloat16 and not bf16_required_version_check(): + return + + test_step = 8 + + model = SimpleModel(hidden_dim) + model = init_ds_engine(model, ds_config, use_torch_adam) + data_loader = random_dataloader(model=model, + total_samples=test_step, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + for batch in data_loader: + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + sd = model.optimizer.optimizer.state_dict() if load_optim else None + + client_state = {} + client_state[UNIVERSAL_CHECKPOINT_INFO] = {} + client_state['iteration'] = test_step + model.save_checkpoint(tmpdir, tag=CP_TAG, client_state=client_state) + + cp_dir = os.path.join(tmpdir, CP_TAG) + univ_cp_dir = f"{cp_dir}_universal" + + args = SimpleNamespace(input_folder=cp_dir, + output_folder=univ_cp_dir, + num_extract_workers=1, + num_merge_workers=1, + keep_temp_folder=False, + strict=True) + + dist.barrier() + if dist.get_rank() == 0: + convert_to_universal(args) + + model_state = model.state_dict() + optimizer_state = None + if load_optim: + optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict()) + + if dist.get_rank() == 0: + torch.save((model_state, optimizer_state), os.path.join(tmpdir, "baseline_state.pt")) + + dist.barrier() + + return model, sd + + +@pytest.fixture +def ds_config(zero_stage, dtype): + ds_config = { + "train_batch_size": 8, + "optimizer": { + "type": 'Adam' + }, + "zero_optimization": { + "stage": zero_stage, + } + } + if dtype == torch.float16: + ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + ds_config["bf16"] = {"enabled": True} + return ds_config + + +class _baseline(DistributedFixture): + world_size = None + + def run(self, tmpdir, ds_config, zero_stage, dtype, load_optim, use_torch_adam): + hidden_dim = 10 + train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir) + + +class baseline_ws2(_baseline): + world_size = 2 + + +class baseline_ws4(_baseline): + world_size = 4 + + +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) +@pytest.mark.parametrize("zero_stage", [1]) +@pytest.mark.parametrize("use_torch_adam", [False, True]) +@pytest.mark.parametrize("load_optim", [False, True]) +class TestZeROUniversalCheckpointDP(DistributedTest): + + def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam): + if dtype == torch.bfloat16 and not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + hidden_dim = 10 + loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt") + + ds_config["checkpoint"] = {"load_universal": True} + univ_model = SimpleModel(hidden_dim) + univ_model = init_ds_engine(univ_model, ds_config, use_torch_adam) + univ_model.load_checkpoint(tmpdir, tag=f"{CP_TAG}_universal", load_optimizer_states=load_optim) + + model_state = univ_model.state_dict() + compare_state_dicts(model_state, loaded_model_state) + + if load_optim: + optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict()) + # padding sizes may differ when dp sizes are different + param_count = sum(p.numel() for p in univ_model.parameters()) + optimizer_state = remove_pad_in_opt_state(optimizer_state, param_count) + loaded_optimizer_state = remove_pad_in_opt_state(loaded_optimizer_state, param_count) + + compare_opt_state_dicts(optimizer_state, loaded_optimizer_state, get_expected_mismatch_keys()) + + # Run training again to verify that the optimizer has necessary states + test_step = 8 + data_loader = random_dataloader(model=univ_model, + total_samples=test_step, + hidden_dim=hidden_dim, + device=univ_model.device, + dtype=dtype) + for batch in data_loader: + loss = univ_model(batch[0], batch[1]) + univ_model.backward(loss) + univ_model.step() + + @pytest.mark.world_size(2) + def test_dp_world_size_2to2(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam): + self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam) + + @pytest.mark.world_size(2) + def test_dp_world_size_4to2(self, baseline_ws4, tmpdir, dtype, ds_config, load_optim, use_torch_adam): + self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam) + + @pytest.mark.world_size(4) + def test_dp_world_size_2to4(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam): + self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam) diff --git a/tests/unit/checkpoint/test_zero_optimizer.py b/tests/unit/checkpoint/test_zero_optimizer.py index aebad4227358..2312425c8aed 100644 --- a/tests/unit/checkpoint/test_zero_optimizer.py +++ b/tests/unit/checkpoint/test_zero_optimizer.py @@ -243,7 +243,11 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l model, _, _, _ = deepspeed.initialize(config=ds_config, model=models[0], model_parameters=models[0].parameters()) - data_loader = random_dataloader(model=model, total_samples=8, hidden_dim=hidden_dim, device=model.device) + run_steps = 8 + data_loader = random_dataloader(model=model, + total_samples=run_steps, + hidden_dim=hidden_dim, + device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss)