Skip to content

Commit

Permalink
Merge branch 'master' into gma/add_autotp_workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Mar 16, 2024
2 parents 42211a4 + 3dd3d51 commit 68fb2ec
Show file tree
Hide file tree
Showing 15 changed files with 110 additions and 31 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ DeepSpeed has been integrated with several different popular open-source DL fram
| ----------- | ------ |
| NVIDIA | [![nv-torch110-p40](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch110-p40.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch110-p40.yml) [![nv-torch110-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch110-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch110-v100.yml) [![nv-torch-latest-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-latest-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-latest-v100.yml) [![nv-h100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-h100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-h100.yml) [![nv-inference](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-inference.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-inference.yml) [![nv-nightly](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-nightly.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-nightly.yml) |
| AMD | [![amd-mi200](https://github.com/microsoft/DeepSpeed/actions/workflows/amd-mi200.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/amd-mi200.yml) |
| CPU | [![nv-torch-latest-cpu](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-torch-latest.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-torch-latest.yml) [![cpu-inference](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-inference.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-inference.yml) |
| CPU | [![torch-latest-cpu](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-torch-latest.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-torch-latest.yml) [![cpu-inference](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-inference.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-inference.yml) |
| Habana | [![hpu-gaudi2](https://github.com/microsoft/DeepSpeed/actions/workflows/hpu-gaudi2.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/hpu-gaudi2.yml) |
| PyTorch Nightly | [![nv-torch-nightly-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml) |
| Integrations | [![nv-transformers-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml) [![nv-lightning-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml) [![nv-accelerate-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml) [![nv-mii](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml) [![nv-ds-chat](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-ds-chat.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-ds-chat.yml) [![nv-sd](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-sd.yml/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-sd.yml) |
| Misc | [![Formatting](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml) [![pages-build-deployment](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment) [![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest)[![python](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml) |
Expand Down
33 changes: 22 additions & 11 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,26 @@
# DeepSpeed Team

import os
import re
import torch
import types
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS)


def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
hp_mapping = self._hp_mapping
optim_state_keys = hp_mapping.get_optim_state_keys()
hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys
#print(f'{hp_keys=}')
checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys}
for file in checkpoint_files.values():
assert os.path.isfile(file), f'{file} is not a valid file'
hp_mapping.optim_fragment = {}

hp_keys = []
for file in os.listdir(folder):
# We expect files named something like "exp_avg.pt", "exp_avg_sq.pt", "fp32.pt"
pattern = r'(.+).pt'
match = re.search(pattern, file)
if match:
hp_keys.append(match.group(1))

for key in hp_keys:
ckpt_file = checkpoint_files[key]
ckpt_file = os.path.join(folder, f"{key}.pt")
ckpt_dict = torch.load(ckpt_file)
full_hp_param = ckpt_dict[PARAM]

Expand Down Expand Up @@ -62,7 +66,6 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):

assert full_param_numel == tp_world_size * tp_slice_numel, \
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(key)

# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
Expand All @@ -84,13 +87,21 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):

lp_frag_address = hp_mapping.lp_fragment_address
tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel)
assert dst_tensor.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'

# print(f"{key} SHAPE: {tp_hp_slice.shape=}")
# print(f"{key} SHAPE: {dst_tensor.shape=}")
# print(f"{key} SHAPE: {tp_hp_fragment.shape=}")
dst_tensor.data.copy_(tp_hp_fragment.data)

if key == FP32_WEIGHT_KEY:
dst_tensor = hp_mapping.get_hp_fragment()
assert dst_tensor.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
dst_tensor.data.copy_(tp_hp_fragment.data)
else:
assert tp_hp_fragment.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst numel {tp_hp_fragment.numel()} != src numel {lp_frag_address.numel}'

hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach()


def enable_universal_checkpoint(param_list):
Expand Down
10 changes: 8 additions & 2 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process)

from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
Expand Down Expand Up @@ -457,12 +457,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
tp_world_size = self.mpu.get_slice_parallel_world_size()

for i, _ in enumerate(self.optimizer.param_groups):
for i, param_group in enumerate(self.optimizer.param_groups):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()

for lp in self.bf16_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
map_to_flat_opt_states(param_group['params'][0], self.bf16_groups[i], self.optimizer.state, opt_keys)

def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
assert self.immediate_grad_update
Expand Down
9 changes: 4 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,7 +2785,7 @@ def load_checkpoint(self,
if self.load_universal_checkpoint():
self.optimizer.update_lp_params()
if load_zero_checkpoint:
self.update_optimizer_step(step=client_states['iteration'] + 1)
self.update_optimizer_step(step=client_states['iteration'])

return load_path, client_states

Expand Down Expand Up @@ -2966,7 +2966,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
def update_optimizer_step(self, step):

def set_step(d):
if isinstance(d['step'], torch.Tensor):
if 'step' in d and isinstance(d['step'], torch.Tensor):
d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)
else:
d['step'] = step
Expand All @@ -2975,10 +2975,9 @@ def set_step(d):
base_optimizer = optimizer.optimizer
state = base_optimizer.state
for group in optimizer.param_groups:
if 'step' in group:
set_step(group)
set_step(group)
for p in group['params']:
if p in state and len(state[p]) > 0 and 'step' in state[p]:
if p in state and len(state[p]) > 0:
set_step(state[p])

def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
Expand Down
12 changes: 9 additions & 3 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, map_to_flat_opt_states
from deepspeed.checkpoint import enable_universal_checkpoint

from deepspeed.utils import groups
Expand Down Expand Up @@ -1360,7 +1360,7 @@ def reduce_ipg_grads(self):
self.average_tensor(extra_large_grad_reduc.view(-1))
self.extra_large_param_to_reduce = None
else:
self.average_tensor(self.ipg_buffer[self.ipg_index])
self.average_tensor(self.ipg_buffer[self.ipg_index].narrow(0, 0, self.elements_in_ipg_bucket))
else:
self.buffered_reduce_fallback(None,
self.grads_in_ipg_bucket,
Expand Down Expand Up @@ -2310,12 +2310,18 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()

for i, _ in enumerate(self.optimizer.param_groups):
for i, param_group in enumerate(self.optimizer.param_groups):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()

for lp in self.bit16_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
map_to_flat_opt_states(param_group['params'][0], self.bit16_groups[i], self.optimizer.state, opt_keys)

def _load_global_state(self, sd):
self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .groups import *
from .nvtx import instrument_w_nvtx
# TODO: Move tensor fragment and mixed precision to zero utils
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad, map_to_flat_opt_states
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from .tensor_fragment import set_full_hp_param
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
Expand Down
15 changes: 15 additions & 0 deletions deepspeed/utils/tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ def get_hp_fragment(self, optim_state_key=None):
return self.get_optim_state_fragment(optim_state_key)


def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys):
for key in opt_keys:
hp_param = flat_hp_tensor
buffer = torch.zeros_like(hp_param)

for lp in lp_tensors:
if lp._hp_mapping is not None:
hp_fragment_address = lp._hp_mapping.get_hp_fragment_address()
hp_fragment = buffer.narrow(0, hp_fragment_address.start, hp_fragment_address.numel)
hp_fragment.data.copy_(lp._hp_mapping.get_hp_fragment(optim_state_key=key).data)
lp._hp_mapping.hp_fragment = hp_fragment

optim_state[hp_param][key] = buffer


def get_full_hp_param(self, optim_state_key=None):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
Expand Down
25 changes: 20 additions & 5 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker

# Worker timeout for tests that hang
DEEPSPEED_TEST_TIMEOUT = 600
DEEPSPEED_TEST_TIMEOUT = int(os.environ.get('DS_UNITTEST_TIMEOUT', '600'))


def is_rocm_pytorch():
Expand Down Expand Up @@ -81,6 +81,11 @@ def set_accelerator_visible():
match = re.search('Device Type.*GPU', line)
if match:
num_accelerators += 1
elif get_accelerator().device_name() == 'hpu':
hl_smi = subprocess.check_output(['hl-smi', "-L"])
num_accelerators = re.findall(r"Module ID\s+:\s+(\d+)", hl_smi.decode())
num_accelerators = sorted(num_accelerators, key=int)
os.environ["HABANA_VISIBLE_MODULES"] = ",".join(num_accelerators)
elif get_accelerator().device_name() == 'npu':
npu_smi = subprocess.check_output(['npu-smi', 'info', '-l'])
num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip())
Expand All @@ -90,7 +95,10 @@ def set_accelerator_visible():
subprocess.check_output('cat /proc/cpuinfo | grep "physical id" | sort -u | wc -l', shell=True))
num_accelerators = cpu_sockets

cuda_visible = ",".join(map(str, range(num_accelerators)))
if isinstance(num_accelerators, list):
cuda_visible = ",".join(num_accelerators)
else:
cuda_visible = ",".join(map(str, range(num_accelerators)))

# rotate list based on xdist worker id, example below
# wid=0 -> ['0', '1', '2', '3']
Expand Down Expand Up @@ -149,6 +157,12 @@ def _get_fixture_kwargs(self, request, func):
def _launch_daemonic_procs(self, num_procs):
# Create process pool or use cached one
master_port = None

if get_accelerator().device_name() == 'hpu':
if self.reuse_dist_env:
print("Ignoring reuse_dist_env for hpu")
self.reuse_dist_env = False

if self.reuse_dist_env:
if num_procs not in self._pool_cache:
self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
Expand All @@ -169,9 +183,10 @@ def _launch_daemonic_procs(self, num_procs):
# usually means an environment error and the rest of tests will
# hang (causing super long unit test runtimes)
pytest.exit("Test hanged, exiting", returncode=1)

# Tear down distributed environment and close process pools
self._close_pool(pool, num_procs)
finally:
# Regardless of the outcome, ensure proper teardown
# Tear down distributed environment and close process pools
self._close_pool(pool, num_procs)

# If we skipped a test, propagate that to this process
if any(skip_msgs):
Expand Down
11 changes: 9 additions & 2 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,8 +653,15 @@ def no_pool_bootstrap_stderr(f, xs, iters):
setattr(lm, model_family, getattr(lm, model_family).half().to(device))
lm._device = device
else:
lm = lm_eval.models.get_model(model_family).create_from_arg_string(
f"pretrained={model_name}", {"device": get_accelerator().device_name()})
if get_accelerator().device_name() == 'hpu':
#lm_eval not supporting HPU device, so get model with CPU and move it to HPU.
lm = lm_eval.models.get_model(model_family).create_from_arg_string(f"pretrained={model_name}",
{"device": "cpu"})
setattr(lm, model_family, getattr(lm, model_family).to(device))
lm._device = device
else:
lm = lm_eval.models.get_model(model_family).create_from_arg_string(
f"pretrained={model_name}", {"device": get_accelerator().device_name()})

get_accelerator().synchronize()
start = time.time()
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/runtime/compile/test_compile_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def base_config():
"backend": "inductor"
}
}
if get_accelerator().device_name() == 'hpu':
config_dict['compile']['backend'] = 'hpu_backend'
return config_dict


Expand Down
2 changes: 2 additions & 0 deletions tests/unit/runtime/compile/test_compile_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device):
}
}

if get_accelerator().device_name() == 'hpu':
config_dict['compile']['backend'] = 'hpu_backend'
if offload_device == OffloadDeviceEnum.cpu:
config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device}
elif offload_device == OffloadDeviceEnum.nvme:
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/runtime/compile/test_load_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def base_config():
"backend": "inductor"
}
}

if get_accelerator().device_name() == 'hpu':
config_dict['compile']['backend'] = 'hpu_backend'
return config_dict


Expand Down
3 changes: 3 additions & 0 deletions tests/unit/runtime/half_precision/onebit/test_onebit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
pytest.skip("NCCL-based 1-bit compression is not yet supported w. ROCm 5 until cupy supports ROCm 5",
allow_module_level=True)

if get_accelerator().device_name() == 'hpu':
pytest.skip("1-bit compression is not supported by HPU.", allow_module_level=True)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"])
class TestOneBitAdamBasic(DistributedTest):
Expand Down
Loading

0 comments on commit 68fb2ec

Please sign in to comment.