Skip to content

Commit

Permalink
Merge branch 'master' into loadams/hpu-uts
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams committed Mar 14, 2024
2 parents 820afbf + b112c99 commit 7f963d3
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 22 deletions.
33 changes: 22 additions & 11 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,26 @@
# DeepSpeed Team

import os
import re
import torch
import types
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS)


def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
hp_mapping = self._hp_mapping
optim_state_keys = hp_mapping.get_optim_state_keys()
hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys
#print(f'{hp_keys=}')
checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys}
for file in checkpoint_files.values():
assert os.path.isfile(file), f'{file} is not a valid file'
hp_mapping.optim_fragment = {}

hp_keys = []
for file in os.listdir(folder):
# We expect files named something like "exp_avg.pt", "exp_avg_sq.pt", "fp32.pt"
pattern = r'(.+).pt'
match = re.search(pattern, file)
if match:
hp_keys.append(match.group(1))

for key in hp_keys:
ckpt_file = checkpoint_files[key]
ckpt_file = os.path.join(folder, f"{key}.pt")
ckpt_dict = torch.load(ckpt_file)
full_hp_param = ckpt_dict[PARAM]

Expand Down Expand Up @@ -62,7 +66,6 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):

assert full_param_numel == tp_world_size * tp_slice_numel, \
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(key)

# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
Expand All @@ -84,13 +87,21 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):

lp_frag_address = hp_mapping.lp_fragment_address
tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel)
assert dst_tensor.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'

# print(f"{key} SHAPE: {tp_hp_slice.shape=}")
# print(f"{key} SHAPE: {dst_tensor.shape=}")
# print(f"{key} SHAPE: {tp_hp_fragment.shape=}")
dst_tensor.data.copy_(tp_hp_fragment.data)

if key == FP32_WEIGHT_KEY:
dst_tensor = hp_mapping.get_hp_fragment()
assert dst_tensor.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
dst_tensor.data.copy_(tp_hp_fragment.data)
else:
assert tp_hp_fragment.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst numel {tp_hp_fragment.numel()} != src numel {lp_frag_address.numel}'

hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach()


def enable_universal_checkpoint(param_list):
Expand Down
10 changes: 8 additions & 2 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process)

from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
Expand Down Expand Up @@ -457,12 +457,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
tp_world_size = self.mpu.get_slice_parallel_world_size()

for i, _ in enumerate(self.optimizer.param_groups):
for i, param_group in enumerate(self.optimizer.param_groups):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()

for lp in self.bf16_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
map_to_flat_opt_states(param_group['params'][0], self.bf16_groups[i], self.optimizer.state, opt_keys)

def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
assert self.immediate_grad_update
Expand Down
9 changes: 4 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,7 +2785,7 @@ def load_checkpoint(self,
if self.load_universal_checkpoint():
self.optimizer.update_lp_params()
if load_zero_checkpoint:
self.update_optimizer_step(step=client_states['iteration'] + 1)
self.update_optimizer_step(step=client_states['iteration'])

return load_path, client_states

Expand Down Expand Up @@ -2966,7 +2966,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
def update_optimizer_step(self, step):

def set_step(d):
if isinstance(d['step'], torch.Tensor):
if 'step' in d and isinstance(d['step'], torch.Tensor):
d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)
else:
d['step'] = step
Expand All @@ -2975,10 +2975,9 @@ def set_step(d):
base_optimizer = optimizer.optimizer
state = base_optimizer.state
for group in optimizer.param_groups:
if 'step' in group:
set_step(group)
set_step(group)
for p in group['params']:
if p in state and len(state[p]) > 0 and 'step' in state[p]:
if p in state and len(state[p]) > 0:
set_step(state[p])

def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
Expand Down
12 changes: 9 additions & 3 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, map_to_flat_opt_states
from deepspeed.checkpoint import enable_universal_checkpoint

from deepspeed.utils import groups
Expand Down Expand Up @@ -1360,7 +1360,7 @@ def reduce_ipg_grads(self):
self.average_tensor(extra_large_grad_reduc.view(-1))
self.extra_large_param_to_reduce = None
else:
self.average_tensor(self.ipg_buffer[self.ipg_index])
self.average_tensor(self.ipg_buffer[self.ipg_index].narrow(0, 0, self.elements_in_ipg_bucket))
else:
self.buffered_reduce_fallback(None,
self.grads_in_ipg_bucket,
Expand Down Expand Up @@ -2310,12 +2310,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()

for i, _ in enumerate(self.optimizer.param_groups):
for i, param_group in enumerate(self.optimizer.param_groups):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()

for lp in self.bit16_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
map_to_flat_opt_states(param_group['params'][0], self.bit16_groups[i], self.optimizer.state, opt_keys)

def _load_global_state(self, sd):
self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .groups import *
from .nvtx import instrument_w_nvtx
# TODO: Move tensor fragment and mixed precision to zero utils
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad, map_to_flat_opt_states
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from .tensor_fragment import set_full_hp_param
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
Expand Down
15 changes: 15 additions & 0 deletions deepspeed/utils/tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ def get_hp_fragment(self, optim_state_key=None):
return self.get_optim_state_fragment(optim_state_key)


def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys):
for key in opt_keys:
hp_param = flat_hp_tensor
buffer = torch.zeros_like(hp_param)

for lp in lp_tensors:
if lp._hp_mapping is not None:
hp_fragment_address = lp._hp_mapping.get_hp_fragment_address()
hp_fragment = buffer.narrow(0, hp_fragment_address.start, hp_fragment_address.numel)
hp_fragment.data.copy_(lp._hp_mapping.get_hp_fragment(optim_state_key=key).data)
lp._hp_mapping.hp_fragment = hp_fragment

optim_state[hp_param][key] = buffer


def get_full_hp_param(self, optim_state_key=None):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
Expand Down

0 comments on commit 7f963d3

Please sign in to comment.