diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml
index a64a337d50af..1881c968b560 100644
--- a/.github/workflows/hpu-gaudi2.yml
+++ b/.github/workflows/hpu-gaudi2.yml
@@ -7,6 +7,8 @@ on:
pull_request:
paths:
- ".github/workflows/hpu-gaudi2.yml"
+ - "accelerator/hpu_accelerator.py"
+
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
@@ -21,11 +23,63 @@ jobs:
# The type of runner that the job will run on
runs-on: [self-hosted, intel, gaudi2]
container:
- image: vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest
+ image: vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest
ports:
- 80
options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice
+ env:
+ PT_HPU_LAZY_MODE: 0
+ TEST_LIST: |
+ test_accelerator.py
+ test_autotuning.py
+ test_compression.py
+ test_dist.py
+ test_elastic.py
+ (test_intX_quantization.py and test_quantized_linear)
+ test_ds_arguments.py
+ test_run.py
+ test_multinode_runner.py
+ test_moe_tp.py
+ test_monitor.py
+ (test_zero_optimizer.py and (TestSaveTensorClone or TestZeRONonDistributed))
+ (test_latest_checkpoint.py and test_missing_latest)
+ test_reshape_checkpoint.py
+ test_shared_weights.py
+ test_sparse.py
+ test_tag_validation.py
+ test_pipe_module.py
+ (test_flops_profiler.py and test_flops_profiler_in_inference)
+ test_get_optim_files.py
+ test_groups.py
+ test_init_on_device.py
+ test_partition_balanced.py
+ (test_adamw.py and TestAdamConfigs)
+ test_coalesced_collectives.py
+ test_activation_checkpointing_non_reentrant.py
+ test_activation_checkpointing.py
+ test_data.py
+ (test_ds_config_dict.py and (TestBasicConfig or TestBatchConfig))
+ test_ds_config_model.py
+ test_mup_optimizers.py
+ (test_pld.py and test_pld_schedule)
+ test_runtime_utils.py
+ test_pipe_schedule.py
+ test_topology.py
+ (test_ds_initialize.py and (TestClientOptimizer or TestClientLrScheduler))
+ test_csr.py
+ (test_fp16.py and (TestZeroEmptyGrad or TestZeroAllowUntestedOptimizer))
+ (test_bf16.py and TestZeroDtypeCocktail)
+ test_partition.py
+ test_ignore_unused_parameters.py
+ test_zero_config.py
+ test_zero_context_ancestry.py
+ (test_zero_context.py and not TestSerialContext)
+ test_zero_dynamic_class.py
+ test_zero_nesting_init.py
+ test_zeropp.py
+ (test_zero.py and (TestZero3ParamPartitioningLargeParam or TestZero3ParamPartitioningLargeParam))
+
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
@@ -38,11 +92,28 @@ jobs:
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
+ - name: Install transformers
+ run: |
+ git clone https://github.com/huggingface/transformers
+ cd transformers
+ git rev-parse --short HEAD
+ pip install .
+
- name: Install deepspeed
run: |
- pip install .[dev]
+ pip install .[dev,autotuning]
ds_report
- name: Python environment
run: |
pip list
+
+ - name: Unit tests
+ run: |
+ unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
+ cd tests
+ export PT_HPU_LAZY_MODE=${PT_HPU_LAZY_MODE}
+ TEST_LIST=$(echo "$TEST_LIST" | awk 'NF{printf "%s%s", (NR>1 ? " or " : ""), $0} END{if (NR>1) print ""}')
+ echo "TEST_LIST ${TEST_LIST}"
+ echo "PT_HPU_LAZY_MODE ${PT_HPU_LAZY_MODE}"
+ pytest --verbose unit/ -k "${TEST_LIST}"
diff --git a/blogs/deepspeed-fp6/03-05-2024/README.md b/blogs/deepspeed-fp6/03-05-2024/README.md
index dbd6b2d081aa..0285dd79b87d 100755
--- a/blogs/deepspeed-fp6/03-05-2024/README.md
+++ b/blogs/deepspeed-fp6/03-05-2024/README.md
@@ -43,7 +43,7 @@ To cite DeepSpeed-FP6, please cite the following two arxiv reports - ZeroQuant(4
In the evolving landscape of Large Language Models (LLMs) like GPT, our research aims to boost computational efficiency and storage while preserving model quality. This focus brings us to tackle the complex challenges of 4-bit quantization, where optimizing performance, efficiency, and accuracy is crucial.
-**Exploring the Challenges of 4-bit Quantization** In our recent research findings -- ZeroQuant (4+2)[1], we explore the capabilities of INT4 quantization techniques (like the GPTQ algorithm) for serving Large Language Models (LLMs). While these techniques reduce memory and computational requirements, they often perform poorly on a broad array of tasks, including generative tasks such as code generation and summarization, due to overfitting issues. This highlights the urgent need for new quantization approaches that simultanenously improve both the efficiency and effectiveness of LLMs.
+**Exploring the Challenges of 4-bit Quantization** In our recent research findings -- ZeroQuant (4+2)[1], we explore the capabilities of INT4 quantization techniques (like the GPTQ algorithm) for serving Large Language Models (LLMs). While these techniques reduce memory and computational requirements, they often perform poorly on a broad array of tasks, including generative tasks such as code generation and summarization, due to overfitting issues. This highlights the urgent need for new quantization approaches that simultaneously improve both the efficiency and effectiveness of LLMs.
**Breakthroughs with FP6 Precision** Our exploration of different quantization methods led us to the FP6 precision standard. Despite the challenges in integrating and accelerating FP6 with current AI hardware -- which we will address in the next section - this format excels in performance and flexibility across various tasks. Notably, we observe that for generative tasks, FP6 quantization can match the performance of the half-precision (FP16) format. For example, with FP6 quantization, StarCoder-15B achieves comparable code generation results to the FP16 variant, while a smaller model, such as BART-460M, achieves comparable summarization performance to the standard FP16 equivalent. In order to preserve these quality gains, while matching the system efficiency of INT4 quantization on AI hardware, we propose a novel 4+2 FP6 scheme. This innovation makes FP6 a promising direction for improving the efficiency of LLMs, marking a significant leap in AI technology advancement. For more details, please refer to our research paper - ZeroQuant (4+2)[1].
diff --git a/blogs/deepspeed-ulysses/README.md b/blogs/deepspeed-ulysses/README.md
index aa4416521dd1..375eb1190325 100644
--- a/blogs/deepspeed-ulysses/README.md
+++ b/blogs/deepspeed-ulysses/README.md
@@ -233,7 +233,7 @@ at different sequence length and GPU count.*
Next, we evaluate Ulysses on 7 billion (7B) and 30 billion (30B) parameter
GPT dense attention models and compare against Megatron-LM's sequence
-parallelism (Megatron LM) and Colosal AI sequence parallelism (ColAI-SP) on
+parallelism (Megatron LM) and Colossal AI sequence parallelism (ColAI-SP) on
32 and 64 A100 GPUs respectively. The results of these evaluations are shown
in Figures 3 and 4.
diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py
index 6c7aa8b15ef9..fe0043547860 100755
--- a/deepspeed/__init__.py
+++ b/deepspeed/__init__.py
@@ -26,6 +26,7 @@
from . import module_inject
from .accelerator import get_accelerator
+from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.hybrid_engine import DeepSpeedHybridEngine
@@ -42,7 +43,6 @@
from .comm.comm import init_distributed
from .runtime import zero
-from .runtime import DeepSpeedOptimizer, ZeROOptimizer
from .runtime.compiler import is_compile_supported
from .pipe import PipelineModule
@@ -72,6 +72,7 @@ def initialize(args=None,
model_parameters: Optional[torch.nn.Module] = None,
training_data: Optional[torch.utils.data.Dataset] = None,
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,
+ distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT,
mpu=None,
dist_init_required: Optional[bool] = None,
collate_fn=None,
@@ -96,6 +97,8 @@ def initialize(args=None,
lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object.
The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods
+ distributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed training
+
mpu: Optional: A model parallelism unit object that implements
get_{model,data}_parallel_{rank,group,world_size}()
@@ -137,7 +140,9 @@ def initialize(args=None,
global dist
from deepspeed import comm as dist
dist_backend = get_accelerator().communication_backend_name()
- dist.init_distributed(dist_backend=dist_backend, dist_init_required=dist_init_required)
+ dist.init_distributed(dist_backend=dist_backend,
+ distributed_port=distributed_port,
+ dist_init_required=dist_init_required)
# Set config using config_params for backwards compat
if config is None and config_params is not None:
diff --git a/deepspeed/autotuning/utils.py b/deepspeed/autotuning/utils.py
index 8c9a5fa85bf2..b851353520fb 100644
--- a/deepspeed/autotuning/utils.py
+++ b/deepspeed/autotuning/utils.py
@@ -42,7 +42,7 @@ def find_replace_str(value, replace_dict):
if not isinstance(value, str):
return str(value)
- matches = re.findall(r"\$[A-Za-z0-9_]+", value)
+ matches = re.findall(r"\$[\w]+", value)
for var in matches:
var_key = var.replace("$", "").lower()
if var_key == "nvme_path":
diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py
index f809a0c39270..b3f199a67b98 100644
--- a/deepspeed/checkpoint/constants.py
+++ b/deepspeed/checkpoint/constants.py
@@ -16,6 +16,7 @@
BASE_OPTIMIZER_STATE = 'base_optimizer_state'
BASE_OPTIMIZER_STATE_STEP = 'base_optimizer_state_step'
SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups"
+PARAM_GROUPS = 'param_groups'
GROUP_PADDINGS = 'group_paddings'
PARTITION_COUNT = 'partition_count'
ZERO_STAGE = 'zero_stage'
diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py
index f40c5630899d..9ec5d0b169e4 100755
--- a/deepspeed/checkpoint/ds_to_universal.py
+++ b/deepspeed/checkpoint/ds_to_universal.py
@@ -22,6 +22,7 @@
OPTIMIZER_STATE_DICT,
BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS,
+ PARAM_GROUPS,
PARAM_SLICE_MAPPINGS,
PARAM_SHAPES,
PARAM,
@@ -110,6 +111,9 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
fp32=fp32_groups[param_group_id],
)
+ if "step" in state_groups[param_group_id]:
+ flat_state["step"] = state_groups[param_group_id]["step"]
+
for name, fragment_mapping in param_slice_mappings[param_group_id].items():
if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params):
# Skip tied weights that are replicated in first and last pp stages
@@ -138,8 +142,10 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor,
#print(f"{param_name}: {offset}: {numel} => {path}")
- t = state_flat_tensor.narrow(0, offset, numel).clone()
- _save_checkpoint(path, t)
+ # State might be a python int or a tensor
+ if state_name != "step" and torch.is_tensor(state_flat_tensor):
+ state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone()
+ _save_checkpoint(path, state_flat_tensor)
def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
@@ -147,8 +153,17 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
paths = sorted(list(glob.glob(f"{prefix_path}.*")))
+ if len(paths) == 0:
+ continue
+
shards = [torch.load(p) for p in paths]
- slice = torch.cat(shards, dim=0).reshape(slice_shape)
+
+ if state == "step":
+ assert all(v == shards[0] for v in shards), "All shards must have the same step value"
+ slice = shards[0]
+ else:
+ slice = torch.cat(shards, dim=0).reshape(slice_shape)
+
slices.append(slice)
return slices
@@ -177,6 +192,10 @@ def get_matched_pattern(patterns_, name_):
return pattern_
return None
+ step_merged = _merge_zero_shards(slice_base_path, "step", tp_degree, shape)
+ if step_merged:
+ _save_checkpoint(os.path.join(param_base_path, f"step.pt"), step_merged[0])
+
for state in ("fp32", "exp_avg", "exp_avg_sq"):
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
final_path = os.path.join(param_base_path, f"{state}.pt")
@@ -227,13 +246,21 @@ def _get_chunks(l, n):
def _do_parallel_work(do_work, work_chunks, num_workers):
- pool = multiprocessing.Pool(num_workers)
- results = []
- for batch in tqdm.tqdm(work_chunks):
- res = pool.map(do_work, batch)
- results.extend(res)
- pool.close()
- pool.join()
+ if num_workers > 1:
+ pool = multiprocessing.Pool(num_workers)
+ results = []
+ for batch in tqdm.tqdm(work_chunks):
+ res = pool.map(do_work, batch)
+ results.extend(res)
+ pool.close()
+ pool.join()
+ else:
+ # No parallel pass for unit testing
+ # We can't create child processes in tests
+ results = []
+ for batch in tqdm.tqdm(work_chunks):
+ res = [do_work(x) for x in batch]
+ results.extend(res)
return results
@@ -273,6 +300,7 @@ def _save_optimizer_state(args, ds_checkpoint):
optim_sd = sd[OPTIMIZER_STATE_DICT]
output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states}
+ output_sd[PARAM_GROUPS] = optim_sd[BASE_OPTIMIZER_STATE][PARAM_GROUPS]
zero_output_folder = os.path.join(args.output_folder, "zero")
output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt")
_save_checkpoint(output_file_path, output_sd)
@@ -283,10 +311,9 @@ def _check_for_required_state(ds_checkpoint):
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'
-def main():
+def main(args):
print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint')
- args = parse_arguments()
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}')
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
@@ -332,4 +359,5 @@ def main():
if __name__ == "__main__":
- main()
+ args = parse_arguments()
+ main(args)
diff --git a/deepspeed/checkpoint/reshape_utils.py b/deepspeed/checkpoint/reshape_utils.py
index 15b6ce28b2fd..137607721ebf 100644
--- a/deepspeed/checkpoint/reshape_utils.py
+++ b/deepspeed/checkpoint/reshape_utils.py
@@ -4,9 +4,10 @@
# DeepSpeed Team
import os
+import re
import torch
from collections import OrderedDict
-from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX)
+from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX, MODEL_FILE_PREFIX)
def basic_folder_validation(dir):
@@ -38,12 +39,28 @@ def get_files(dir):
return file_list
+def sort_zero_files(files, prefix):
+ pattern = f"{prefix}([0-9]+)_{MODEL_FILE_PREFIX}([0-9]+)"
+ rank_pairs = []
+ for f in files:
+ m = re.search(pattern, f)
+ if m:
+ dp_rank = int(m.group(1))
+ mp_rank = int(m.group(2))
+ rank_pairs.append((dp_rank, mp_rank, f))
+ else:
+ raise ValueError(f"Cannot parse dp_rank and mp_rank from {f}")
+
+ sorted_files = sorted(rank_pairs, key=lambda x: (x[0], x[1]))
+ return [f for _, _, f in sorted_files]
+
+
def get_zero_files(dir):
file_list = get_files(dir)
for prefix in [ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX]:
zero_files = get_files_with_prefix(file_list, prefix)
if len(zero_files) > 0:
- return zero_files
+ return sort_zero_files(zero_files, prefix)
return []
diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py
index a1314e004969..86c8dc904b8c 100644
--- a/deepspeed/checkpoint/universal_checkpoint.py
+++ b/deepspeed/checkpoint/universal_checkpoint.py
@@ -22,9 +22,15 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
if match:
hp_keys.append(match.group(1))
+ step = None
for key in hp_keys:
ckpt_file = os.path.join(folder, f"{key}.pt")
ckpt_dict = torch.load(ckpt_file)
+
+ if key == "step":
+ step = ckpt_dict
+ continue
+
full_hp_param = ckpt_dict[PARAM]
# need to deal with slices that were averaged.
@@ -103,6 +109,8 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach()
+ return step
+
def enable_universal_checkpoint(param_list):
for param in param_list:
diff --git a/deepspeed/checkpoint/zero_checkpoint.py b/deepspeed/checkpoint/zero_checkpoint.py
index c65745d3dd0c..6730b93dfd4f 100644
--- a/deepspeed/checkpoint/zero_checkpoint.py
+++ b/deepspeed/checkpoint/zero_checkpoint.py
@@ -105,9 +105,11 @@ def _strip_tensor_paddings(self, sd):
if group_paddings[key] == 0:
continue
for state_name, state_value in group_state.items():
- if torch.is_tensor(state_value):
+ if state_name != "step" and torch.is_tensor(state_value):
raw_length = state_value.numel() - group_paddings[key]
group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone()
+ else:
+ group_state[state_name] = state_value
def _clear_group_paddings(self, sd):
group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py
index 44e694952ffe..ce58deadc281 100644
--- a/deepspeed/launcher/multinode_runner.py
+++ b/deepspeed/launcher/multinode_runner.py
@@ -74,7 +74,8 @@ def name(self):
def get_cmd(self, environment, active_resources):
environment['PDSH_RCMD_TYPE'] = 'ssh'
if self.args.ssh_port is not None: # only specify ssh port if it is specified
- environment["PDSH_SSH_ARGS_APPEND"] += f" -p {self.args.ssh_port}"
+ environment["PDSH_SSH_ARGS_APPEND"] = f"{environment.get('PDSH_SSH_ARGS_APPEND', '')} \
+ -p {self.args.ssh_port}"
active_workers = ",".join(active_resources.keys())
logger.info("Running on the following workers: %s" % active_workers)
diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py
index bf9c2d74c635..88f7086518e8 100644
--- a/deepspeed/module_inject/auto_tp.py
+++ b/deepspeed/module_inject/auto_tp.py
@@ -133,7 +133,7 @@ def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
- "MistralRMSNorm", "T5LayerNorm"
+ "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names
@@ -303,6 +303,9 @@ def tp_parser(model):
elif 'self_attention.dense' in layer and 'falcon' in str(
type(module)): # this is a hack to get the right linear layer for this model!
gem_list = gem_list + [layer]
+ # Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
+ elif 'w2' in layer and 'Mixtral' in str(type(module)):
+ gem_list = gem_list + [layer]
layer_list = []
if gem_list != []:
@@ -322,6 +325,9 @@ def _replace(self, child, name, conv_linear_layer):
return
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
+ # For mixtral-7x8b, need to skip MoE gate linear replace.
+ if name == "block_sparse_moe.gate":
+ return child
if name in self.all_reduce_linears:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]
diff --git a/deepspeed/moe/utils.py b/deepspeed/moe/utils.py
index 8e1faffc3541..f52fe2e3442d 100644
--- a/deepspeed/moe/utils.py
+++ b/deepspeed/moe/utils.py
@@ -146,3 +146,7 @@ def split_params_into_different_moe_groups_for_optimizer(
param_groups.append(param_group)
return param_groups
+
+
+def is_moe_param_group(param_group):
+ return param_group.get('moe', False)
diff --git a/deepspeed/ops/transformer/inference/triton/matmul_ext.py b/deepspeed/ops/transformer/inference/triton/matmul_ext.py
index d6f72b4efb0b..c77d8a8e11c0 100644
--- a/deepspeed/ops/transformer/inference/triton/matmul_ext.py
+++ b/deepspeed/ops/transformer/inference/triton/matmul_ext.py
@@ -13,12 +13,41 @@
import deepspeed
from pathlib import Path
import atexit
+import subprocess
# -----------------------------------------------------------------------------
# util class/functions for triton
-def _default_cache_dir():
- return os.path.join(Path.home(), ".triton", "autotune")
+def is_nfs_path(path):
+ # Normalize the path to get the absolute path
+ path = os.path.abspath(path)
+
+ # Use the 'df' command to find the file system type for the given path
+ try:
+ output = subprocess.check_output(['df', '-T', path], encoding='utf-8')
+ except subprocess.CalledProcessError:
+ return False # Command failed
+
+ # Process the output of 'df -T' to check for 'nfs' in the filesystem type column
+ lines = output.strip().split('\n')
+ if len(lines) > 1: # The first line is headers
+ fs_type = lines[1].split()[1].lower() # File system type is the second column
+ return 'nfs' in fs_type
+ return False
+
+
+class TritonCacheDir:
+ _warning_printed = False
+
+ @staticmethod
+ def default_cache_dir():
+ tmp_path = os.path.join(Path.home(), ".triton", "autotune")
+ if is_nfs_path(tmp_path) and not TritonCacheDir._warning_printed:
+ print(
+ f"Warning: The default cache directory for DeepSpeed Triton autotune, {tmp_path}, appears to be on an NFS system. While this is generally acceptable, if you experience slowdowns or hanging when DeepSpeed exits, it is recommended to set the TRITON_CACHE_DIR environment variable to a non-NFS path."
+ )
+ TritonCacheDir._warning_printed = True
+ return tmp_path
def bias_add_activation(C, bias=None, activation=""):
@@ -50,7 +79,7 @@ def __init__(self, key):
self.file_path = None
self.lock_path = None
# if caching is enabled, get the lock and bin path
- self.cache_dir = os.environ.get('TRITON_CACHE_DIR', _default_cache_dir())
+ self.cache_dir = os.environ.get('TRITON_CACHE_DIR', TritonCacheDir.default_cache_dir())
if self.cache_dir:
os.makedirs(self.cache_dir, exist_ok=True)
if self.cache_dir:
diff --git a/deepspeed/runtime/__init__.py b/deepspeed/runtime/__init__.py
index 347ff7993d82..208299fb8c50 100644
--- a/deepspeed/runtime/__init__.py
+++ b/deepspeed/runtime/__init__.py
@@ -2,11 +2,3 @@
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
-
-
-class DeepSpeedOptimizer(object):
- pass
-
-
-class ZeROOptimizer(DeepSpeedOptimizer):
- pass
diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py
new file mode 100644
index 000000000000..6cfd66f1cc38
--- /dev/null
+++ b/deepspeed/runtime/base_optimizer.py
@@ -0,0 +1,63 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import os
+import torch
+
+from deepspeed.utils import logger
+from deepspeed.utils.tensor_fragment import map_to_flat_opt_states
+from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank
+
+
+class DeepSpeedOptimizer(object):
+ pass
+
+
+class ZeROOptimizer(DeepSpeedOptimizer):
+
+ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None:
+ checkpoint_dir = os.path.join(checkpoint_dir, "zero")
+ optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
+ assert os.path.isfile(
+ optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
+ optim_sd = torch.load(optim_state_path)
+
+ self._load_global_state(optim_sd)
+
+ tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
+ if self.mpu is None:
+ logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.")
+ tp_world_size = 1
+ else:
+ tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
+ else self.mpu.get_tensor_model_parallel_world_size()
+
+ for i, (param_group,
+ loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])):
+ # We have an assumption that all params in the same param_group have the same keys
+ opt_keys = set()
+ steps = []
+
+ lp_groups = getattr(self, lp_groups_name)
+ for lp in lp_groups[i]:
+ if lp._hp_mapping is not None:
+ #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
+ step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
+ tp_world_size)
+ for key in lp._hp_mapping.get_optim_state_keys():
+ opt_keys.add(key)
+ steps.append(step)
+
+ hp_param = param_group['params'][0]
+ assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal"
+ if steps[0] is not None:
+ self.optimizer.state[hp_param]['step'] = steps[0]
+
+ map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys)
+
+ for key, value in loaded_param_group.items():
+ if key == 'params':
+ continue
+ param_group[key] = value
diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py
index 82c8dda423a6..d076035604e3 100644
--- a/deepspeed/runtime/bf16_optimizer.py
+++ b/deepspeed/runtime/bf16_optimizer.py
@@ -6,19 +6,18 @@
from collections import OrderedDict
import torch
import sys
-import os
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed import comm as dist
from deepspeed.runtime.constants import PIPE_REPLICATED
-from deepspeed.runtime import ZeROOptimizer
+from deepspeed.runtime.base_optimizer import ZeROOptimizer
from packaging import version as pkg_version
-
from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
- is_model_parallel_parameter, see_memory_usage, graph_process)
-
-from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, map_to_flat_opt_states
+ is_model_parallel_parameter, see_memory_usage, graph_process,
+ get_norm_with_moe_layers)
+from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups
+from deepspeed.moe.utils import is_moe_param, is_moe_param_group
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
@@ -40,7 +39,8 @@ def __init__(self,
timers=None,
grad_acc_dtype=None,
graph_harvesting=False,
- immediate_grad_update=False):
+ immediate_grad_update=False,
+ has_moe_layers=False):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
self.timers = timers
@@ -59,7 +59,11 @@ def __init__(self,
self.allgather_bucket_size = int(allgather_bucket_size)
self.dp_process_group = dp_process_group
self.dp_rank = dist.get_rank(group=self.dp_process_group)
+ self.has_moe_layers = has_moe_layers
+ self.non_expert_gradients = []
self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
+ if self.has_moe_layers:
+ self._configure_moe_settings()
# Use torch (un)flatten ops
self.flatten = _flatten_dense_tensors
@@ -90,11 +94,26 @@ def __init__(self,
see_memory_usage('end bf16_optimizer', force=True)
+ def _configure_moe_settings(self):
+ assert any(
+ [is_moe_param_group(group) for group in self.optimizer.param_groups]
+ ), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"
+
+ for i, group in enumerate(self.optimizer.param_groups):
+ if is_moe_param_group(group):
+ assert all([is_moe_param(param)
+ for param in group['params']]), "All params in MoE group must be MoE params"
+ self.real_dp_process_group[i] = groups._get_expert_data_parallel_group(group['name'])
+ self.expert_gradients = {}
+ if self.has_moe_layers:
+ for key in groups._get_expert_data_parallel_group_dict().keys():
+ self.expert_gradients[key] = []
+
def _setup_for_real_optimizer(self):
- dp_world_size = dist.get_world_size(group=self.dp_process_group)
- self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))]
+ self.partition_count = [dist.get_world_size(group=pg) for pg in self.real_dp_process_group]
for i, param_group in enumerate(self.optimizer.param_groups):
+ real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])
see_memory_usage(f'before initializing group {i}', force=True)
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
@@ -106,17 +125,16 @@ def _setup_for_real_optimizer(self):
# create flat bf16 params
self.bf16_groups_flat.append(
self._flatten_dense_tensors_aligned(self.bf16_groups[i],
- self.nccl_start_alignment_factor * dp_world_size))
-
+ self.nccl_start_alignment_factor * real_dp_world_size))
# Make bf16 params point to flat tensor storage
self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i],
flat_tensor=self.bf16_groups_flat[i])
# divide flat weights into equal sized partitions
- partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
+ partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
bf16_dp_partitions = [
self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size)
- for dp_index in range(dp_world_size)
+ for dp_index in range(real_dp_world_size)
]
self.bf16_partitioned_groups.append(bf16_dp_partitions)
@@ -127,8 +145,12 @@ def _setup_for_real_optimizer(self):
num_elem_list = [t.numel() for t in self.bf16_groups[i]]
# create fp32 gradients
- self.fp32_groups_gradients_flat.append(
- torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype))
+ fp32_flat_buffer = torch.zeros_like(self.bf16_groups_flat[i], dtype=self.grad_acc_dtype)
+ self.fp32_groups_gradients_flat.append(fp32_flat_buffer)
+ if self.has_moe_layers and is_moe_param_group(param_group):
+ self.expert_gradients[param_group['name']].append(fp32_flat_buffer)
+ else:
+ self.non_expert_gradients.append(fp32_flat_buffer)
# track individual fp32 gradients for entire model
fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],
@@ -191,11 +213,12 @@ def _create_param_mapping(self):
return param_mapping
def _link_all_hp_params(self):
- dp_world_size = dist.get_world_size(group=self.dp_process_group)
for i, _ in enumerate(self.optimizer.param_groups):
+ real_dp_world_size = dist.get_world_size(group=self.real_dp_process_group[i])
+
# Link bf16 and fp32 params in partition
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
- partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
+ partition_size = self.bf16_groups_flat[i].numel() // real_dp_world_size
flat_hp_partition = self.fp32_groups_flat_partition[i]
link_hp_params(lp_param_list=self.bf16_groups[i],
flat_hp_partition=flat_hp_partition,
@@ -257,10 +280,18 @@ def step(self, closure=None):
if closure is not None:
raise NotImplementedError(f'{self.__class__} does not support closure.')
- all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(),
- mpu=self.mpu,
- norm_type=self.norm_type,
- use_graph=self.graph_harvesting)
+ non_expert_grads_for_norm, expert_grads_for_norm = self.get_grads_for_norm()
+ non_expert_groups_norm = get_global_norm_of_tensors(input_tensors=non_expert_grads_for_norm,
+ mpu=self.mpu,
+ norm_type=self.norm_type,
+ use_graph=self.graph_harvesting)
+ all_groups_norm = non_expert_groups_norm
+ if self.has_moe_layers:
+ all_groups_norm = get_norm_with_moe_layers(non_expert_groups_norm,
+ mpu=self.mpu,
+ expert_tensors=expert_grads_for_norm,
+ norm_type=self.norm_type)
+
self._global_grad_norm = all_groups_norm
assert all_groups_norm > 0.
@@ -336,27 +367,55 @@ def update_hp_grads(self, clear_lp_grads=False):
@torch.no_grad()
def get_grads_for_reduction(self):
- return self.fp32_groups_gradients_flat
+ if self.has_moe_layers:
+ return self.non_expert_gradients, self.expert_gradients
+ return self.non_expert_gradients, {}
@torch.no_grad()
def get_grads_for_norm(self, for_clipping=False):
- grads = []
+ """
+ Returns:
+ tuple[list[Tensor], dict[ep_name, List[Tensor]] | list:
+ If for_clipping, return all gradients.
+ Otherwise, separate and return dict of expert_grad and list of non_expert_grad
+ """
+ # (grads, expert_group_name)
+ expert_grads_for_norm = {}
+
+ # grads
+ non_expert_grads_for_norm = []
+ all_grads_for_clip = []
+
tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
+ assert len(self.bf16_groups) == len(self.optimizer.param_groups)
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if not for_clipping:
if hasattr(lp, PIPE_REPLICATED) and lp.ds_pipe_replicated:
continue
- if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp)):
+ # skip duplicated parameters. perform norm only on cards with tp_rank=0.
+ # non-duplicated parameters include:
+ # - Parameters with tp: Use allreducesum of mp_group.
+ # - Moe Parameters with ep: Use allreducesum of ep_group.
+ if not (tensor_mp_rank == 0 or is_model_parallel_parameter(lp) or is_moe_param(lp)):
continue
if not self.fp32_groups_has_gradients[i][j]:
continue
-
- grads.append(self.fp32_groups_gradients[i][j])
-
- return grads
+ if not for_clipping:
+ param_group = self.optimizer.param_groups[i]
+ if self.has_moe_layers and is_moe_param_group(param_group):
+ if param_group['name'] not in expert_grads_for_norm:
+ expert_grads_for_norm[param_group['name']] = []
+ expert_grads_for_norm[param_group['name']].append(self.fp32_groups_gradients[i][j])
+ else:
+ non_expert_grads_for_norm.append(self.fp32_groups_gradients[i][j])
+ else:
+ all_grads_for_clip.append(self.fp32_groups_gradients[i][j])
+ if not for_clipping:
+ return non_expert_grads_for_norm, expert_grads_for_norm
+ return all_grads_for_clip
@torch.no_grad()
def update_lp_params(self):
@@ -433,6 +492,7 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l
self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)
if load_optimizer_states:
+ print(f"_load_legacy_checkpoint current_rank_sd[BASE_OPTIMIZER_STATE]")
self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
if load_from_fp32_weights:
@@ -445,34 +505,19 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l
self._link_all_hp_params()
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
- self._load_hp_checkpoint_state(checkpoint_folder)
+ self.load_hp_checkpoint_state_from_checkpoint_dir("bf16_groups", checkpoint_folder)
+
+ def _load_global_state(self, sd):
+ pass
@property
def param_groups(self):
"""Forward the wrapped optimizer's parameters."""
return self.optimizer.param_groups
- def _load_hp_checkpoint_state(self, checkpoint_dir):
- checkpoint_dir = os.path.join(checkpoint_dir, "zero")
- tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
- tp_world_size = self.mpu.get_slice_parallel_world_size()
-
- for i, param_group in enumerate(self.optimizer.param_groups):
- # We have an assumption that all params in the same param_group have the same keys
- opt_keys = set()
-
- for lp in self.bf16_groups[i]:
- if lp._hp_mapping is not None:
- #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
- lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
- tp_world_size)
- for key in lp._hp_mapping.get_optim_state_keys():
- opt_keys.add(key)
- map_to_flat_opt_states(param_group['params'][0], self.bf16_groups[i], self.optimizer.state, opt_keys)
-
def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
assert self.immediate_grad_update
- self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False)
+ self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=True)
def create_grad_acc_hooks(self):
self.grad_accs = []
diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py
index 975fb1f21501..19b169086be1 100755
--- a/deepspeed/runtime/config.py
+++ b/deepspeed/runtime/config.py
@@ -258,10 +258,10 @@ def get_communication_data_type(param_dict,
return torch.float32
elif val == "fp16":
return torch.float16
- elif val == "bfp16":
+ elif val == "bf16":
return torch.bfloat16
- raise ValueError(f"Invalid communication_data_type. Supported data types: ['fp16', 'bfp16', 'fp32']. Got: {val}")
+ raise ValueError(f"Invalid communication_data_type. Supported data types: ['fp16', 'bf16', 'fp32']. Got: {val}")
def get_prescale_gradients(param_dict):
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
index 174e699c5202..3ad37baeedcb 100644
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -1478,7 +1478,8 @@ def _configure_bf16_optimizer(self, optimizer):
timers=timers,
grad_acc_dtype=self.get_data_types()[1],
graph_harvesting=self.graph_harvesting(),
- immediate_grad_update=self._config.bfloat16_immediate_grad_update)
+ immediate_grad_update=self._config.bfloat16_immediate_grad_update,
+ has_moe_layers=self.has_moe_layers)
return optimizer
@@ -1924,9 +1925,6 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism)
else:
grads = None
- if hasattr(self.optimizer, "get_grads_for_reduction"):
- # This is currently for BF16 optimizer
- grads = self.optimizer.get_grads_for_reduction()
self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size)
@instrument_w_nvtx
@@ -2335,7 +2333,7 @@ def _report_progress(self, step):
mom = self.get_mom()
log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0])
- def allreduce_bucket(self, bucket, dp_group):
+ def allreduce_bucket(self, bucket, dp_group, dp_world_size=None):
tensor = self.flatten(bucket)
tensor_to_allreduce = tensor
@@ -2343,16 +2341,18 @@ def allreduce_bucket(self, bucket, dp_group):
if self.communication_data_type != tensor.dtype:
tensor_to_allreduce = tensor.to(self.communication_data_type)
+ if dp_world_size is None:
+ dp_world_size = dist.get_world_size(group=dp_group)
if self.postscale_gradients():
if self.gradient_predivide_factor() != 1.0:
tensor_to_allreduce.mul_(1.0 / self.gradient_predivide_factor())
dist.all_reduce(tensor_to_allreduce, group=dp_group)
if self.gradient_average:
- if self.gradient_predivide_factor() != dist.get_world_size(group=dp_group):
- tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group))
+ if self.gradient_predivide_factor() != dp_world_size:
+ tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dp_world_size)
else:
- tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group))
+ tensor_to_allreduce.mul_(1. / dp_world_size)
dist.all_reduce(tensor_to_allreduce, group=dp_group)
if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
@@ -2360,23 +2360,23 @@ def allreduce_bucket(self, bucket, dp_group):
return tensor
- def allreduce_and_copy(self, small_bucket, dp_group):
- allreduced = self.allreduce_bucket(small_bucket, dp_group)
+ def allreduce_and_copy(self, small_bucket, dp_group, dp_world_size=None):
+ allreduced = self.allreduce_bucket(small_bucket, dp_group, dp_world_size)
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
buf.copy_(synced)
- def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000):
+ def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000, dp_world_size=None):
small_bucket = []
numel = 0
for tensor in bucket:
small_bucket.append(tensor)
numel = numel + tensor.numel()
if numel > numel_per_bucket:
- self.allreduce_and_copy(small_bucket, dp_group)
+ self.allreduce_and_copy(small_bucket, dp_group, dp_world_size)
small_bucket = []
numel = 0
if len(small_bucket) > 0:
- self.allreduce_and_copy(small_bucket, dp_group)
+ self.allreduce_and_copy(small_bucket, dp_group, dp_world_size)
def _get_gradients_for_reduction(self):
non_expert_grads = []
@@ -2427,26 +2427,35 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer):
self.allreduce_no_retain(dense_bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer)
def _reduce_expert_gradients(self, expert_grads, elements_per_buffer):
+ # to maintain the gradients value unaffected by ep_size setting,
+ # utilize dp_world_size for allreduce average
+ dp_world_size = dist.get_world_size(groups._get_data_parallel_group())
for ep_name, expert_grads_group in expert_grads.items():
+ ep_dp_group = groups._get_expert_data_parallel_group(ep_name)
split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse(
expert_grads_group)
for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets):
if sparse_bucket_tuple:
bucket_type, sparse_bucket = sparse_bucket_tuple
- self.sparse_allreduce_no_retain(sparse_bucket, groups._get_expert_data_parallel_group(ep_name))
+ self.sparse_allreduce_no_retain(sparse_bucket, dp_group=ep_dp_group, dp_world_size=dp_world_size)
for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets):
if dense_bucket_tuple:
bucket_type, dense_bucket = dense_bucket_tuple
# Separate between diff groups
self.allreduce_no_retain(dense_bucket,
- dp_group=groups._get_expert_data_parallel_group(ep_name),
- numel_per_bucket=elements_per_buffer)
+ dp_group=ep_dp_group,
+ numel_per_bucket=elements_per_buffer,
+ dp_world_size=dp_world_size)
def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
if grads is None:
- non_expert_grads, expert_grads = self._get_gradients_for_reduction()
+ if hasattr(self.optimizer, "get_grads_for_reduction"):
+ # This is currently for BF16 optimizer
+ non_expert_grads, expert_grads = self.optimizer.get_grads_for_reduction()
+ else:
+ non_expert_grads, expert_grads = self._get_gradients_for_reduction()
else:
assert not self.has_moe_layers, "attempting to reduce grads in unsupported way w.r.t. MoE"
non_expert_grads = grads
@@ -2456,8 +2465,8 @@ def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000)
if self.has_moe_layers:
self._reduce_expert_gradients(expert_grads, elements_per_buffer)
- def sparse_allreduce_no_retain(self, bucket, dp_group):
- allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group)
+ def sparse_allreduce_no_retain(self, bucket, dp_group, dp_world_size=None):
+ allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group, dp_world_size)
# Densify sparse tensor and copy back to original location
for tensor in allreduced_sparses:
if tensor.is_sparse:
@@ -2465,13 +2474,13 @@ def sparse_allreduce_no_retain(self, bucket, dp_group):
else:
tensor.orig_dense_tensor.copy_(tensor.to_dense())
- def sparse_allreduce_bucket(self, bucket, dp_group):
+ def sparse_allreduce_bucket(self, bucket, dp_group, dp_world_size=None):
sparse_list = []
for sparse in bucket:
- sparse_list.append(self.sparse_allreduce(sparse, dp_group))
+ sparse_list.append(self.sparse_allreduce(sparse, dp_group, dp_world_size))
return sparse_list
- def sparse_allreduce(self, sparse, dp_group):
+ def sparse_allreduce(self, sparse, dp_group, dp_world_size=None):
original_data_type = sparse.values.dtype
if self.communication_data_type != sparse.values.dtype:
if self.communication_data_type in (torch.float16, torch.bfloat16):
@@ -2483,12 +2492,13 @@ def sparse_allreduce(self, sparse, dp_group):
indices = sparse.indices
values = sparse.values
+ if dp_world_size is None:
+ dp_world_size = dist.get_world_size(group=dp_group)
if self.postscale_gradients():
if self.gradient_average:
- values.mul_(self.gradient_predivide_factor() /
- (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size)))
+ values.mul_(self.gradient_predivide_factor() / (dp_world_size / float(self.sequence_parallel_size)))
else:
- values.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size)))
+ values.mul_(1. / (dp_world_size / float(self.sequence_parallel_size)))
indices_device_list = self.sparse_all_gather(indices, dp_group)
values_device_list = self.sparse_all_gather(values, dp_group)
@@ -2759,7 +2769,7 @@ def load_checkpoint(self,
load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
if load_zero_checkpoint:
- if load_optimizer_states and not load_module_only:
+ if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint():
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
else:
success = False
@@ -2784,8 +2794,6 @@ def load_checkpoint(self,
if self.load_universal_checkpoint():
self.optimizer.update_lp_params()
- if load_zero_checkpoint:
- self.update_optimizer_step(step=client_states['iteration'])
return load_path, client_states
@@ -2963,23 +2971,6 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}")
return True
- def update_optimizer_step(self, step):
-
- def set_step(d):
- if 'step' in d and isinstance(d['step'], torch.Tensor):
- d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)
- else:
- d['step'] = step
-
- optimizer = self.optimizer
- base_optimizer = optimizer.optimizer
- state = base_optimizer.state
- for group in optimizer.param_groups:
- set_step(group)
- for p in group['params']:
- if p in state and len(state[p]) > 0:
- set_step(state[p])
-
def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
zero_ckpt_names = []
for dp_rank in range(dp_world_size):
diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py
index 182f806c839c..9ed250252e17 100755
--- a/deepspeed/runtime/fp16/fused_optimizer.py
+++ b/deepspeed/runtime/fp16/fused_optimizer.py
@@ -10,13 +10,13 @@
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
-from deepspeed.runtime import DeepSpeedOptimizer
-from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, required_torch_version
+from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
+from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, required_torch_version, get_norm_with_moe_layers
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
-from deepspeed.utils import groups, logger, log_dist
-from deepspeed import comm as dist
+from deepspeed.utils import logger, log_dist
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD
from deepspeed.accelerator import get_accelerator
+from deepspeed.moe.utils import is_moe_param_group
OVERFLOW_CHECK_TIMER = 'overflow_check'
COMPUTE_NORM_TIMER = 'compute_norm'
@@ -237,6 +237,10 @@ def step(self, closure=None):
return self.overflow
grads_groups_flat = []
+ non_experts_grads_for_norm = []
+ expert_grads_for_norm = {}
+ assert len(self.fp16_groups) == len(self.optimizer.param_groups)
+
for i, group in enumerate(self.fp16_groups):
data_type = self.fp32_groups_flat[i].dtype
@@ -250,15 +254,25 @@ def step(self, closure=None):
p.grad = None
self.fp32_groups_flat[i].grad = grads_groups_flat[i]
+ param_group = self.optimizer.param_groups[i]
+ if self.has_moe_layers and is_moe_param_group(param_group):
+ if param_group['name'] not in expert_grads_for_norm:
+ expert_grads_for_norm[param_group['name']] = []
+ expert_grads_for_norm[param_group['name']].append(self.fp32_groups_flat[i])
+ else:
+ non_experts_grads_for_norm.append(self.fp32_groups_flat[i])
self.timers(COMPUTE_NORM_TIMER).start()
- all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)
+ all_groups_norm = get_grad_norm(non_experts_grads_for_norm, mpu=self.mpu)
self.timers(COMPUTE_NORM_TIMER).stop()
if self.has_moe_layers:
- all_groups_norm = self._get_norm_with_moe_layers(all_groups_norm)
+ all_groups_norm = get_norm_with_moe_layers(all_groups_norm,
+ mpu=self.mpu,
+ expert_tensors=expert_grads_for_norm,
+ norm_type=self.norm_type)
scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm])
@@ -290,20 +304,6 @@ def step(self, closure=None):
return self.overflow
- def _get_norm_with_moe_layers(self, all_groups_norm):
- #all_groups_norm_old = all_groups_norm
- # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
- if self.using_pipeline:
- pg = self.deepspeed.mpu.get_data_parallel_group()
- else:
- pg = groups._get_data_parallel_group()
- scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=pg))
- scaled_norm_tensor = torch.tensor(scaled_norm, device=self.fp32_groups_flat[0].device, dtype=torch.float)
- dist.all_reduce(scaled_norm_tensor, group=pg)
- all_groups_norm = scaled_norm_tensor.item()
- #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
- return all_groups_norm
-
def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True):
# compute combined scale factor for this group
combined_scale = self.cur_scale
diff --git a/deepspeed/runtime/fp16/unfused_optimizer.py b/deepspeed/runtime/fp16/unfused_optimizer.py
index 14271255df2e..a7fd1910d7b2 100755
--- a/deepspeed/runtime/fp16/unfused_optimizer.py
+++ b/deepspeed/runtime/fp16/unfused_optimizer.py
@@ -11,7 +11,7 @@
import torch
from torch._utils import _flatten_dense_tensors
-from deepspeed.runtime import DeepSpeedOptimizer
+from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
from deepspeed.runtime.utils import get_global_norm, CheckOverflow, get_weight_norm, required_torch_version
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger
diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py
index d1ebe4b2f83d..e068f4a48b4a 100755
--- a/deepspeed/runtime/utils.py
+++ b/deepspeed/runtime/utils.py
@@ -205,6 +205,17 @@ def move_to_device(item, device, criterion_func):
return item
+def get_norm_with_moe_layers_fast(all_groups_norm, group):
+ # This implementation standardizes the grad_norm across ranks. A more precise implementation can be found in 'get_norm_with_moe_layers'.
+ # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
+ scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=group))
+ scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device(), dtype=torch.float)
+ dist.all_reduce(scaled_norm_tensor, group=group)
+ all_groups_norm = scaled_norm_tensor.item()
+ #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
+ return all_groups_norm
+
+
class CheckOverflow(object):
'''Checks for overflow in gradient across parallel process'''
@@ -861,7 +872,7 @@ def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, ep
return global_grad_norm
-def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False):
+def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None):
"""Get norm of an iterable of tensors.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
@@ -884,7 +895,9 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
- total_norm = total_norm_cuda[0].item()
+ if moe_ep_group is not None:
+ dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=moe_ep_group)
+ total_norm = total_norm_cuda[0].item()
else:
if use_graph:
if 'norm_tensors_compute_buffer' not in graph_cache:
@@ -906,6 +919,9 @@ def _norm_tensors(tensor_list, _compute_buffer, _norm_type):
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach()
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
+ if moe_ep_group is not None:
+ dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=moe_ep_group)
+
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
@@ -1048,3 +1064,45 @@ def required_torch_version(min_version=None, max_version=None):
return False
return True
+
+
+def get_norm_with_moe_layers(non_expert_norm, mpu, expert_tensors, norm_type=2):
+ """ Compute the global norm with MoE experts
+
+ Inputs:
+ non_expert_norm (float) : the calculated norm of the non-expert params
+ expert_tensors (Dict[ep_name, List[Tensor]): Dictionary of expert group name to list of grad tensors
+ norm_type (int): the norm to use
+
+ Returns:
+ if norm is (-/+) inf, returns -1
+ otherwise the global norm (float)
+ """
+
+ def to_tensor(v):
+ return get_accelerator().FloatTensor(float(v)).detach()
+
+ group_norms = [non_expert_norm]
+ for exp_name, tensors in expert_tensors.items():
+ group_norm = get_global_norm_of_tensors(input_tensors=tensors,
+ mpu=mpu,
+ norm_type=norm_type,
+ use_graph=False,
+ moe_ep_group=groups._get_expert_parallel_group(exp_name))
+ group_norms.append(group_norm)
+
+ # check if all norms are valid
+ group_norms = torch.stack([to_tensor(norm) for norm in group_norms])
+ if group_norms.eq(-1).any():
+ return -1
+
+ # combine norms
+ if norm_type == inf:
+ total_norm = group_norms.max().item()
+ else:
+ total_norm = group_norms.pow(norm_type).sum()
+ total_norm = total_norm.item()**(1. / norm_type)
+ if total_norm == float('inf') or total_norm == -float('inf'):
+ total_norm = -1
+
+ return total_norm
diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py
index 03813afa4ed1..c6ff216edfcb 100644
--- a/deepspeed/runtime/zero/stage3.py
+++ b/deepspeed/runtime/zero/stage3.py
@@ -11,7 +11,7 @@
from deepspeed.utils import groups
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
-from deepspeed.runtime import ZeROOptimizer
+from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py
index a3b2c59952a1..8593380422c4 100755
--- a/deepspeed/runtime/zero/stage_1_and_2.py
+++ b/deepspeed/runtime/zero/stage_1_and_2.py
@@ -4,13 +4,12 @@
# DeepSpeed Team
import torch
-import os
from deepspeed import comm as dist
from packaging import version as pkg_version
from collections import OrderedDict
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
-from deepspeed.runtime import ZeROOptimizer
+from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, empty_cache, see_memory_usage, inf,
is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups)
@@ -28,7 +27,7 @@
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER,
SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE,
BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
-from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, map_to_flat_opt_states
+from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.utils import groups
@@ -2288,42 +2287,13 @@ def load_state_dict(self,
self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
- self._load_hp_checkpoint_state(checkpoint_folder)
+ self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder)
@property
def param_groups(self):
"""Forward the wrapped optimizer's parameters."""
return self.optimizer.param_groups
- def _load_hp_checkpoint_state(self, checkpoint_dir):
- checkpoint_dir = os.path.join(checkpoint_dir, "zero")
- optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
- assert os.path.isfile(
- optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
- optim_sd = torch.load(optim_state_path)
- self._load_global_state(optim_sd)
-
- tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
- if self.mpu is None:
- logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.")
- tp_world_size = 1
- else:
- tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
- else self.mpu.get_tensor_model_parallel_world_size()
-
- for i, param_group in enumerate(self.optimizer.param_groups):
- # We have an assumption that all params in the same param_group have the same keys
- opt_keys = set()
-
- for lp in self.bit16_groups[i]:
- if lp._hp_mapping is not None:
- #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
- lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
- tp_world_size)
- for key in lp._hp_mapping.get_optim_state_keys():
- opt_keys.add(key)
- map_to_flat_opt_states(param_group['params'][0], self.bit16_groups[i], self.optimizer.state, opt_keys)
-
def _load_global_state(self, sd):
self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler)
self.dynamic_loss_scale = sd.get('dynamic_loss_scale', self.dynamic_loss_scale)
diff --git a/docs/_tutorials/accelerator-abstraction-interface.md b/docs/_tutorials/accelerator-abstraction-interface.md
index db1a6005f793..88a43236ce9d 100644
--- a/docs/_tutorials/accelerator-abstraction-interface.md
+++ b/docs/_tutorials/accelerator-abstraction-interface.md
@@ -79,13 +79,13 @@ torch.distributed.init_process_group(get_accelerator().communication_backend_nam
```
# Run DeepSpeed model on different accelerators
-Once a model is ported with DeepSpeed Accelerator Abstraction Interface, we can run this model on different accelerators using extension to DeepSpeed. DeepSpeed check whether certain extension is installed in the environment to decide whether to use the Accelerator backend in that extension. For example if we wish to run model on Intel GPU, we can install _Intel Extension for DeepSpeed_ following the instruction in [link](https://github.com/intel/intel-extension-for-deepspeed/)
+Once a model is ported with DeepSpeed Accelerator Abstraction Interface, we can run this model on different accelerators using an extension to DeepSpeed. DeepSpeed checks whether a certain extension is installed in the environment to decide whether to use the Accelerator backend in that extension. For example, if we wish to run a model on Intel GPU, we can install _Intel Extension for DeepSpeed_ following the instructions in the following [link](https://github.com/intel/intel-extension-for-deepspeed/)
-After the extension is installed, install DeepSpeed and run model. The model will be running on top of DeepSpeed. Because DeepSpeed installation is also accelerator related, it is recommended to install DeepSpeed accelerator extension before install DeepSpeed.
+After the extension is installed, install DeepSpeed and run the model. The model will be running on top of DeepSpeed. Because DeepSpeed installation is also accelerator related, it is recommended to install DeepSpeed accelerator extension before installing DeepSpeed.
`CUDA_Accelerator` is the default accelerator in DeepSpeed. If no other DeepSpeed accelerator extension is installed, `CUDA_Accelerator` will be used.
-When run a model on different accelerator in a cloud environment, the recommended practice is provision environment for each accelerator in different env with tool such as _anaconda/miniconda/virtualenv_. When run model on different Accelerator, load the env accordingly.
+When running a model on different accelerators in a cloud environment, the recommended practice is to provision an environment for each accelerator in a different env with tools such as _anaconda/miniconda/virtualenv_. When running models on different Accelerator, load the env accordingly.
Note that different accelerator may have different 'flavor' of float16 or bfloat16. So it is recommended to make the model configurable for both float16 and bfloat16, in that way model code does not need to be changed when running on different accelerators.
diff --git a/docs/index.md b/docs/index.md
index e3351ee1a3d7..1ca92019bff2 100755
--- a/docs/index.md
+++ b/docs/index.md
@@ -94,7 +94,7 @@ DeepSpeed has been integrated with several different popular open-source DL fram
| | Documentation |
| ---------------------------------------------------------------------------------------------- | -------------------------------------------- |
| | [Transformers with DeepSpeed](https://huggingface.co/docs/transformers/main/main_classes/deepspeed) |
-| | [Accelerate with DeepSpeed](https://huggingface.co/docs/accelerate/main/en/deepspeed) |
+| | [Accelerate with DeepSpeed](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) |
| | [Lightning with DeepSpeed](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.strategies.DeepSpeedStrategy.html) |
| | [MosaicML with DeepSpeed](https://docs.mosaicml.com/en/latest/trainer/using_the_trainer.html?highlight=deepspeed#deepspeed-integration) |
diff --git a/tests/unit/checkpoint/common.py b/tests/unit/checkpoint/common.py
index 08fa1eb671bd..3fb13b214ea0 100644
--- a/tests/unit/checkpoint/common.py
+++ b/tests/unit/checkpoint/common.py
@@ -86,15 +86,20 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True, load
def compare_state_dicts(state0, state1, expected_mismatch_keys=[]):
- for (k0, s0), (k1, s1) in zip(state0.items(), state1.items()):
- assert k0 == k1, f'failure due to key mismatch {k0} != {k1}'
- if k0 in expected_mismatch_keys:
+ key_set0 = set(k for k in state0.keys() if k not in expected_mismatch_keys)
+ key_set1 = set(k for k in state1.keys() if k not in expected_mismatch_keys)
+ assert key_set0 == key_set1, f'failure due to key mismatch {key_set0} != {key_set1}'
+
+ for k in key_set0:
+ s0 = state0[k]
+ s1 = state1[k]
+ if k in expected_mismatch_keys:
continue
if isinstance(s0, torch.Tensor) and isinstance(s1, torch.Tensor):
assert id(s0) != id(s1), f'Comparing optimizer state tensor against itself: {id(s0)} <====> {id(s1)}'
assert torch.equal(s0.to('cpu'), s1.to('cpu'))
else:
- assert s0 == s1, f'failures with keys = {k0}, {k1}, values = {type(s0[0])} and {type(s1[0])}'
+ assert s0 == s1, f'failures with keys = {k}, {k}, values = {s0} and {s1}'
def compare_opt_state_dicts(state0, state1, expected_mismatch_keys=[]):
diff --git a/tests/unit/checkpoint/test_lr_scheduler.py b/tests/unit/checkpoint/test_lr_scheduler.py
index 4891b4f6fa9b..89c4dd1b49f7 100644
--- a/tests/unit/checkpoint/test_lr_scheduler.py
+++ b/tests/unit/checkpoint/test_lr_scheduler.py
@@ -53,7 +53,7 @@ def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload):
}
if get_accelerator().is_fp16_supported():
config_dict["fp16"] = {"enabled": True}
- elif get_accelerator().is_fp16_supported():
+ elif get_accelerator().is_bf16_supported():
config_dict["bf16"] = {"enabled": True}
hidden_dim = 10
diff --git a/tests/unit/checkpoint/test_universal_checkpoint.py b/tests/unit/checkpoint/test_universal_checkpoint.py
new file mode 100644
index 000000000000..7adfe8410b55
--- /dev/null
+++ b/tests/unit/checkpoint/test_universal_checkpoint.py
@@ -0,0 +1,215 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import deepspeed
+from types import SimpleNamespace
+from torch.utils._pytree import tree_map
+
+from deepspeed.runtime.utils import required_torch_version
+from deepspeed.checkpoint import UNIVERSAL_CHECKPOINT_INFO
+from deepspeed.checkpoint.ds_to_universal import main as convert_to_universal
+
+from unit.common import DistributedTest, DistributedFixture
+from unit.simple_model import *
+from unit.util import bf16_required_version_check
+
+from unit.checkpoint.common import compare_opt_state_dicts, compare_state_dicts
+
+import pytest
+import deepspeed.comm as dist
+
+
+def get_expected_mismatch_keys():
+ # torch 1.2.* stores raw tensor id numbers in checkpoint state which leads to
+ # false positive mismatches in checkpoint state comparisons.
+ # Newer torch versions store tensor ids as 0, 1, 2, ...
+ return [] if required_torch_version(min_version=1.4) else ['params']
+
+
+def maybe_step(t):
+ return not torch.is_tensor(t) or (t.device.type == 'cpu' and t.numel() == 1)
+
+
+def gather_opt_state(optimizer_state):
+
+ def gather_tensor(t):
+
+ if maybe_step(t):
+ return t
+ else:
+ buffer = [torch.zeros_like(t.flatten()) for _ in range(dist.get_world_size())]
+ dist.all_gather(buffer, t.flatten())
+ return torch.cat(buffer)
+
+ return tree_map(gather_tensor, optimizer_state)
+
+
+def remove_pad_in_opt_state(optimizer_state, num_params):
+
+ def remove_pad(t):
+ if maybe_step(t):
+ return t
+ else:
+ return t[:num_params]
+
+ return tree_map(remove_pad, optimizer_state)
+
+
+CP_TAG = "test_tag"
+
+
+def init_ds_engine(model, ds_config, use_torch_adam):
+
+ if use_torch_adam:
+ ds_optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
+ del ds_config["optimizer"]
+ model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, optimizer=ds_optimizer)
+ else:
+ model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, model_parameters=model.parameters())
+
+ return model
+
+
+def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir):
+ if dtype == torch.bfloat16 and not bf16_required_version_check():
+ return
+
+ test_step = 8
+
+ model = SimpleModel(hidden_dim)
+ model = init_ds_engine(model, ds_config, use_torch_adam)
+ data_loader = random_dataloader(model=model,
+ total_samples=test_step,
+ hidden_dim=hidden_dim,
+ device=model.device,
+ dtype=dtype)
+ for batch in data_loader:
+ loss = model(batch[0], batch[1])
+ model.backward(loss)
+ model.step()
+
+ sd = model.optimizer.optimizer.state_dict() if load_optim else None
+
+ client_state = {}
+ client_state[UNIVERSAL_CHECKPOINT_INFO] = {}
+ client_state['iteration'] = test_step
+ model.save_checkpoint(tmpdir, tag=CP_TAG, client_state=client_state)
+
+ cp_dir = os.path.join(tmpdir, CP_TAG)
+ univ_cp_dir = f"{cp_dir}_universal"
+
+ args = SimpleNamespace(input_folder=cp_dir,
+ output_folder=univ_cp_dir,
+ num_extract_workers=1,
+ num_merge_workers=1,
+ keep_temp_folder=False,
+ strict=True)
+
+ dist.barrier()
+ if dist.get_rank() == 0:
+ convert_to_universal(args)
+
+ model_state = model.state_dict()
+ optimizer_state = None
+ if load_optim:
+ optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict())
+
+ if dist.get_rank() == 0:
+ torch.save((model_state, optimizer_state), os.path.join(tmpdir, "baseline_state.pt"))
+
+ dist.barrier()
+
+ return model, sd
+
+
+@pytest.fixture
+def ds_config(zero_stage, dtype):
+ ds_config = {
+ "train_batch_size": 8,
+ "optimizer": {
+ "type": 'Adam'
+ },
+ "zero_optimization": {
+ "stage": zero_stage,
+ }
+ }
+ if dtype == torch.float16:
+ ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8}
+ elif dtype == torch.bfloat16:
+ ds_config["bf16"] = {"enabled": True}
+ return ds_config
+
+
+class _baseline(DistributedFixture):
+ world_size = None
+
+ def run(self, tmpdir, ds_config, zero_stage, dtype, load_optim, use_torch_adam):
+ hidden_dim = 10
+ train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir)
+
+
+class baseline_ws2(_baseline):
+ world_size = 2
+
+
+class baseline_ws4(_baseline):
+ world_size = 4
+
+
+@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32])
+@pytest.mark.parametrize("zero_stage", [1])
+@pytest.mark.parametrize("use_torch_adam", [False, True])
+@pytest.mark.parametrize("load_optim", [False, True])
+class TestZeROUniversalCheckpointDP(DistributedTest):
+
+ def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
+ if dtype == torch.bfloat16 and not bf16_required_version_check():
+ pytest.skip(
+ " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
+ )
+
+ hidden_dim = 10
+ loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt")
+
+ ds_config["checkpoint"] = {"load_universal": True}
+ univ_model = SimpleModel(hidden_dim)
+ univ_model = init_ds_engine(univ_model, ds_config, use_torch_adam)
+ univ_model.load_checkpoint(tmpdir, tag=f"{CP_TAG}_universal", load_optimizer_states=load_optim)
+
+ model_state = univ_model.state_dict()
+ compare_state_dicts(model_state, loaded_model_state)
+
+ if load_optim:
+ optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict())
+ # padding sizes may differ when dp sizes are different
+ param_count = sum(p.numel() for p in univ_model.parameters())
+ optimizer_state = remove_pad_in_opt_state(optimizer_state, param_count)
+ loaded_optimizer_state = remove_pad_in_opt_state(loaded_optimizer_state, param_count)
+
+ compare_opt_state_dicts(optimizer_state, loaded_optimizer_state, get_expected_mismatch_keys())
+
+ # Run training again to verify that the optimizer has necessary states
+ test_step = 8
+ data_loader = random_dataloader(model=univ_model,
+ total_samples=test_step,
+ hidden_dim=hidden_dim,
+ device=univ_model.device,
+ dtype=dtype)
+ for batch in data_loader:
+ loss = univ_model(batch[0], batch[1])
+ univ_model.backward(loss)
+ univ_model.step()
+
+ @pytest.mark.world_size(2)
+ def test_dp_world_size_2to2(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
+ self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
+
+ @pytest.mark.world_size(2)
+ def test_dp_world_size_4to2(self, baseline_ws4, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
+ self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
+
+ @pytest.mark.world_size(4)
+ def test_dp_world_size_2to4(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
+ self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
diff --git a/tests/unit/checkpoint/test_zero_optimizer.py b/tests/unit/checkpoint/test_zero_optimizer.py
index aebad4227358..2312425c8aed 100644
--- a/tests/unit/checkpoint/test_zero_optimizer.py
+++ b/tests/unit/checkpoint/test_zero_optimizer.py
@@ -243,7 +243,11 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l
model, _, _, _ = deepspeed.initialize(config=ds_config,
model=models[0],
model_parameters=models[0].parameters())
- data_loader = random_dataloader(model=model, total_samples=8, hidden_dim=hidden_dim, device=model.device)
+ run_steps = 8
+ data_loader = random_dataloader(model=model,
+ total_samples=run_steps,
+ hidden_dim=hidden_dim,
+ device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
diff --git a/tests/unit/runtime/half_precision/test_bf16.py b/tests/unit/runtime/half_precision/test_bf16.py
index d42a4b62cd10..0af14abc3be5 100644
--- a/tests/unit/runtime/half_precision/test_bf16.py
+++ b/tests/unit/runtime/half_precision/test_bf16.py
@@ -288,8 +288,8 @@ def test(self, stage=2):
model.step()
-@pytest.mark.parametrize("comp_type", [torch.float16, torch.bfloat16, torch.float], ids=["fp16", "bfp16", "fp32"])
-@pytest.mark.parametrize("comm_type", [torch.float16, torch.bfloat16, None], ids=["fp16", "bfp16", "default"])
+@pytest.mark.parametrize("comp_type", [torch.float16, torch.bfloat16, torch.float], ids=["fp16", "bf16", "fp32"])
+@pytest.mark.parametrize("comm_type", [torch.float16, torch.bfloat16, None], ids=["fp16", "bf16", "default"])
class TestZeroDtypeCocktail(DistributedTest):
world_size = 2
@@ -304,7 +304,7 @@ def test(self, comp_type, comm_type):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
- type_str = {torch.float16: "fp16", torch.bfloat16: "bfp16"}
+ type_str = {torch.float16: "fp16", torch.bfloat16: "bf16"}
config_dict = {
"train_micro_batch_size_per_gpu": 2,
diff --git a/tests/unit/runtime/half_precision/test_fp16.py b/tests/unit/runtime/half_precision/test_fp16.py
index e54fe352bf5b..9229794b39f8 100644
--- a/tests/unit/runtime/half_precision/test_fp16.py
+++ b/tests/unit/runtime/half_precision/test_fp16.py
@@ -13,6 +13,7 @@
from deepspeed.runtime.utils import required_torch_version
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import CPUAdamBuilder
+from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
try:
from apex import amp # noqa: F401 # type: ignore
@@ -215,8 +216,10 @@ def mock_unscale_and_clip_grads(grads_groups_flat, total_norm, apply_scale=True)
# initialize MoE
model = SimpleMoEModel(hidden_dim, ep_size=2)
+ param_group = {'params': [p for p in model.parameters()], 'name': 'random-unique-name'}
+ params = split_params_into_different_moe_groups_for_optimizer(param_group)
# optimizer = torch.optim.AdamW(params=model.parameters())
- optimizer = FusedAdam(params=model.parameters())
+ optimizer = FusedAdam(params=params)
engine, optimizer, _, _ = deepspeed.initialize(config=config_dict,
model=model,
optimizer=optimizer,