diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py index 542d1125c566..a1314e004969 100644 --- a/deepspeed/checkpoint/universal_checkpoint.py +++ b/deepspeed/checkpoint/universal_checkpoint.py @@ -4,6 +4,7 @@ # DeepSpeed Team import os +import re import torch import types from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS) @@ -11,15 +12,18 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): hp_mapping = self._hp_mapping - optim_state_keys = hp_mapping.get_optim_state_keys() - hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys - #print(f'{hp_keys=}') - checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys} - for file in checkpoint_files.values(): - assert os.path.isfile(file), f'{file} is not a valid file' + hp_mapping.optim_fragment = {} + + hp_keys = [] + for file in os.listdir(folder): + # We expect files named something like "exp_avg.pt", "exp_avg_sq.pt", "fp32.pt" + pattern = r'(.+).pt' + match = re.search(pattern, file) + if match: + hp_keys.append(match.group(1)) for key in hp_keys: - ckpt_file = checkpoint_files[key] + ckpt_file = os.path.join(folder, f"{key}.pt") ckpt_dict = torch.load(ckpt_file) full_hp_param = ckpt_dict[PARAM] @@ -62,7 +66,6 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): assert full_param_numel == tp_world_size * tp_slice_numel, \ f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}' - dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(key) # print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}") # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}") @@ -84,13 +87,21 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): lp_frag_address = hp_mapping.lp_fragment_address tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel) - assert dst_tensor.numel() == lp_frag_address.numel, \ - f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' # print(f"{key} SHAPE: {tp_hp_slice.shape=}") # print(f"{key} SHAPE: {dst_tensor.shape=}") # print(f"{key} SHAPE: {tp_hp_fragment.shape=}") - dst_tensor.data.copy_(tp_hp_fragment.data) + + if key == FP32_WEIGHT_KEY: + dst_tensor = hp_mapping.get_hp_fragment() + assert dst_tensor.numel() == lp_frag_address.numel, \ + f'Load checkpoint {key} dst numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' + dst_tensor.data.copy_(tp_hp_fragment.data) + else: + assert tp_hp_fragment.numel() == lp_frag_address.numel, \ + f'Load checkpoint {key} dst numel {tp_hp_fragment.numel()} != src numel {lp_frag_address.numel}' + + hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach() def enable_universal_checkpoint(param_list): diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index aaa836bf1c31..82c8dda423a6 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -18,7 +18,7 @@ align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, is_model_parallel_parameter, see_memory_usage, graph_process) -from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, @@ -457,12 +457,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir): tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) tp_world_size = self.mpu.get_slice_parallel_world_size() - for i, _ in enumerate(self.optimizer.param_groups): + for i, param_group in enumerate(self.optimizer.param_groups): + # We have an assumption that all params in the same param_group have the same keys + opt_keys = set() + for lp in self.bf16_groups[i]: if lp._hp_mapping is not None: #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, tp_world_size) + for key in lp._hp_mapping.get_optim_state_keys(): + opt_keys.add(key) + map_to_flat_opt_states(param_group['params'][0], self.bf16_groups[i], self.optimizer.state, opt_keys) def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): assert self.immediate_grad_update diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5c1202ba06ae..174e699c5202 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2785,7 +2785,7 @@ def load_checkpoint(self, if self.load_universal_checkpoint(): self.optimizer.update_lp_params() if load_zero_checkpoint: - self.update_optimizer_step(step=client_states['iteration'] + 1) + self.update_optimizer_step(step=client_states['iteration']) return load_path, client_states @@ -2966,7 +2966,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): def update_optimizer_step(self, step): def set_step(d): - if isinstance(d['step'], torch.Tensor): + if 'step' in d and isinstance(d['step'], torch.Tensor): d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device) else: d['step'] = step @@ -2975,10 +2975,9 @@ def set_step(d): base_optimizer = optimizer.optimizer state = base_optimizer.state for group in optimizer.param_groups: - if 'step' in group: - set_step(group) + set_step(group) for p in group['params']: - if p in state and len(state[p]) > 0 and 'step' in state[p]: + if p in state and len(state[p]) > 0: set_step(state[p]) def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode): diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index b1d94a4459d9..6cfcc418e71a 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -28,7 +28,7 @@ from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) -from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, map_to_flat_opt_states from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.utils import groups @@ -1360,7 +1360,7 @@ def reduce_ipg_grads(self): self.average_tensor(extra_large_grad_reduc.view(-1)) self.extra_large_param_to_reduce = None else: - self.average_tensor(self.ipg_buffer[self.ipg_index]) + self.average_tensor(self.ipg_buffer[self.ipg_index].narrow(0, 0, self.elements_in_ipg_bucket)) else: self.buffered_reduce_fallback(None, self.grads_in_ipg_bucket, @@ -2310,12 +2310,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir): tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ else self.mpu.get_tensor_model_parallel_world_size() - for i, _ in enumerate(self.optimizer.param_groups): + for i, param_group in enumerate(self.optimizer.param_groups): + # We have an assumption that all params in the same param_group have the same keys + opt_keys = set() + for lp in self.bit16_groups[i]: if lp._hp_mapping is not None: #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, tp_world_size) + for key in lp._hp_mapping.get_optim_state_keys(): + opt_keys.add(key) + map_to_flat_opt_states(param_group['params'][0], self.bit16_groups[i], self.optimizer.state, opt_keys) def _load_global_state(self, sd): self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 33ea8ba60818..75fb6aa9d30a 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -10,7 +10,7 @@ from .groups import * from .nvtx import instrument_w_nvtx # TODO: Move tensor fragment and mixed precision to zero utils -from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad +from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad, map_to_flat_opt_states from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state from .tensor_fragment import set_full_hp_param from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state diff --git a/deepspeed/utils/tensor_fragment.py b/deepspeed/utils/tensor_fragment.py index 49eefafcfbcc..b34722580ddd 100644 --- a/deepspeed/utils/tensor_fragment.py +++ b/deepspeed/utils/tensor_fragment.py @@ -58,6 +58,21 @@ def get_hp_fragment(self, optim_state_key=None): return self.get_optim_state_fragment(optim_state_key) +def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys): + for key in opt_keys: + hp_param = flat_hp_tensor + buffer = torch.zeros_like(hp_param) + + for lp in lp_tensors: + if lp._hp_mapping is not None: + hp_fragment_address = lp._hp_mapping.get_hp_fragment_address() + hp_fragment = buffer.narrow(0, hp_fragment_address.start, hp_fragment_address.numel) + hp_fragment.data.copy_(lp._hp_mapping.get_hp_fragment(optim_state_key=key).data) + lp._hp_mapping.hp_fragment = hp_fragment + + optim_state[hp_param][key] = buffer + + def get_full_hp_param(self, optim_state_key=None): reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() if self._hp_mapping is not None: