diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml
index 0272829e8286..4e9ceb32b6b1 100644
--- a/.github/workflows/hpu-gaudi2.yml
+++ b/.github/workflows/hpu-gaudi2.yml
@@ -68,7 +68,6 @@ jobs:
(test_flops_profiler.py and test_flops_profiler_in_inference)
test_get_optim_files.py
test_groups.py
- test_init_on_device.py
test_partition_balanced.py
(test_adamw.py and TestAdamConfigs)
test_coalesced_collectives.py
diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml
index 3ce406948432..484948b28e34 100644
--- a/.github/workflows/nv-a6000.yml
+++ b/.github/workflows/nv-a6000.yml
@@ -47,7 +47,8 @@ jobs:
- name: Install deepspeed
run: |
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
- python -m pip install pydantic==1.10.11
+ # Update packages included in the container that do not support pydantic 2+ to versions that do
+ python -m pip install thinc spacy confection --upgrade
python -m pip install .[dev,1bit,autotuning,inf]
ds_report
- name: Python environment
diff --git a/.github/workflows/nv-nightly.yml b/.github/workflows/nv-nightly.yml
index b1e8c042214f..8658ff5d2348 100644
--- a/.github/workflows/nv-nightly.yml
+++ b/.github/workflows/nv-nightly.yml
@@ -2,6 +2,9 @@ name: nv-nightly
on:
workflow_dispatch:
+ pull_request:
+ paths:
+ - '.github/workflows/nv-nightly.yml'
schedule:
- cron: "0 0 * * *"
@@ -25,7 +28,7 @@ jobs:
- name: Install pytorch
run: |
- pip install -U --cache-dir $TORCH_CACHE torch==1.13.1 torchvision --index-url https://download.pytorch.org/whl/cu117
+ pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu118
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
@@ -34,7 +37,7 @@ jobs:
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
- # git checkout 1cc453d33
+ git checkout v4.42.4
git rev-parse --short HEAD
pip install .
@@ -55,7 +58,7 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
- pytest $PYTEST_OPTS --forked -m 'nightly' unit/ --torch_ver="1.13" --cuda_ver="11.7"
+ pytest $PYTEST_OPTS --forked -m 'nightly' unit/ --torch_ver="2.4" --cuda_ver="11.8"
- name: Open GitHub issue if nightly CI fails
if: ${{ failure() && (github.event_name == 'schedule') }}
diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml
index a506bb27fda4..72ba8abbd95d 100644
--- a/.github/workflows/nv-pre-compile-ops.yml
+++ b/.github/workflows/nv-pre-compile-ops.yml
@@ -36,7 +36,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
- DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
+ DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_GDS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
- name: DS Report
run: |
ds_report
diff --git a/.github/workflows/xpu-max1100.yml b/.github/workflows/xpu-max1100.yml
index c5a23fe3f53f..adeeb0acade2 100644
--- a/.github/workflows/xpu-max1100.yml
+++ b/.github/workflows/xpu-max1100.yml
@@ -21,7 +21,7 @@ on:
- "deepspeed/runtime/zero/parameter_offload.py"
- "deepspeed/runtime/pipe/engine.py"
- "deepspeed/runtime/utils.py"
- - "opbuilder/xpu/**"
+ - "op_builder/xpu/**"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
@@ -36,38 +36,36 @@ jobs:
unit-tests:
runs-on: [self-hosted, intel, xpu]
container:
- image: intel/intel-extension-for-pytorch:2.1.30-xpu
+ image: intel/oneapi-basekit:2024.1.1-devel-ubuntu22.04
ports:
- 80
options: --privileged -it --rm --device /dev/dri:/dev/dri -v /dev/dri/by-path:/dev/dri/by-path --ipc=host --cap-add=ALL
steps:
- uses: actions/checkout@v4
- - name: Check container state
- shell: bash
- run: |
- ldd --version
- python -c "import torch; print('torch:', torch.__version__, torch)"
- python -c "import torch; import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available())"
-
- - name: Install deepspeed
+ - name: Install prerequisite
run: |
- pip install py-cpuinfo
+ apt-get update
+ apt-get install clinfo libaio-dev python3-pip -y
+ pip install torch==2.1.0.post2 -f https://developer.intel.com/ipex-whl-stable-xpu
+ pip install intel-extension-for-pytorch==2.1.30+xpu -f https://developer.intel.com/ipex-whl-stable-xpu
+ pip install intel-extension-for-pytorch-deepspeed==2.1.30 -f https://developer.intel.com/ipex-whl-stable-xpu
+ pip install oneccl_bind_pt==2.1.300+xpu -f https://developer.intel.com/ipex-whl-stable-xpu
+ pip install torchvision==0.16.0.post2 -f https://developer.intel.com/ipex-whl-stable-xpu
+ pip install py-cpuinfo numpy==1.26
pip install .[dev,autotuning]
- ds_report
- python -c "from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)"
- - name: Python environment
+ - name: Check container state
run: |
+ ldd --version
+ ds_report
+ python3 -c "import torch; print('torch:', torch.__version__, torch)"
+ python3 -c "import torch; import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available())"
+ python3 -c "from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)"
pip list
- name: Unit tests
run: |
- pip install pytest pytest-timeout tabulate tensorboard wandb
- export ONEAPI_ROOT=/opt/intel/oneapi/redist
- export FI_PROVIDER_PATH=$ONEAPI_ROOT/opt/mpi/libfabric/lib/prov
- export LD_LIBRARY_PATH=$ONEAPI_ROOT/opt/mpi/libfabric/lib:$LD_LIBRARY_PATH
- export LD_LIBRARY_PATH=$ONEAPI_ROOT/lib:$LD_LIBRARY_PATH
cd tests/unit
pytest --verbose accelerator/*
pytest --verbose autotuning/*
@@ -75,8 +73,10 @@ jobs:
pytest --verbose checkpoint/test_moe_checkpoint.py
pytest --verbose checkpoint/test_shared_weights.py
pytest --verbose launcher/test_ds_arguments.py launcher/test_run.py
+ pytest --verbose model_parallelism/*
pytest --verbose moe/test_moe_tp.py
pytest --verbose monitor/*
+ pytest --verbose utils/*
pytest --verbose runtime/test_ds_config_model.py
pytest --verbose runtime/pipe/test_pipe_schedule.py
pytest --verbose runtime/zero/test_zero_config.py
diff --git a/README.md b/README.md
index 304169b56777..2f6661ef5860 100755
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@
DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) .
-
+* [2024/08] [DeepSpeed on Windows](https://github.com/microsoft/DeepSpeed/tree/master/blogs/windows/08-2024/README.md) [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/windows/08-2024/japanese/README.md)]
* [2024/08] [DeepNVMe: Improving DL Applications through I/O Optimizations](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-gds/README.md) [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-gds/japanese/README.md)]
* [2024/07] [DeepSpeed Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/README.md) [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/japanese/README.md)]
* [2024/03] [DeepSpeed-FP6:The power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)]
diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py
index 485b205f3e67..1f407e86787e 100644
--- a/accelerator/hpu_accelerator.py
+++ b/accelerator/hpu_accelerator.py
@@ -42,9 +42,8 @@ def handles_memory_backpressure(self):
return True
def device_name(self, device_index=None):
- if device_index is None:
- return 'hpu'
- return 'hpu:{}'.format(device_index)
+ # ignoring device_index.
+ return 'hpu'
def device(self, device_index=None):
return torch.device(self.device_name(device_index))
diff --git a/blogs/windows/08-2024/README.md b/blogs/windows/08-2024/README.md
new file mode 100644
index 000000000000..34e11bd47792
--- /dev/null
+++ b/blogs/windows/08-2024/README.md
@@ -0,0 +1,101 @@
+
+
+# DeepSpeed on Windows
+
+
+
+# Introduction
+
+DeepSpeed is a popular open-source deep learning optimization library that makes distributed training and inference easy, efficient, and effective. DeepSpeed has been widely used to train a variety of state-of-the-art models, including Phi-3, Megatron-Turing-530B, BLOOM-176B, and Arctic because of its rich suite of sophisticated optimizations (e.g., ZeRO, 3D parallelism, MoE, etc.). However, the lack of native support for Microsoft Windows, the most popular operating system, means that DeepSpeed innovations are inaccessible to many AI developers and users. To address this problem, we started an effort to make DeepSpeed run natively with full features on Windows, while ensuring the same ease-of-use enjoyed on Linux.
+
+In this blog, we are pleased to announce some early achievements on this journey: DeepSpeed can now be installed in Windows and run natively for single-GPU training, finetuning, and inferencing. Importantly, both the installation and usage experiences are identical to those on Linux. Furthermore, the finetuning and inferencing workloads demonstrate the functioning of three critical DeepSpeed features, HuggingFace Transformers integration, LoRA support, and CPU Offloading. DeepSpeed on Windows is available in DeepSpeed versions 0.14.5 and above. In the rest of this blog, we present examples to demonstrate these achievements.
+
+# Evaluation Environment
+We conducted the experiments on a Surface Laptop Studio 2 running Windows 11 Version 23H2 and Build 22631.3880. The laptop is equipped with a single NVIDIA RTX A2000 GPU with 4GB VRAM. We used Pytorch version 2.3.0 and HuggingFace Transformers version 4.41.2. The example scripts used are from the [DeepSpeedExamples repo](https://github.com/microsoft/DeepSpeedExamples), therefore you need to clone the repo before running any of the following examples.
+
+# Installation
+DeepSpeed can be installed on Windows in one of two ways. The easier way is to use the pip package manager, while the other is to build from source. The prerequisites for in both cases are Python 3.x and Pytorch with CUDA support.
+
+## Installing via pip
+To install DeepSpeed, simply run: `pip install deepspeed`. This will install the latest version of DeepSpeed (0.14.5 at this time). Unlike the Linux counterpart, the Windows version comes with all the operators already prebuilt, so there is no need to have a CUDA SDK or C++ compiler installed.
+
+
+
+
+
+
+ pip installation of DeepSpeed on Windows.
+
+
+
+## Building from Source
+To build DeepSpeed from source, you need to clone the DeepSpeed repository and run the `build_win.bat` compilation script.
+
+
+## Validating Installation
+Regardless of the installation choice, you can check that the installation was successful by running ds_report. The output should look like this:
+
+
+
+
+
+
+
+ ds_report output confirming Windows installation of DeepSpeed.
+
+
+# Pretraining Examples
+We use an image classification model, CIFAR10, and a language model, BERT, to demonstrate pretraining on Windows with DeepSpeed.
+
+## Pretraining CIFAR10
+The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py –deepspeed`. The final output should look something like this:
+
+
+
+
+
+ Pretraining CIFAR10 model on Windows using DeepSpeed.
+
+
+## Pretraining BERT
+The scripts and codes for the BERT pretraining example are available in the following path: DeepSpeedExamples\training\HelloDeepSpeed. You can launch the BERT pretraining experiment using the following command: `deepspeed train_bert_ds.py --checkpoint_dir experiment_deepspeed`. The final output should look like this:
+
+
+
+
+
+
+ Pretraining BERT model on Windows using DeepSpeed.
+
+
+# Fine Tuning Example
+We demonstrate fine tuning capability by using the supervised fine tuning (SFT) step of DeepSpeed-Chat application. We conduct SFT of the HuggingFace facebook/opt-125m model while enabling LoRA and CPU offloading memory optimizations. The command line for running this example is as follows:
+deepspeed training\step1_supervised_finetuning\main.py --model_name_or_path facebook/opt-125m --gradient_accumulation_steps 8 --lora_dim 128 --only_optimize_lora --print_loss --zero_stage 2 --deepspeed --dtype bf16 --offload --output_dir output
+The output should look like this:
+
+
+
+
+
+
+ Supervised Finetuning of facebook/opt-125m model on Windows using DeepSpeed.
+
+
+# Inference Example
+We demonstrate inference capability by using ZeRO-Inference for token generation. ZeRO-Inference reduces hardware cost of inferencing by offloading to CPU or NVMe memories. We use the example scripts here to run token generation using Llama-2-7B model from HuggingFace. We offload the model weights to CPU memory since the 4GB VRAM is insufficient to host both the model and the generation working set. We use the following command line to generate 32 tokens from a prompt of 8 tokens:
+deepspeed run_model.py --model meta-llama/Llama-2-7b-hf --batch-size 64 --prompt-len 8 --gen-len 32 --cpu-offload
+The output will look something like this:
+
+
+
+
+
+
+ LLAMA2-7B token generation on Windows using ZeRO-Inference.
+
+
+# Summary
+Enabling DeepSpeed, a popular deep learning framework, to run natively on Windows, the most popular operating system, is a crucial step towards empowering every person and every organization to benefit from the ongoing AI revolution. In this blog, we have shared early results of our work towards this goal. Although Windows support of DeepSpeed is a work-in-progress, we hope that the above updates are encouraging and already useful to users. The next items on our roadmap include running on multiple GPUs, weight quantization, and performance studies.
+
+# Acknowledgements
+This work is a result of significant contributions from current and former DeepSpeed members including Costin Eseanu, Logan Adams, Elton Zheng, Reza Yazdani Aminabadi, Martin Cai, and Olatunji Ruwase. We also acknowledge the valuable contributions of DeepSpeed users who righteously demanded this feature, provided critical workarounds, partial solutions, and constructive feedback, and most importantly, stuck with us.
diff --git a/blogs/windows/08-2024/japanese/README.md b/blogs/windows/08-2024/japanese/README.md
new file mode 100644
index 000000000000..7e437f737f58
--- /dev/null
+++ b/blogs/windows/08-2024/japanese/README.md
@@ -0,0 +1,123 @@
+
+
+# DeepSpeedのWindowsサポート
+
+
+
+# はじめに
+
+DeepSpeedは、分散学習と推論を簡単かつ効率的に行うための人気のあるオープンソースの深層学習最適化ライブラリです。DeepSpeedは、その豊富かつ高度な最適化機能(例:ZeRO、3D parallelism, MoEなど)のおかげで、Phi-3、Megatron-Turing-530B、BLOOM-176B、Arcticなどの最先端モデルの学習に広く利用されています。しかし、最も普及しているオペレーティングシステムであるMicrosoft Windowsをネイティブにサポートしていなかったため、多くのAI開発者やユーザーが、DeepSpeedの革新的な機能を利用できない状態でした。この問題を解決するため、DeepSpeedの完全な機能をWindows上でネイティブに実行し、Linux上と同じ使いやすさを実現するための取り組みを開始しました。
+
+このブログでは、この取り組みの最初の成果をお知らせします。現在、DeepSpeedはWindowsにインストールし、単一GPUでの学習、ファインチューニング、および推論をネイティブに実行できるようになりました。ここで重要なこととして、インストールと利用は、Linuxとまったく同じように行えます。ファインチューニングと推論のワークロードを通じて、HuggingFace Transformers との統合、LoRAのサポート、CPUオフロードの3つの重要なDeepSpeedの機能が、正しく動作していることが確認できました。このWindowsサポートは、バージョン0.14.5以降で利用可能です。このブログの残りの部分では、これらの成果を示す例を紹介します。
+
+# テスト環境
+
+Windows 11 Version 23H2 および Build 22631.3880 を実行している Surface Laptop Studio 2 でテストを行いました。このハードウェアには、4GBのVRAMを搭載した NVIDIA RTX A2000 GPU が1つ搭載されています。また、PyTorchバージョン 2.3.0 および HuggingFace Transformersバージョン 4.41.2 を使用しました。使用したサンプルスクリプトは[DeepSpeedExamplesリポジトリ](https://github.com/microsoft/DeepSpeedExamples)から取得できます。以下の例を実行する前にリポジトリをクローンしてください。
+
+# インストール
+
+DeepSpeedは、2つの方法でWindowsにインストールできます。より簡単な方法は、pipパッケージマネージャーを使用することで、もう一方はソースからビルドする方法です。どちらの場合も、Python 3.xとCUDAサポート付きのPyTorchが必要です。
+
+## pipを使用したインストール
+
+DeepSpeedをインストールするには、単に次のコマンドを実行します: `pip install deepspeed`。
+これにより、最新バージョンのDeepSpeed(現時点では0.14.5)がインストールされます。Linux版とは異なり、Windows版ではすべてのオペレーターがすでにビルド済みであるため、CUDA SDKやC++コンパイラをインストールする必要はありません。
+
+
+
+
+
+
+ pipによるWindowsへのDeepSpeedのインストール
+
+
+
+## ソースからのビルド
+
+ソースからDeepSpeedをビルドするには、DeepSpeedリポジトリをクローンし、コンパイルスクリプトである `build_win.bat` を実行する必要があります。
+
+## インストールの検証
+
+インストール方法にかかわらず、`ds_report`を実行してインストールが成功したかどうかを確認できます。出力は次のようになります:
+
+
+
+
+
+
+ DeepSpeedのWindowsインストールを確認するds_reportの出力
+
+
+# 事前学習の例
+
+Windows上でDeepSpeedを使用した事前学習の例として、画像分類モデルCIFAR10と言語モデルBERTの実行例を示します。
+
+## CIFAR10の事前学習
+
+CIFAR10の事前学習に必要なスクリプトとコードは、次のパスにあります: `DeepSpeedExamples\training\cifar`
+
+以下のコマンドを使用してCIFAR10の事前学習を開始できます: `deepspeed cifar10_deepspeed.py –deepspeed`
+
+出力は次のようになります。
+
+
+
+
+
+
+ DeepSpeedによるWindowsでのCIFAR10モデルの事前学習
+
+
+## BERTの事前学習
+
+BERTの事前学習に必要なスクリプトとコードは、次のパスにあります: `DeepSpeedExamples\training\HelloDeepSpeed`
+
+以下のコマンドを使用してBERTの事前学習を開始できます: `deepspeed train_bert_ds.py --checkpoint_dir experiment_deepspeed`
+
+出力は次のようになります。
+
+
+
+
+
+
+ DeepSpeedによるWindowsでのBERTモデルの事前学習
+
+
+# ファインチューニングの例
+
+DeepSpeed-Chatアプリケーションの教師ありファインチューニング(supervised fine tuning; SFT)を使用して、ファインチューニングの機能を示します。LoRAおよびCPUオフロードメモリ最適化を有効にして、 HuggingFace の `facebook/opt-125m` モデルのSFTを実施します。この例を実行するためのコマンドラインは次のとおりです: `deepspeed training\step1_supervised_finetuning\main.py --model_name_or_path facebook/opt-125m --gradient_accumulation_steps 8 --lora_dim 128 --only_optimize_lora --print_loss --zero_stage 2 --deepspeed --dtype bf16 --offload --output_dir output`
+
+出力は次のようになります。
+
+
+
+
+
+
+ DeepSpeedを使用したWindowsでの facebook/opt-125m モデルのファインチューニング
+
+
+# 推論の例
+
+推論の機能を示すために、トークン生成のためのZeRO-Inferenceを使用します。ZeRO-Inferenceは、CPUまたはNVMeメモリにオフロードすることで推論のハードウェアコストを削減します。ここでは、サンプルスクリプトを使用して、HuggingFaceのLlama-2-7Bモデルを使用したトークン生成を実行します。4GBのVRAMではモデルと生成処理の両方を実効するのに十分ではないため、モデルパラメータをCPUメモリにオフロードします。
+
+次のコマンドラインを使用して、8トークンのプロンプトから32トークンを生成します: `deepspeed run_model.py --model meta-llama/Llama-2-7b-hf --batch-size 64 --prompt-len 8 --gen-len 32 --cpu-offload`
+
+出力は次のようになります。
+
+
+
+
+
+
+ DeepSpeedのZeRO-InferenceによるWindowsでのLLAMA2-7Bのトークン生成
+
+
+# まとめ
+
+最も広く使われているオペレーティングシステムであるWindowsで、深層学習フレームワークであるDeepSpeedをネイティブに実行できるようにすることは、多くの人と組織が、今まさに進行中のAI革命の恩恵を受けるための重要な一歩です。このブログでは、この目標に向けたプロジェクトの、最初の成果を共有しました。Windowsのサポートは現在進行中のプロジェクトですが、今回の成果が多くのユーザにとって活用され、またさらに発展していけることを願っています。次のロードマップには、複数のGPUでの実行、モデルパラメータの量子化、パフォーマンスの詳細な分析が含まれます。
+
+# 謝辞
+
+このプロジェクトは、Costin Eseanu、Logan Adams、Elton Zheng、Reza Yazdani Aminabadi、Martin Cai、Olatunji Ruwaseを含むDeepSpeedメンバーによる大きな貢献の結果です。また、この機能を必要とし、様々な問題の解決策や、建設的なフィードバックを提供し、私たちと共に歩んでくれたDeepSpeedユーザーの重要な貢献に感謝します。
diff --git a/blogs/windows/08-2024/media/bert_training.png b/blogs/windows/08-2024/media/bert_training.png
new file mode 100644
index 000000000000..c5935e47747e
Binary files /dev/null and b/blogs/windows/08-2024/media/bert_training.png differ
diff --git a/blogs/windows/08-2024/media/cifar10_training.png b/blogs/windows/08-2024/media/cifar10_training.png
new file mode 100644
index 000000000000..99f3fa25bc70
Binary files /dev/null and b/blogs/windows/08-2024/media/cifar10_training.png differ
diff --git a/blogs/windows/08-2024/media/ds_report.png b/blogs/windows/08-2024/media/ds_report.png
new file mode 100644
index 000000000000..43d82d724ed2
Binary files /dev/null and b/blogs/windows/08-2024/media/ds_report.png differ
diff --git a/blogs/windows/08-2024/media/llama2-7b_inference.png b/blogs/windows/08-2024/media/llama2-7b_inference.png
new file mode 100644
index 000000000000..f5874468a854
Binary files /dev/null and b/blogs/windows/08-2024/media/llama2-7b_inference.png differ
diff --git a/blogs/windows/08-2024/media/opt125m_finetuning.png b/blogs/windows/08-2024/media/opt125m_finetuning.png
new file mode 100644
index 000000000000..ed6d1522e3b3
Binary files /dev/null and b/blogs/windows/08-2024/media/opt125m_finetuning.png differ
diff --git a/blogs/windows/08-2024/media/win_pip_install_deepspeed.png b/blogs/windows/08-2024/media/win_pip_install_deepspeed.png
new file mode 100644
index 000000000000..3b87c95ef144
Binary files /dev/null and b/blogs/windows/08-2024/media/win_pip_install_deepspeed.png differ
diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp
new file mode 100644
index 000000000000..dc820be528d0
--- /dev/null
+++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.cpp
@@ -0,0 +1,38 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include "deepspeed_aio_op_desc.h"
+
+using namespace std;
+
+io_op_desc_t::io_op_desc_t(const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const int num_threads,
+ const bool validate)
+ : _read_op(read_op),
+ _buffer(buffer),
+ _fd(fd),
+ _filename(filename),
+ _file_num_bytes(file_num_bytes),
+ _num_threads(num_threads),
+ _num_bytes_per_thread(file_num_bytes / num_threads),
+ _validate(validate)
+{
+}
+
+char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }
+
+void io_op_desc_t::finish() {}
+
+void io_op_desc_t::validate() {}
+
+void io_op_desc_t::run(const int tid,
+ std::unique_ptr& aio_ctxt,
+ deepspeed_aio_config_t* aio_config)
+{
+}
diff --git a/csrc/aio/py_lib/deepspeed_aio_op_desc.h b/csrc/aio/py_lib/deepspeed_aio_op_desc.h
new file mode 100644
index 000000000000..350d28d29d58
--- /dev/null
+++ b/csrc/aio/py_lib/deepspeed_aio_op_desc.h
@@ -0,0 +1,41 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#ifndef _IO_OP_DESC_T_
+#define _IO_OP_DESC_T_
+#include
+#include
+#include "deepspeed_py_aio.h"
+
+struct io_op_desc_t {
+ const bool _read_op;
+ torch::Tensor _buffer;
+ int _fd;
+ const std::string _filename;
+ const long long int _file_num_bytes;
+ const int _num_threads;
+ const long long int _num_bytes_per_thread;
+ torch::Tensor _contiguous_buffer;
+ const bool _validate;
+
+ io_op_desc_t(const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const int num_threads,
+ const bool validate);
+
+ virtual void run(const int tid,
+ std::unique_ptr& aio_ctxt,
+ deepspeed_aio_config_t* aio_config);
+
+ virtual char* data_ptr() const;
+
+ virtual void validate();
+
+ virtual void finish();
+};
+#endif // _IO_OP_DESC_T_
diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.cpp b/csrc/aio/py_lib/deepspeed_aio_thread.cpp
index c852711a28c0..30c3b4914397 100644
--- a/csrc/aio/py_lib/deepspeed_aio_thread.cpp
+++ b/csrc/aio/py_lib/deepspeed_aio_thread.cpp
@@ -9,50 +9,8 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
#include "deepspeed_aio_thread.h"
-#if defined(__ENABLE_CANN__)
-#include "torch_npu/csrc/framework/utils/OpAdapter.h"
-#include "torch_npu/csrc/framework/utils/UtilForOpAdapter.h"
-#endif
-
using namespace std;
-io_op_desc_t::io_op_desc_t(const bool read_op,
- const torch::Tensor& buffer,
- const int fd,
- const char* filename,
- const long long int num_bytes,
- const bool validate)
- : _read_op(read_op),
- _buffer(buffer),
- _fd(fd),
- _filename(filename),
- _num_bytes(num_bytes),
- _validate(validate)
-{
- _cpu_buffer = (_buffer.is_cuda() || _buffer.is_xpu()
-#if defined(__ENABLE_CANN__)
- || torch_npu::utils::is_npu(_buffer)
-#endif
- )
- ? _buffer.to(torch::kCPU).pin_memory()
- : _buffer;
- _contiguous_buffer = _cpu_buffer.contiguous();
-}
-
-char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }
-
-void io_op_desc_t::fini()
-{
- if (_read_op && _buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); }
- if (_read_op && _buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); }
-#if defined(__ENABLE_CANN__)
- if (_read_op && torch_npu::utils::is_npu(_buffer)) {
- auto device = at::Device("npu:0");
- _buffer.copy_(_cpu_buffer.to(device));
- }
-#endif
-}
-
deepspeed_aio_thread_t::deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config)
: _tid(tid),
_aio_config(aio_config),
@@ -79,18 +37,7 @@ void deepspeed_aio_thread_t::run()
}
if (next_io_op) {
- const auto base_offset = next_io_op->_num_bytes * _tid;
-
- std::unique_ptr xfer_ctxt(new io_xfer_ctxt(
- next_io_op->_fd, base_offset, next_io_op->_num_bytes, next_io_op->data_ptr()));
-
- if (_aio_config._overlap_events) {
- do_aio_operation_overlap(
- next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
- } else {
- do_aio_operation_sequential(
- next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
- }
+ next_io_op->run(_tid, _aio_ctxt, &_aio_config);
{
std::lock_guard lock(_complete_sync._mutex);
diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.h b/csrc/aio/py_lib/deepspeed_aio_thread.h
index 20799ecbb018..a192804db13d 100644
--- a/csrc/aio/py_lib/deepspeed_aio_thread.h
+++ b/csrc/aio/py_lib/deepspeed_aio_thread.h
@@ -10,28 +10,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
#include
#include
#include
-#include "deepspeed_py_aio.h"
-
-struct io_op_desc_t {
- const bool _read_op;
- torch::Tensor _buffer;
- int _fd;
- const std::string _filename;
- const long long int _num_bytes;
- torch::Tensor _cpu_buffer;
- torch::Tensor _contiguous_buffer;
- const bool _validate;
-
- io_op_desc_t(const bool read_op,
- const torch::Tensor& buffer,
- const int fd,
- const char* filename,
- const long long int num_bytes,
- const bool validate);
-
- char* data_ptr() const;
- void fini();
-};
+#include "deepspeed_cpu_op.h"
struct thread_sync_t {
std::mutex _mutex;
diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.cpp b/csrc/aio/py_lib/deepspeed_cpu_op.cpp
new file mode 100644
index 000000000000..41790b99bb88
--- /dev/null
+++ b/csrc/aio/py_lib/deepspeed_cpu_op.cpp
@@ -0,0 +1,72 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include "deepspeed_cpu_op.h"
+
+using namespace std;
+
+cpu_op_desc_t::cpu_op_desc_t(const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const int num_threads,
+ const bool validate)
+ : io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, num_threads, validate),
+ _cpu_buffer(buffer)
+{
+ // Need to use CPU bounce buffer if buffer is not a page-locked DRAM memory.
+ _use_bounce_buffer = !(_buffer.is_cpu() && _buffer.is_pinned());
+ if (_use_bounce_buffer) {
+ if (_read_op) {
+ auto options = torch::TensorOptions()
+ .dtype(_buffer.dtype())
+ .layout(_buffer.layout())
+ .device(torch::kCPU);
+ _cpu_buffer = torch::empty(_buffer.nbytes(), options).pin_memory();
+ } else {
+ _cpu_buffer = _buffer.to(torch::kCPU).pin_memory();
+ }
+ }
+ _contiguous_buffer = _cpu_buffer.contiguous();
+}
+
+char* cpu_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }
+
+void cpu_op_desc_t::finish()
+{
+ if (_read_op) {
+ if (_buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); }
+ if (_buffer.is_xpu()) { _buffer.copy_(_cpu_buffer.to(torch::kXPU)); }
+#if defined(__ENABLE_CANN__)
+ if (torch_npu::utils::is_npu(_buffer)) {
+ auto device = at::Device("npu:0");
+ _buffer.copy_(_cpu_buffer.to(device));
+ }
+#endif
+ }
+}
+
+void cpu_op_desc_t::validate()
+{
+ validate_aio_operation(_read_op, _filename.c_str(), data_ptr(), _file_num_bytes);
+}
+
+void cpu_op_desc_t::run(const int tid,
+ std::unique_ptr& aio_ctxt,
+ deepspeed_aio_config_t* aio_config)
+{
+ assert(tid < _num_threads);
+ const auto base_offset = _num_bytes_per_thread * tid;
+
+ std::unique_ptr xfer_ctxt(
+ new io_xfer_ctxt(_fd, base_offset, _num_bytes_per_thread, data_ptr()));
+
+ if (aio_config->_overlap_events) {
+ do_aio_operation_overlap(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr);
+ } else {
+ do_aio_operation_sequential(_read_op, aio_ctxt, xfer_ctxt, aio_config, nullptr);
+ }
+}
diff --git a/csrc/aio/py_lib/deepspeed_cpu_op.h b/csrc/aio/py_lib/deepspeed_cpu_op.h
new file mode 100644
index 000000000000..da96dd2b1d50
--- /dev/null
+++ b/csrc/aio/py_lib/deepspeed_cpu_op.h
@@ -0,0 +1,31 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+#include
+#include "deepspeed_aio_op_desc.h"
+
+struct cpu_op_desc_t : io_op_desc_t {
+ torch::Tensor _cpu_buffer;
+ bool _use_bounce_buffer;
+
+ cpu_op_desc_t(const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const int num_threads,
+ const bool validate);
+
+ void run(const int tid,
+ std::unique_ptr& aio_ctxt,
+ deepspeed_aio_config_t* aio_config);
+
+ char* data_ptr() const;
+
+ void validate();
+
+ void finish();
+};
diff --git a/csrc/aio/py_lib/deepspeed_py_aio.cpp b/csrc/aio/py_lib/deepspeed_py_aio.cpp
index 0556f5aa8168..eac268d33433 100644
--- a/csrc/aio/py_lib/deepspeed_py_aio.cpp
+++ b/csrc/aio/py_lib/deepspeed_py_aio.cpp
@@ -4,9 +4,6 @@
// DeepSpeed Team
/*
-Copyright 2020 The Microsoft DeepSpeed Team
-Licensed under the MIT license.
-
Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
*/
diff --git a/csrc/aio/py_lib/deepspeed_py_aio.h b/csrc/aio/py_lib/deepspeed_py_aio.h
index 11d5225de9f1..ba794db5440d 100644
--- a/csrc/aio/py_lib/deepspeed_py_aio.h
+++ b/csrc/aio/py_lib/deepspeed_py_aio.h
@@ -4,10 +4,7 @@
// DeepSpeed Team
/*
-Copyright 2020 The Microsoft DeepSpeed Team
-Licensed under the MIT license.
-
-Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
+Functionality for swapping tensors to/from (NVMe) storage devices.
*/
#include
diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp
index 23ddabe260d4..c7ca5e82afde 100644
--- a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp
+++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp
@@ -4,293 +4,21 @@
// DeepSpeed Team
/*
-Copyright 2020 The Microsoft DeepSpeed Team
-Licensed under the MIT license.
-
Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
*/
#include "deepspeed_py_aio_handle.h"
+#include
using namespace std;
-static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); }
-
deepspeed_aio_handle_t::deepspeed_aio_handle_t(const int block_size,
const int queue_depth,
const bool single_submit,
const bool overlap_events,
const int num_threads)
- : _aio_ctxt(new aio_context(block_size, queue_depth)),
- _single_submit(single_submit),
- _overlap_events(overlap_events),
- _num_threads(num_threads),
- _aio_config(block_size, queue_depth, single_submit, overlap_events, false),
- _num_pending_ops(0),
- _pinned_tensor_mgr(new deepspeed_pin_tensor_t())
-{
- for (auto i = 0; i < num_threads; ++i) {
- _thread_contexts.push_back(std::make_shared(i, _aio_config));
- }
-
- for (auto& ctxt : _thread_contexts) {
- _threads.push_back(std::thread(_start_aio_thread, ctxt));
- }
-}
-
-deepspeed_aio_handle_t::~deepspeed_aio_handle_t()
-{
- _stop_threads();
- for (auto& thr : _threads) { thr.join(); }
-}
-
-const int deepspeed_aio_handle_t::get_block_size() const
-{
- return _aio_ctxt ? _aio_ctxt->_block_size : -1;
-}
-
-const int deepspeed_aio_handle_t::get_queue_depth() const
-{
- return _aio_ctxt ? _aio_ctxt->_queue_depth : -1;
-}
-
-const bool deepspeed_aio_handle_t::get_single_submit() const { return _single_submit; }
-
-const bool deepspeed_aio_handle_t::get_overlap_events() const { return _overlap_events; }
-
-const int deepspeed_aio_handle_t::get_thread_count() const { return _num_threads; }
-
-int deepspeed_aio_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate)
-{
- const auto start_time = std::chrono::high_resolution_clock::now();
-
- assert(_aio_ctxt);
-
- long long num_file_bytes;
- if (-1 == get_file_size(filename, num_file_bytes)) {
- const auto error_code = errno;
- report_file_error(filename, " fstat for read", error_code);
- return -1;
- }
- assert(static_cast(buffer.nbytes()) == num_file_bytes);
-
- const auto fd = open_file(filename, true);
- if (fd == -1) { return -1; }
-
- auto read_buffer = (char*)buffer.data_ptr();
- std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer));
-
- if (_aio_config._overlap_events) {
- do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
- } else {
- do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
- }
-
- close(fd);
- const std::chrono::duration aio_time =
- std::chrono::high_resolution_clock::now() - start_time;
-
- if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); }
- const std::chrono::duration fn_time =
- std::chrono::high_resolution_clock::now() - start_time;
- std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
- << " call = " << fn_time.count() * 1e6 << std::endl;
- return 0;
-}
-
-int deepspeed_aio_handle_t::write(const torch::Tensor& buffer,
- const char* filename,
- const bool validate)
+ : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, num_threads)
{
- assert(_aio_ctxt);
-
- const auto start_time = std::chrono::high_resolution_clock::now();
-
- const auto fd = open_file(filename, false);
- if (fd == -1) { return -1; }
-
- auto write_buffer = (char*)buffer.data_ptr();
- const auto num_write_bytes = static_cast(buffer.nbytes());
- std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer));
-
- if (_aio_config._overlap_events) {
- do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
- } else {
- do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
- }
- const std::chrono::duration aio_time =
- std::chrono::high_resolution_clock::now() - start_time;
-
- close(fd);
-
- if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); }
-
- const std::chrono::duration fn_time =
- std::chrono::high_resolution_clock::now() - start_time;
- std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
- << " call = " << fn_time.count() * 1e6 << std::endl;
- return 0;
}
-void deepspeed_aio_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op)
-{
- for (auto& ctxt : _thread_contexts) {
- {
- std::lock_guard lock(ctxt->_work_sync._mutex);
- ctxt->_work_queue.push(scheduled_op);
- }
- ctxt->_work_sync._cond_var.notify_one();
- }
- _num_pending_ops++;
-}
-
-std::shared_ptr deepspeed_aio_handle_t::_wait_for_aio_work()
-{
- std::shared_ptr completed_op = nullptr;
- for (auto& ctxt : _thread_contexts) {
- std::unique_lock lock(ctxt->_complete_sync._mutex);
- ctxt->_complete_sync._cond_var.wait(lock,
- [ctxt] { return !ctxt->_complete_queue.empty(); });
- completed_op = ctxt->_complete_queue.front();
- ctxt->_complete_queue.pop();
- }
- return completed_op;
-}
-
-void deepspeed_aio_handle_t::_stop_threads()
-{
- assert(0 == _num_pending_ops);
- for (auto& ctxt : _thread_contexts) {
- {
- std::lock_guard lock(ctxt->_work_sync._mutex);
- ctxt->_time_to_exit = true;
- }
- ctxt->_work_sync._cond_var.notify_one();
- }
-}
-
-int deepspeed_aio_handle_t::wait()
-{
- assert(_num_pending_ops > 0);
- auto num_completed_ops = 0;
-
- while (_num_pending_ops > 0) {
- auto completed_op = _wait_for_aio_work();
-
- completed_op->fini();
-
- close(completed_op->_fd);
-
- if (completed_op->_validate) {
- validate_aio_operation(completed_op->_read_op,
- completed_op->_filename.c_str(),
- completed_op->data_ptr(),
- _num_threads * completed_op->_num_bytes);
- }
- --_num_pending_ops;
- ++num_completed_ops;
- }
-
- return num_completed_ops;
-}
-
-bool deepspeed_aio_handle_t::_is_valid_parallel_aio_op(const bool read_op,
- const long long int num_bytes)
-{
- const auto op_string = read_op ? "Read" : "Write";
- if (num_bytes % get_thread_count()) {
- std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes
- << " not divisible by thread count = " << get_thread_count() << std::endl;
- return false;
- }
-
- return true;
-}
-
-int deepspeed_aio_handle_t::pread(const torch::Tensor& buffer,
- const char* filename,
- const bool validate,
- const bool async)
-{
- long long num_file_bytes;
- if (-1 == get_file_size(filename, num_file_bytes)) {
- const auto error_code = errno;
- report_file_error(filename, " fstat for read", error_code);
- return -1;
- }
- const auto buffer_bytes = static_cast(buffer.nbytes());
- if (buffer_bytes != num_file_bytes) {
- std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes
- << " != " << num_file_bytes << std::endl;
- }
- assert(static_cast(buffer.nbytes()) == num_file_bytes);
- assert((num_file_bytes % _num_threads) == 0);
-
- if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; }
-
- const auto fd = open_file(filename, true);
- if (fd == -1) { return -1; }
-
- auto scheduled_op = std::make_shared(
- true, buffer, fd, filename, (num_file_bytes / _num_threads), validate);
-
- _schedule_aio_work(scheduled_op);
-
- if (async) { return 0; }
-
- return wait();
-}
-
-int deepspeed_aio_handle_t::pwrite(const torch::Tensor& buffer,
- const char* filename,
- const bool validate,
- const bool async)
-{
- const auto num_write_bytes = static_cast(buffer.nbytes());
- assert((num_write_bytes % _num_threads) == 0);
-
- if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; }
-
- const auto fd = open_file(filename, false);
- if (fd == -1) { return -1; }
-
- auto scheduled_op = std::make_shared(
- false, buffer, fd, filename, (num_write_bytes / _num_threads), validate);
-
- _schedule_aio_work(scheduled_op);
-
- if (async) { return 0; }
-
- return wait();
-}
-
-int deepspeed_aio_handle_t::sync_pread(torch::Tensor& buffer, const char* filename)
-{
- return pread(buffer, filename, false, false);
-}
-
-int deepspeed_aio_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename)
-{
- return pwrite(buffer, filename, false, false);
-}
-
-int deepspeed_aio_handle_t::async_pread(torch::Tensor& buffer, const char* filename)
-{
- return pread(buffer, filename, false, true);
-}
-
-int deepspeed_aio_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename)
-{
- return pwrite(buffer, filename, false, true);
-}
-
-at::Tensor deepspeed_aio_handle_t::new_cpu_locked_tensor(const size_t num_elem,
- const torch::Tensor& example_tensor)
-{
- return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type());
-}
-
-bool deepspeed_aio_handle_t::free_cpu_locked_tensor(torch::Tensor& locked_tensor)
-{
- return _pinned_tensor_mgr->free(locked_tensor);
-}
+deepspeed_aio_handle_t::~deepspeed_aio_handle_t() {}
diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.h b/csrc/aio/py_lib/deepspeed_py_aio_handle.h
index 3a254c3814a2..eb6b90ea22f0 100644
--- a/csrc/aio/py_lib/deepspeed_py_aio_handle.h
+++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.h
@@ -9,21 +9,9 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
#include
#include
-#include "deepspeed_aio_thread.h"
-#include "deepspeed_pin_tensor.h"
-
-struct deepspeed_aio_handle_t {
- std::unique_ptr _aio_ctxt;
- const bool _single_submit;
- const bool _overlap_events;
- const int _num_threads;
- deepspeed_aio_config_t _aio_config;
-
- std::vector> _thread_contexts;
- std::vector _threads;
- int _num_pending_ops;
- std::unique_ptr _pinned_tensor_mgr;
+#include "deepspeed_py_io_handle.h"
+struct deepspeed_aio_handle_t : deepspeed_io_handle_t {
deepspeed_aio_handle_t(const int block_size,
const int queue_depth,
const bool single_submit,
@@ -31,47 +19,4 @@ struct deepspeed_aio_handle_t {
const int num_threads);
~deepspeed_aio_handle_t();
-
- const int get_block_size() const;
- const int get_queue_depth() const;
- const bool get_single_submit() const;
- const bool get_overlap_events() const;
- const int get_thread_count() const;
-
- int read(torch::Tensor& buffer, const char* filename, const bool validate);
-
- int write(const torch::Tensor& buffer, const char* filename, const bool validate);
-
- int pread(const torch::Tensor& buffer,
- const char* filename,
- const bool validate,
- const bool async);
-
- int pwrite(const torch::Tensor& buffer,
- const char* filename,
- const bool validate,
- const bool async);
-
- int sync_pread(torch::Tensor& buffer, const char* filename);
-
- int sync_pwrite(const torch::Tensor& buffer, const char* filename);
-
- int async_pread(torch::Tensor& buffer, const char* filename);
-
- int async_pwrite(const torch::Tensor& buffer, const char* filename);
-
- // TODO: Make API's args to be shape and dtype.
- torch::Tensor new_cpu_locked_tensor(const size_t num_elem, const torch::Tensor& example_tensor);
-
- bool free_cpu_locked_tensor(torch::Tensor&);
-
- int wait();
-
- void _stop_threads();
-
- void _schedule_aio_work(std::shared_ptr scheduled_op);
-
- std::shared_ptr _wait_for_aio_work();
-
- bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes);
};
diff --git a/csrc/aio/py_lib/deepspeed_py_copy.cpp b/csrc/aio/py_lib/deepspeed_py_copy.cpp
index c597b91d05c9..f5480e9d9d83 100644
--- a/csrc/aio/py_lib/deepspeed_py_copy.cpp
+++ b/csrc/aio/py_lib/deepspeed_py_copy.cpp
@@ -4,7 +4,7 @@
// DeepSpeed Team
/*
-Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
+Functionality for swapping tensors to/from (NVMe) storage devices.
*/
#include "deepspeed_py_copy.h"
diff --git a/csrc/aio/py_lib/deepspeed_py_copy.h b/csrc/aio/py_lib/deepspeed_py_copy.h
index 19ba28317d00..f443571a3e7b 100644
--- a/csrc/aio/py_lib/deepspeed_py_copy.h
+++ b/csrc/aio/py_lib/deepspeed_py_copy.h
@@ -4,9 +4,6 @@
// DeepSpeed Team
/*
-Copyright 2020 The Microsoft DeepSpeed Team
-Licensed under the MIT license.
-
Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
*/
diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.cpp b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp
new file mode 100644
index 000000000000..bdf2a858d797
--- /dev/null
+++ b/csrc/aio/py_lib/deepspeed_py_io_handle.cpp
@@ -0,0 +1,300 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*
+Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
+*/
+
+#include "deepspeed_py_io_handle.h"
+#include
+
+using namespace std;
+
+static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); }
+
+deepspeed_io_handle_t::deepspeed_io_handle_t(const int block_size,
+ const int queue_depth,
+ const bool single_submit,
+ const bool overlap_events,
+ const int num_threads)
+ : _aio_ctxt(new aio_context(block_size, queue_depth)),
+ _single_submit(single_submit),
+ _overlap_events(overlap_events),
+ _num_threads(num_threads),
+ _aio_config(block_size, queue_depth, single_submit, overlap_events, false),
+ _num_pending_ops(0),
+ _pinned_tensor_mgr(new deepspeed_pin_tensor_t())
+{
+ for (auto i = 0; i < num_threads; ++i) {
+ _thread_contexts.push_back(std::make_shared(i, _aio_config));
+ }
+
+ for (auto& ctxt : _thread_contexts) {
+ _threads.push_back(std::thread(_start_aio_thread, ctxt));
+ }
+}
+
+deepspeed_io_handle_t::~deepspeed_io_handle_t()
+{
+ _stop_threads();
+ for (auto& thr : _threads) { thr.join(); }
+}
+
+const int deepspeed_io_handle_t::get_block_size() const
+{
+ return _aio_ctxt ? _aio_ctxt->_block_size : -1;
+}
+
+const int deepspeed_io_handle_t::get_queue_depth() const
+{
+ return _aio_ctxt ? _aio_ctxt->_queue_depth : -1;
+}
+
+const bool deepspeed_io_handle_t::get_single_submit() const { return _single_submit; }
+
+const bool deepspeed_io_handle_t::get_overlap_events() const { return _overlap_events; }
+
+const int deepspeed_io_handle_t::get_thread_count() const { return _num_threads; }
+
+int deepspeed_io_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate)
+{
+ const auto start_time = std::chrono::high_resolution_clock::now();
+
+ assert(_aio_ctxt);
+
+ long long num_file_bytes;
+ if (-1 == get_file_size(filename, num_file_bytes)) {
+ const auto error_code = errno;
+ report_file_error(filename, " fstat for read", error_code);
+ return -1;
+ }
+ assert(static_cast(buffer.nbytes()) == num_file_bytes);
+
+ const auto fd = open_file(filename, true);
+ if (fd == -1) { return -1; }
+
+ auto read_buffer = (char*)buffer.data_ptr();
+ std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer));
+
+ if (_aio_config._overlap_events) {
+ do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
+ } else {
+ do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
+ }
+
+ close(fd);
+ const std::chrono::duration aio_time =
+ std::chrono::high_resolution_clock::now() - start_time;
+
+ if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); }
+ const std::chrono::duration fn_time =
+ std::chrono::high_resolution_clock::now() - start_time;
+ std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
+ << " call = " << fn_time.count() * 1e6 << std::endl;
+ return 0;
+}
+
+int deepspeed_io_handle_t::write(const torch::Tensor& buffer,
+ const char* filename,
+ const bool validate)
+{
+ assert(_aio_ctxt);
+
+ const auto start_time = std::chrono::high_resolution_clock::now();
+
+ const auto fd = open_file(filename, false);
+ if (fd == -1) { return -1; }
+
+ auto write_buffer = (char*)buffer.data_ptr();
+ const auto num_write_bytes = static_cast(buffer.nbytes());
+ std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer));
+
+ if (_aio_config._overlap_events) {
+ do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
+ } else {
+ do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr);
+ }
+ const std::chrono::duration aio_time =
+ std::chrono::high_resolution_clock::now() - start_time;
+
+ close(fd);
+
+ if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); }
+
+ const std::chrono::duration fn_time =
+ std::chrono::high_resolution_clock::now() - start_time;
+ std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6
+ << " call = " << fn_time.count() * 1e6 << std::endl;
+ return 0;
+}
+
+void deepspeed_io_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op)
+{
+ for (auto& ctxt : _thread_contexts) {
+ {
+ std::lock_guard lock(ctxt->_work_sync._mutex);
+ ctxt->_work_queue.push(scheduled_op);
+ }
+ ctxt->_work_sync._cond_var.notify_one();
+ }
+ _num_pending_ops++;
+}
+
+std::shared_ptr deepspeed_io_handle_t::_wait_for_aio_work()
+{
+ std::shared_ptr completed_op = nullptr;
+ for (auto& ctxt : _thread_contexts) {
+ std::unique_lock lock(ctxt->_complete_sync._mutex);
+ ctxt->_complete_sync._cond_var.wait(lock,
+ [ctxt] { return !ctxt->_complete_queue.empty(); });
+ completed_op = ctxt->_complete_queue.front();
+ ctxt->_complete_queue.pop();
+ }
+ return completed_op;
+}
+
+void deepspeed_io_handle_t::_stop_threads()
+{
+ assert(0 == _num_pending_ops);
+ for (auto& ctxt : _thread_contexts) {
+ {
+ std::lock_guard lock(ctxt->_work_sync._mutex);
+ ctxt->_time_to_exit = true;
+ }
+ ctxt->_work_sync._cond_var.notify_one();
+ }
+}
+
+int deepspeed_io_handle_t::wait()
+{
+ assert(_num_pending_ops > 0);
+ auto num_completed_ops = 0;
+
+ while (_num_pending_ops > 0) {
+ auto completed_op = _wait_for_aio_work();
+
+ if (completed_op->_validate) { completed_op->validate(); }
+
+ completed_op->finish();
+
+ close(completed_op->_fd);
+
+ --_num_pending_ops;
+ ++num_completed_ops;
+ }
+
+ return num_completed_ops;
+}
+
+bool deepspeed_io_handle_t::_is_valid_parallel_aio_op(const bool read_op,
+ const long long int num_bytes)
+{
+ const auto op_string = read_op ? "Read" : "Write";
+ if (num_bytes % get_thread_count()) {
+ std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes
+ << " not divisible by thread count = " << get_thread_count() << std::endl;
+ return false;
+ }
+
+ return true;
+}
+
+std::shared_ptr deepspeed_io_handle_t::_create_io_op_desc(
+ const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const bool validate)
+{
+ return std::make_shared(
+ read_op, buffer, fd, filename, file_num_bytes, _num_threads, validate);
+}
+
+int deepspeed_io_handle_t::pread(const torch::Tensor& buffer,
+ const char* filename,
+ const bool validate,
+ const bool async)
+{
+ long long num_file_bytes;
+ if (-1 == get_file_size(filename, num_file_bytes)) {
+ const auto error_code = errno;
+ report_file_error(filename, " fstat for read", error_code);
+ return -1;
+ }
+ const auto buffer_bytes = static_cast(buffer.nbytes());
+ if (buffer_bytes != num_file_bytes) {
+ std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes
+ << " != " << num_file_bytes << std::endl;
+ }
+ assert(static_cast(buffer.nbytes()) == num_file_bytes);
+ assert((num_file_bytes % _num_threads) == 0);
+
+ if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; }
+
+ const auto fd = open_file(filename, true);
+ if (fd == -1) { return -1; }
+
+ auto scheduled_op = _create_io_op_desc(true, buffer, fd, filename, num_file_bytes, validate);
+
+ _schedule_aio_work(scheduled_op);
+
+ if (async) { return 0; }
+
+ return wait();
+}
+
+int deepspeed_io_handle_t::pwrite(const torch::Tensor& buffer,
+ const char* filename,
+ const bool validate,
+ const bool async)
+{
+ const auto num_write_bytes = static_cast(buffer.nbytes());
+ assert((num_write_bytes % _num_threads) == 0);
+
+ if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; }
+
+ const auto fd = open_file(filename, false);
+ if (fd == -1) { return -1; }
+
+ auto scheduled_op = _create_io_op_desc(false, buffer, fd, filename, num_write_bytes, validate);
+
+ _schedule_aio_work(scheduled_op);
+
+ if (async) { return 0; }
+
+ return wait();
+}
+
+int deepspeed_io_handle_t::sync_pread(torch::Tensor& buffer, const char* filename)
+{
+ return pread(buffer, filename, false, false);
+}
+
+int deepspeed_io_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename)
+{
+ return pwrite(buffer, filename, false, false);
+}
+
+int deepspeed_io_handle_t::async_pread(torch::Tensor& buffer, const char* filename)
+{
+ return pread(buffer, filename, false, true);
+}
+
+int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename)
+{
+ return pwrite(buffer, filename, false, true);
+}
+
+at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const size_t num_elem,
+ const torch::Tensor& example_tensor)
+{
+ return _pinned_tensor_mgr->alloc(num_elem, example_tensor.scalar_type());
+}
+
+bool deepspeed_io_handle_t::free_cpu_locked_tensor(torch::Tensor& locked_tensor)
+{
+ return _pinned_tensor_mgr->free(locked_tensor);
+}
diff --git a/csrc/aio/py_lib/deepspeed_py_io_handle.h b/csrc/aio/py_lib/deepspeed_py_io_handle.h
new file mode 100644
index 000000000000..2974ebe87bfc
--- /dev/null
+++ b/csrc/aio/py_lib/deepspeed_py_io_handle.h
@@ -0,0 +1,85 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*
+Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
+*/
+
+#include
+#include
+#include "deepspeed_aio_thread.h"
+#include "deepspeed_pin_tensor.h"
+
+struct deepspeed_io_handle_t {
+ std::unique_ptr _aio_ctxt;
+ const bool _single_submit;
+ const bool _overlap_events;
+ const int _num_threads;
+ deepspeed_aio_config_t _aio_config;
+
+ std::vector> _thread_contexts;
+ std::vector _threads;
+ int _num_pending_ops;
+ std::unique_ptr _pinned_tensor_mgr;
+
+ deepspeed_io_handle_t(const int block_size,
+ const int queue_depth,
+ const bool single_submit,
+ const bool overlap_events,
+ const int num_threads);
+
+ virtual ~deepspeed_io_handle_t() = 0;
+
+ const int get_block_size() const;
+ const int get_queue_depth() const;
+ const bool get_single_submit() const;
+ const bool get_overlap_events() const;
+ const int get_thread_count() const;
+
+ int read(torch::Tensor& buffer, const char* filename, const bool validate);
+
+ int write(const torch::Tensor& buffer, const char* filename, const bool validate);
+
+ int pread(const torch::Tensor& buffer,
+ const char* filename,
+ const bool validate,
+ const bool async);
+
+ int pwrite(const torch::Tensor& buffer,
+ const char* filename,
+ const bool validate,
+ const bool async);
+
+ int sync_pread(torch::Tensor& buffer, const char* filename);
+
+ int sync_pwrite(const torch::Tensor& buffer, const char* filename);
+
+ int async_pread(torch::Tensor& buffer, const char* filename);
+
+ int async_pwrite(const torch::Tensor& buffer, const char* filename);
+
+ // TODO: Make API's args to be shape and dtype.
+ torch::Tensor new_cpu_locked_tensor(const size_t num_elem, const torch::Tensor& example_tensor);
+
+ bool free_cpu_locked_tensor(torch::Tensor&);
+
+ int wait();
+
+ void _stop_threads();
+
+ void _schedule_aio_work(std::shared_ptr scheduled_op);
+
+ std::shared_ptr _wait_for_aio_work();
+
+ bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes);
+
+ virtual std::shared_ptr _create_io_op_desc(
+ const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const bool validate);
+};
diff --git a/csrc/aio/py_lib/py_ds_aio.cpp b/csrc/aio/py_lib/py_ds_aio.cpp
old mode 100755
new mode 100644
index 9033549bc0d2..3171d0c6bf3c
--- a/csrc/aio/py_lib/py_ds_aio.cpp
+++ b/csrc/aio/py_lib/py_ds_aio.cpp
@@ -10,6 +10,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
#include
#include "deepspeed_py_aio_handle.h"
#include "deepspeed_py_copy.h"
+using namespace pybind11::literals;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
@@ -20,7 +21,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("deepspeed_memcpy", &deepspeed_py_memcpy, "DeepSpeed Memory Copy");
py::class_(m, "aio_handle")
- .def(py::init())
+ .def(py::init(),
+ "AIO handle constructor",
+ "block_size"_a = 1024 * 1024,
+ "queue_depth"_a = 128,
+ "single_submit"_a = false,
+ "overlap_events"_a = false,
+ "num_threads"_a = 1)
.def("get_block_size", &deepspeed_aio_handle_t::get_block_size)
.def("get_queue_depth", &deepspeed_aio_handle_t::get_queue_depth)
@@ -28,19 +35,74 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("get_overlap_events", &deepspeed_aio_handle_t::get_overlap_events)
.def("get_thread_count", &deepspeed_aio_handle_t::get_thread_count)
- .def("read", &deepspeed_aio_handle_t::read)
- .def("write", &deepspeed_aio_handle_t::write)
+ .def("read",
+ &deepspeed_aio_handle_t::read,
+ "Synchronous and non-parallel file read. Returns count of completed read ops",
+ "buffer"_a,
+ "filename"_a,
+ "validate"_a)
- .def("pread", &deepspeed_aio_handle_t::pread)
- .def("pwrite", &deepspeed_aio_handle_t::pwrite)
+ .def("write",
+ &deepspeed_aio_handle_t::write,
+ "Synchronous and non-parallel file write. Returns count of completed write ops",
+ "buffer"_a,
+ "filename"_a,
+ "validate"_a)
- .def("sync_pread", &deepspeed_aio_handle_t::sync_pread)
- .def("sync_pwrite", &deepspeed_aio_handle_t::sync_pwrite)
- .def("async_pread", &deepspeed_aio_handle_t::async_pread)
- .def("async_pwrite", &deepspeed_aio_handle_t::async_pwrite)
+ .def("pread",
+ &deepspeed_aio_handle_t::pread,
+ "Parallel file read with option of parallelism. Returns count of completed read ops",
+ "buffer"_a,
+ "filename"_a,
+ "validate"_a,
+ "async"_a)
- .def("new_cpu_locked_tensor", &deepspeed_aio_handle_t::new_cpu_locked_tensor)
- .def("free_cpu_locked_tensor", &deepspeed_aio_handle_t::free_cpu_locked_tensor)
+ .def("pwrite",
+ &deepspeed_aio_handle_t::pwrite,
+ "Parallel file write with option of parallelism. Returns count of completed write ops",
+ "buffer"_a,
+ "filename"_a,
+ "validate"_a,
+ "async"_a)
- .def("wait", &deepspeed_aio_handle_t::wait);
+ .def("sync_pread",
+ &deepspeed_aio_handle_t::sync_pread,
+ "Synchrononous parallel file read. Returns count of completed read ops",
+ "buffer"_a,
+ "filename"_a)
+
+ .def("sync_pwrite",
+ &deepspeed_aio_handle_t::sync_pwrite,
+ "Synchronous parallel file write. Returns count of completed write ops",
+ "buffer"_a,
+ "filename"_a)
+
+ .def("async_pread",
+ &deepspeed_aio_handle_t::async_pread,
+ "Asynchronous parallel file read. Returns 0 on success. Returns 0 on success, and "
+ "following wait() returns count of completed ops.",
+ "buffer"_a,
+ "filename"_a)
+
+ .def("async_pwrite",
+ &deepspeed_aio_handle_t::async_pwrite,
+ "Asynchronous parallel file write. Returns 0 on success, and following wait() returns "
+ "count of completed ops.",
+ "buffer"_a,
+ "filename"_a)
+
+ .def("new_cpu_locked_tensor",
+ &deepspeed_aio_handle_t::new_cpu_locked_tensor,
+ "Allocate pinned CPU tensor.",
+ "num_elem"_a,
+ "example_tenosr"_a)
+
+ .def("free_cpu_locked_tensor",
+ &deepspeed_aio_handle_t::free_cpu_locked_tensor,
+ "Free pinned CPU tensor.",
+ "tensor"_a)
+
+ .def("wait",
+ &deepspeed_aio_handle_t::wait,
+ "Wait for (ongoing) asynchronous operations to complete");
}
diff --git a/csrc/aio/py_test/aio_bench_generate_param.py b/csrc/aio/py_test/aio_bench_generate_param.py
index 09d0e03c7ef6..7a0ab59ed73d 100644
--- a/csrc/aio/py_test/aio_bench_generate_param.py
+++ b/csrc/aio/py_test/aio_bench_generate_param.py
@@ -41,9 +41,9 @@ def convert_to_param(key):
return {
"single_submit": "true" if key[0] == "single" else "false",
"overlap_events": "true" if key[1] == "overlap" else "false",
- "thread_count": int(key[3]),
- "queue_depth": int(key[4]),
- "block_size": int(key[5])
+ "thread_count": int(key[5]),
+ "queue_depth": int(key[3]),
+ "block_size": int(key[4])
}
diff --git a/csrc/aio/py_test/aio_bench_perf_sweep.py b/csrc/aio/py_test/aio_bench_perf_sweep.py
index 7d55f7ded65c..ba95150b11e1 100644
--- a/csrc/aio/py_test/aio_bench_perf_sweep.py
+++ b/csrc/aio/py_test/aio_bench_perf_sweep.py
@@ -10,75 +10,47 @@
import argparse
import json
import itertools
-import subprocess
import shutil
-from test_ds_aio_utils import refine_integer_value
+from ds_aio_job import Job, run_job
from perf_sweep_utils import READ_OP_DESC, WRITE_OP_DESC, BENCH_LOG_DIR, \
- READ_IO_DIR, WRITE_IO_DIR, READ_LOG_DIR, WRITE_LOG_DIR
+ READ_LOG_DIR, WRITE_LOG_DIR
from deepspeed.ops.op_builder import AsyncIOBuilder
OTHER_OPTIONS = '--handle'
PERF_SCRIPT = 'test_ds_aio.py'
DEFAULT_SWEEP_CONFIG = {
- "block_size": ["128K", "256K"],
- "queue_depth": [4, 16, 32],
- "overlap_events": [True, False],
- "io_parallel": [2, 8],
- "single_submit": [False]
+ "block_size": ["128K", "1M"],
+ "queue_depth": [32, 64, 128],
+ "sequential_requests": [True, False],
+ "single_submit": [False],
+ "io_parallel": [1, 2, 8],
}
-class Job(object):
-
- def __init__(self, cmd_line, output_file=None, work_dir=None):
- self.cmd_line = cmd_line
- self.output_file = output_file
- self.work_dir = work_dir
- self.output_fd = None
-
- def cmd(self):
- return self.cmd_line
-
- def get_stdout(self):
- return self.output_fd
-
- def get_stderr(self):
- return self.output_fd
-
- def get_cwd(self):
- return self.work_dir
-
- def open_output_file(self):
- if self.output_file is not None:
- self.output_fd = open(self.output_file, 'w')
-
- def close_output_file(self):
- if self.output_fd is not None:
- self.output_fd.close()
- self.output_fd = None
-
-
class SweepConfig(object):
def __init__(self, args):
- self.nvme_dir = args.nvme_dir
- self.io_size = args.io_size
+ self.folder_to_device_mapping = get_ftd_map(args.nvme_dir)
self.search_space = get_sweep_config_dict(args.sweep_config)
+ self.search_space.update(self.folder_to_device_mapping)
self.read = not args.no_read
self.write = not args.no_write
self.flush_cache = not args.no_sudo
self.log_dir = args.log_dir
- self.loops = args.loops
- self.other_options = f'{OTHER_OPTIONS} --loops {args.loops}'
+ self.other_options = f'{OTHER_OPTIONS} --loops {args.loops} --io_size {args.io_size}'
+ if args.gpu:
+ self.other_options += ' --gpu'
+ if args.gds:
+ self.other_options += ' --use_gds'
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--nvme_dir',
+ nargs='+',
required=True,
- type=str,
help='Directory in which to perform I/O tests. A writeable directory on a NVMe device.')
parser.add_argument('--sweep_config', type=str, default=None, help='Performance sweep configuration json file.')
@@ -92,6 +64,10 @@ def parse_arguments():
default="400M",
help='Number of I/O bytes to read/write for performance measurements.')
+ parser.add_argument('--gpu', action='store_true', help='Test tensor transfers between GPU device and NVME device.')
+
+ parser.add_argument('--gds', action='store_true', help='Run the sweep over NVIDIA GPUDirectStorage operator')
+
parser.add_argument(
'--no_sudo',
action='store_true',
@@ -118,6 +94,12 @@ def dump_cmd_lines(cmd_lines):
print(f'{i}: {cmd}')
+def get_ftd_map(nvme_dir_list):
+ ftd_list = [f'{dir}:{dev}' for dev, dir in enumerate(nvme_dir_list)]
+ ftd_arg = [' '.join(ftd for ftd in ftd_list)]
+ return {'folder_to_device_mapping': ftd_arg}
+
+
def get_sweep_config_dict(sweep_config_json):
if sweep_config_json is None:
return DEFAULT_SWEEP_CONFIG
@@ -148,16 +130,6 @@ def flatten_options(key, value_list):
return cmd_list
-def run_job(job):
- args = ' '.join(job.cmd())
- print(f'args = {args}')
- job.open_output_file()
- proc = subprocess.run(args=args, shell=True, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd())
- job.close_output_file()
- assert proc.returncode == 0, \
- f"This command failed: {job.cmd()}"
-
-
def launch_sweep(sweep_jobs, sync_job, flush_cache_job):
for perf_job in sweep_jobs:
if flush_cache_job is not None:
@@ -176,7 +148,12 @@ def create_cmd_tags(cmd_line):
if len(fields) == 1:
tags[fields[0]] = None
elif len(fields) == 2:
- tags[fields[0]] = fields[1]
+ if fields[0] == '--folder_to_device_mapping':
+ tags[fields[0]] = len(fields[1:])
+ else:
+ tags[fields[0]] = fields[1]
+ elif len(fields) > 2:
+ tags[fields[0]] = len(fields[1:])
return tags
@@ -184,16 +161,16 @@ def get_log_file(io_op_desc, cmd_line):
QUEUE_DEPTH = "--queue_depth"
BLOCK_SIZE = "--block_size"
SINGLE_SUBMIT = "--single_submit"
- OVERLAP_EVENTS = "--overlap_events"
- THREAD_COUNT = "--threads"
+ SEQUENTIAL_REQUESTS = "--sequential_requests"
+ FTD_MAP = "--folder_to_device_mapping"
IO_PARALLEL = "--io_parallel"
tag_map = {
QUEUE_DEPTH: "d",
BLOCK_SIZE: "bs",
SINGLE_SUBMIT: "single",
- OVERLAP_EVENTS: "overlap",
- THREAD_COUNT: "t",
+ SEQUENTIAL_REQUESTS: "sequential",
+ FTD_MAP: "ftd",
IO_PARALLEL: "p"
}
@@ -201,14 +178,14 @@ def get_log_file(io_op_desc, cmd_line):
QUEUE_DEPTH: 1,
BLOCK_SIZE: "1M",
SINGLE_SUBMIT: "block",
- OVERLAP_EVENTS: "sequential",
- THREAD_COUNT: 1,
+ SEQUENTIAL_REQUESTS: "overlap",
+ FTD_MAP: 1,
IO_PARALLEL: 1
}
def get_default_value(tag):
value = tag_default[tag]
- if tag in [SINGLE_SUBMIT, OVERLAP_EVENTS]:
+ if tag in [SINGLE_SUBMIT, SEQUENTIAL_REQUESTS]:
return value
return f'{tag_map[tag]}{value}'
@@ -218,7 +195,7 @@ def get_config_value(tag, value):
return tag_key
return f'{tag_key}{value}'
- tag_list = [SINGLE_SUBMIT, OVERLAP_EVENTS, THREAD_COUNT, IO_PARALLEL, QUEUE_DEPTH, BLOCK_SIZE]
+ tag_list = [SINGLE_SUBMIT, SEQUENTIAL_REQUESTS, FTD_MAP, QUEUE_DEPTH, BLOCK_SIZE, IO_PARALLEL]
log_tags = [io_op_desc]
cmd_tags = create_cmd_tags(cmd_line)
for tag in tag_list:
@@ -252,40 +229,14 @@ def async_io_setup():
return AsyncIOBuilder().is_compatible()
-def get_block_size_and_count(io_bytes):
- block_size = 1
- block_count = io_bytes
- bytes_in_KB = 1024
-
- while block_count % bytes_in_KB == 0:
- block_size *= bytes_in_KB
- block_count /= bytes_in_KB
-
- return int(block_size), int(block_count)
-
-
-def create_read_file(sweep_config):
- read_folder = os.path.join(sweep_config.nvme_dir, f'{READ_IO_DIR}')
- os.makedirs(read_folder, exist_ok=True)
- read_file_name = os.path.join(read_folder, f'random_{sweep_config.io_size}B.pt')
- block_size, block_count = get_block_size_and_count(refine_integer_value(sweep_config.io_size))
- dd_job = Job(cmd_line=[f'dd if=/dev/urandom of={read_file_name} bs={block_size} count={block_count}'])
- print(f'[Start] Create read file of {sweep_config.io_size} bytes by running {dd_job.cmd()} ....')
- run_job(dd_job)
- print(f'[Done] Create read file of {sweep_config.io_size} bytes by running {dd_job.cmd()} ....')
- return read_folder, read_file_name
-
-
def remove_folder(folder):
assert os.path.isdir(folder), f"Error: cannot remove {folder} - folder not found"
shutil.rmtree(folder)
def run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines):
- read_folder, read_file_name = create_read_file(sweep_config)
- read_option = f'--read_file {read_file_name}'
- read_cmd_lines = [[f'{read_option} {sweep_config.other_options}'] + cmd for cmd in cmd_lines]
- #dump_cmd_lines(read_cmd_lines)
+ read_cmd_lines = [[f'--read {sweep_config.other_options}'] + cmd for cmd in cmd_lines]
+ #dump_cmd_lines(cmd_lines)
log_folder = os.path.join(sweep_config.log_dir, f'{READ_LOG_DIR}')
os.makedirs(log_folder, exist_ok=True)
@@ -294,15 +245,9 @@ def run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines):
launch_sweep(sweep_jobs=perf_jobs, sync_job=sync_job, flush_cache_job=flush_cache_job)
- remove_folder(read_folder)
-
def run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines):
- write_folder = os.path.join(sweep_config.nvme_dir, f'{WRITE_IO_DIR}')
- os.makedirs(write_folder, exist_ok=True)
- write_file_name = os.path.join(write_folder, f'random_{sweep_config.io_size}B.pt')
- write_option = f'--write_size {sweep_config.io_size} --write_file {write_file_name}'
- write_cmd_lines = [[f'{write_option} {sweep_config.other_options}'] + cmd for cmd in cmd_lines]
+ write_cmd_lines = [[f'{sweep_config.other_options}'] + cmd for cmd in cmd_lines]
#dump_cmd_lines(write_cmd_lines)
log_folder = os.path.join(sweep_config.log_dir, f'{WRITE_LOG_DIR}')
@@ -312,8 +257,6 @@ def run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines):
launch_sweep(sweep_jobs=perf_jobs, sync_job=sync_job, flush_cache_job=flush_cache_job)
- remove_folder(write_folder)
-
def main():
print("Running performance sweep of deepspeed nvme library")
diff --git a/csrc/aio/py_test/ds_aio_args.py b/csrc/aio/py_test/ds_aio_args.py
new file mode 100644
index 000000000000..346feabe4810
--- /dev/null
+++ b/csrc/aio/py_test/ds_aio_args.py
@@ -0,0 +1,175 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""
+Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
+"""
+
+import argparse
+import os
+from test_ds_aio_utils import refine_integer_value
+from deepspeed.accelerator import get_accelerator
+
+MAPPING_DELIMITER = ':'
+
+
+def refine_args(args):
+ if args.io_size and type(args.io_size) == str:
+ args.io_size = refine_integer_value(args.io_size)
+
+ if args.block_size and type(args.block_size) == str:
+ args.block_size = refine_integer_value(args.block_size)
+
+ return args
+
+
+def _get_mapping_dict(args):
+ if args.folder is not None:
+ d = {i: args.folder for i in range(args.multi_process)}
+ else:
+ d = {}
+ for m in args.folder_to_device_mapping:
+ fields = m.split(MAPPING_DELIMITER)
+ d[fields[1]] = fields[0]
+
+ return d
+
+
+def _validate_folder_mapping(args):
+ no_error = True
+ error_messages = []
+ invalid_mappings = [m for m in args.folder_to_device_mapping if MAPPING_DELIMITER not in m]
+ if len(invalid_mappings) > 0:
+ error_messages.append(
+ f'Missing delimiter ({MAPPING_DELIMITER}) in folder_to_device_mapping {invalid_mappings}')
+ no_error = False
+
+ folder_list = [m.split(MAPPING_DELIMITER)[0] for m in args.folder_to_device_mapping]
+ invalid_folders = [d for d in folder_list if not os.path.exists(d)]
+ if len(invalid_folders) > 0:
+ error_messages.append(f'Invalid folders in folder_to_device_mapping: {invalid_folders}')
+ no_error = False
+
+ if args.gpu:
+ device_list = [int(m.split(MAPPING_DELIMITER)[1]) for m in args.folder_to_device_mapping]
+ invalid_device_list = [dev_id for dev_id in device_list if not dev_id < get_accelerator().device_count()]
+ if len(invalid_device_list) > 0:
+ error_messages.append(f'Invalid device ids in folder_to_device_mapping: {invalid_device_list}')
+ no_error = False
+
+ return no_error, error_messages
+
+
+def validate_args(args):
+ no_error = True
+ error_messages = []
+
+ if args.folder is not None and len(args.folder_to_device_mapping) > 0:
+ error_messages.append(f'--folder and --folder_to_device_mapping cannot be specified together.')
+ no_error = False
+ elif args.folder is None and len(args.folder_to_device_mapping) == 0:
+ error_messages.append(f'At least one of --folder or --folder_to_device_mapping must be specified.')
+ no_error = False
+
+ # Validate --folder
+ if args.folder is not None and not os.path.exists(args.folder):
+ no_error = False
+ error_messages.append(f'Invalid folder in --folder: {args.folder} ')
+
+ # Validate --folder_mapping_to_device
+ if len(args.folder_to_device_mapping) > 0:
+ no_mapping_error, mapping_error_messages = _validate_folder_mapping(args)
+ no_error = no_error and no_mapping_error
+ error_messages += mapping_error_messages
+
+ # Validate --gpu, --use_gds
+ if args.use_gds and not args.gpu:
+ error_messages.append(f'--gpu must be set to transfer with --use_gds')
+ no_error = False
+
+ if not no_error:
+ print(f'Found {len(error_messages)} validation errors')
+ for i, msg in enumerate(error_messages):
+ print(f'{i+1}: {msg}')
+
+ return no_error
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--folder', default=None, type=str, help='Folder to use for I/O.')
+
+ parser.add_argument('--folder_to_device_mapping',
+ default=[],
+ nargs='+',
+ help='Specification of mapping of folder to (gpu) device id, (ignored for cpu accesses).'
+ 'Can be specified multiple times for multi-process runs,'
+ 'e.g. --folder_to_device_mapping /mnt/nvme0:0 --folder_to_device_mapping /mnt/nvme1:15 --gpu'
+ 'means access /mnt/nvme0 with gpu 0 and /mnt/nvme1 with gpu 15')
+
+ parser.add_argument('--io_size', type=str, default=None, required=True, help='Number of bytes to read or write.')
+
+ parser.add_argument('--read', action='store_true', help='Perform read I/O (default is write)')
+
+ parser.add_argument('--multi_process',
+ type=int,
+ default=1,
+ help='Number of parallel processes doing I/O (default 1).')
+
+ parser.add_argument('--block_size',
+ type=str,
+ default='1M',
+ help='I/O block size. Can use K, M, or G suffix (default 1M for 1 megabytes).')
+
+ parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth (default 32).')
+
+ parser.add_argument('--single_submit',
+ action='store_true',
+ help='Submit I/O requests in singles (default is submit queue_depth amount at once.).')
+
+ parser.add_argument(
+ '--sequential_requests',
+ action='store_true',
+ help=
+ 'Delay I/O request submission until completion of prior requests (default is overlap I/O submission and completion requests.).'
+ )
+
+ parser.add_argument('--validate', action='store_true', help='Perform validation of I/O transfer in library.')
+
+ parser.add_argument('--handle', action='store_true', help='Use AIO handle.')
+
+ parser.add_argument('--loops', type=int, default=3, help='Count of operation repetitions')
+
+ parser.add_argument('--io_parallel', type=int, default=None, help='Per iop parallelism')
+
+ parser.add_argument('--gpu', action='store_true', help='Use GPU memory')
+
+ parser.add_argument('--use_gds', action='store_true', help='Enable GDS AIO')
+
+ parser.add_argument('--slow_bounce_buffer',
+ action='store_true',
+ help='For GPU memory transfers, measure impact of bounce buffer pinning on critical path.')
+
+ args = parser.parse_args()
+ print(f'args = {args}')
+ return args
+
+
+def get_validated_args():
+ args = parse_arguments()
+ args = refine_args(args)
+ if not validate_args(args):
+ quit()
+ print(f'Successful validation of command line arguments')
+
+ peer_tag = 'gpu' if args.gpu else 'process'
+ args.mapping_dict = _get_mapping_dict(args)
+ args.mapping_list = [(device_id, folder) for device_id, folder in args.mapping_dict.items()]
+ assert len(args.mapping_dict) == len(args.mapping_list)
+ print(f'Configuring {len(args.mapping_list)} {peer_tag} to folder mapping')
+ for i, (device_id, folder) in enumerate(args.mapping_list):
+ print(f'[{i}]: {peer_tag} {device_id} <----> {folder}')
+
+ return args
diff --git a/csrc/aio/py_test/ds_aio_basic.py b/csrc/aio/py_test/ds_aio_basic.py
index ad2a4349cd0c..9b3c7cbfc49f 100755
--- a/csrc/aio/py_test/ds_aio_basic.py
+++ b/csrc/aio/py_test/ds_aio_basic.py
@@ -9,10 +9,9 @@
import torch
import os
import time
+from deepspeed.ops.aio import AsyncIOBuilder
from multiprocessing import Pool, Barrier
from test_ds_aio_utils import report_results, task_log, task_barrier
-from deepspeed.accelerator import get_accelerator
-from deepspeed.ops.op_builder import AsyncIOBuilder
def pre_basic(args, tid, read_op):
@@ -21,7 +20,7 @@ def pre_basic(args, tid, read_op):
file = args.read_file if read_op else f'{args.write_file}.{tid}'
task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
- buffer = get_accelerator().pin_memory(torch.empty(num_bytes, dtype=torch.uint8, device='cpu'))
+ buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory()
task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}')
ctxt = {}
@@ -56,7 +55,7 @@ def main_basic_read(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_read(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth,
- args.single_submit, args.overlap_events, args.validate)
+ args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
@@ -67,7 +66,7 @@ def main_basic_write(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_write(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth,
- args.single_submit, args.overlap_events, args.validate)
+ args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
@@ -90,16 +89,17 @@ def get_schedule(args, read_op):
def _aio_handle_tasklet(pool_params):
args, tid, read_op = pool_params
+ num_processes = len(args.mapping_dict)
# Create schedule
schedule = get_schedule(args, read_op)
task_log(tid, f'schedule = {schedule}')
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
# Run pre task
task_log(tid, f'running pre-task')
ctxt = schedule["pre"]((args, tid))
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
# Run main tasks in a loop
ctxt["main_task_sec"] = 0
@@ -107,14 +107,14 @@ def _aio_handle_tasklet(pool_params):
task_log(tid, f'running main task {i}')
start_time = time.time()
ctxt = schedule["main"]((args, tid, ctxt))
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
stop_time = time.time()
ctxt["main_task_sec"] += stop_time - start_time
# Run post task
task_log(tid, f'running post-task')
ctxt = schedule["post"]((args, tid, ctxt))
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
@@ -125,9 +125,10 @@ def _init_tasklet(b):
def aio_basic_multiprocessing(args, read_op):
- b = Barrier(args.threads)
- pool_params = [(args, p, read_op) for p in range(args.threads)]
- with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p:
+ num_processes = len(args.mapping_dict)
+ b = Barrier(num_processes)
+ pool_params = [(args, p, read_op) for p in range(num_processes)]
+ with Pool(processes=num_processes, initializer=_init_tasklet, initargs=(b, )) as p:
pool_results = p.map(_aio_handle_tasklet, pool_params)
report_results(args, read_op, pool_results)
diff --git a/csrc/aio/py_test/ds_aio_handle.py b/csrc/aio/py_test/ds_aio_handle.py
index d35b2713edae..f4a179deb9ec 100755
--- a/csrc/aio/py_test/ds_aio_handle.py
+++ b/csrc/aio/py_test/ds_aio_handle.py
@@ -10,40 +10,56 @@
import os
import time
from multiprocessing import Pool, Barrier
-from test_ds_aio_utils import report_results, task_log, task_barrier
+from deepspeed.ops.aio import AsyncIOBuilder
+from deepspeed.ops.op_builder import GDSBuilder
+from test_ds_aio_utils import report_results, task_log, task_barrier, create_filename, create_file
from deepspeed.accelerator import get_accelerator
-from deepspeed.ops.op_builder import AsyncIOBuilder
+
+BUFFER = 'buffer'
+BOUNCE_BUFFER = 'bounce_buffer'
def pre_handle(args, tid, read_op):
io_string = "Read" if read_op else "Write"
- num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size
- file = args.read_file if read_op else f'{args.write_file}.{tid}'
-
- io_parallel = args.io_parallel if args.io_parallel else 1
- handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit,
- args.overlap_events, io_parallel)
- task_log(tid, f'Created deepspeed aio handle')
-
+ gds = True if args.use_gds else False
+ device_id, folder = args.mapping_list[tid]
+ filename = create_filename(folder, args.read, args.io_size, tid)
+ if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size):
+ create_file(filename, args.io_size)
+
+ task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
+ bounce_buffer = None
if args.gpu:
- buffer = torch.empty(num_bytes, dtype=torch.uint8, device=get_accelerator().device_name())
+ device_name = get_accelerator().device_name(device_id)
+ buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=device_name)
+ if not (args.slow_bounce_buffer or gds):
+ bounce_buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8,
+ device='cpu').pin_memory()
else:
- if args.use_accelerator_pin_memory:
- buffer = get_accelerator().pin_memory(torch.empty(num_bytes, dtype=torch.uint8, device='cpu'))
- else:
- buffer = handle.new_cpu_locked_tensor(num_bytes, torch.empty(0, dtype=torch.uint8))
+ buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device='cpu').pin_memory()
+ task_log(tid,
+ f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}',
+ force=True)
- task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
+ io_parallel = args.io_parallel if args.io_parallel else 1
+ if gds:
+ handle = GDSBuilder().load().gds_handle(args.block_size, args.queue_depth, args.single_submit,
+ not args.sequential_requests, io_parallel)
+ handle.pin_device_tensor(buffer)
+ else:
+ handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit,
+ not args.sequential_requests, io_parallel)
+ task_log(tid, f'created deepspeed aio handle')
ctxt = {}
- ctxt['file'] = file
- ctxt['num_bytes'] = num_bytes
+ ctxt['file'] = filename
+ ctxt['num_bytes'] = args.io_size
ctxt['handle'] = handle
- ctxt['buffer'] = buffer
+ ctxt['gds'] = gds
+ ctxt[BUFFER] = buffer
+ ctxt[BOUNCE_BUFFER] = bounce_buffer
ctxt['elapsed_sec'] = 0
- task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}')
-
return ctxt
@@ -61,8 +77,12 @@ def pre_handle_write(pool_params):
def post_handle(pool_params):
_, _, ctxt = pool_params
- ctxt["buffer"].detach()
- ctxt["buffer"] = None
+ for buf in [BUFFER, BOUNCE_BUFFER]:
+ if ctxt[buf] is not None:
+ if ctxt['gds']:
+ ctxt['handle'].unpin_device_tensor(ctxt[buf])
+ ctxt[buf].detach()
+ ctxt[buf] = None
return ctxt
@@ -71,20 +91,31 @@ def main_parallel_read(pool_params):
handle = ctxt['handle']
start_time = time.time()
- ret = handle.pread(ctxt['buffer'], ctxt['file'], args.validate, True)
+ dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER
+ ret = handle.pread(ctxt[dest_buffer], ctxt['file'], args.validate, True)
assert ret != -1
handle.wait()
+ if dest_buffer == BOUNCE_BUFFER:
+ ctxt[BUFFER].data.copy_(ctxt[BOUNCE_BUFFER].data)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
-
return ctxt
def main_parallel_write(pool_params):
args, tid, ctxt = pool_params
+ # Avoid overwriting existing files as it could be artificially faster
+ if os.path.isfile(ctxt['file']):
+ os.remove(ctxt['file'])
+
handle = ctxt['handle']
start_time = time.time()
- ret = handle.pwrite(ctxt['buffer'], ctxt['file'], args.validate, True)
+ if ctxt[BOUNCE_BUFFER] is not None:
+ source_buffer = BOUNCE_BUFFER
+ ctxt[BOUNCE_BUFFER].data.copy_(ctxt[BUFFER].data)
+ else:
+ source_buffer = BUFFER
+ ret = handle.pwrite(ctxt[source_buffer], ctxt['file'], args.validate, True)
assert ret != -1
handle.wait()
end_time = time.time()
@@ -98,8 +129,11 @@ def main_handle_read(pool_parms):
handle = ctxt['handle']
start_time = time.time()
- ret = handle.read(ctxt['buffer'], ctxt['file'], args.validate)
+ dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER
+ ret = handle.read(ctxt[dest_buffer], ctxt['file'], args.validate)
assert ret != -1
+ if dest_buffer == BOUNCE_BUFFER:
+ ctxt[BUFFER].data.copy_(ctxt[BOUNCE_BUFFER].data)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
@@ -108,9 +142,18 @@ def main_handle_read(pool_parms):
def main_handle_write(pool_parms):
args, tid, ctxt = pool_parms
+ # Avoid overwriting existing files as it could be artificially faster
+ if os.path.isfile(ctxt['file']):
+ os.remove(ctxt['file'])
+
handle = ctxt['handle']
start_time = time.time()
- ret = handle.write(ctxt['buffer'], ctxt['file'], args.validate)
+ if ctxt[BOUNCE_BUFFER] is not None:
+ source_buffer = BOUNCE_BUFFER
+ ctxt[BOUNCE_BUFFER].data.copy_(ctxt[BUFFER].data)
+ else:
+ source_buffer = BUFFER
+ ret = handle.write(ctxt[source_buffer], ctxt['file'], args.validate)
assert ret != -1
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
@@ -123,27 +166,28 @@ def get_schedule(args, read_op):
if read_op:
schedule['pre'] = pre_handle_read
schedule['post'] = post_handle
- schedule['main'] = main_parallel_read if args.io_parallel else main_handle_read
+ schedule['main'] = main_parallel_read
else:
schedule['pre'] = pre_handle_write
schedule['post'] = post_handle
- schedule['main'] = main_parallel_write if args.io_parallel else main_handle_write
+ schedule['main'] = main_parallel_write
return schedule
def _aio_handle_tasklet(pool_params):
args, tid, read_op = pool_params
+ num_processes = len(args.mapping_dict)
# Create schedule
schedule = get_schedule(args, read_op)
task_log(tid, f'schedule = {schedule}')
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
# Run pre task
task_log(tid, f'running pre-task')
ctxt = schedule["pre"]((args, tid))
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
# Run main tasks in a loop
ctxt["main_task_sec"] = 0
@@ -151,14 +195,14 @@ def _aio_handle_tasklet(pool_params):
task_log(tid, f'running main task {i}')
start_time = time.time()
ctxt = schedule["main"]((args, tid, ctxt))
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
stop_time = time.time()
ctxt["main_task_sec"] += stop_time - start_time
# Run post task
task_log(tid, f'running post-task')
ctxt = schedule["post"]((args, tid, ctxt))
- task_barrier(aio_barrier, args.threads)
+ task_barrier(aio_barrier, num_processes)
return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
@@ -169,9 +213,10 @@ def _init_tasklet(b):
def aio_handle_multiprocessing(args, read_op):
- b = Barrier(args.threads)
- pool_params = [(args, p, read_op) for p in range(args.threads)]
- with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p:
+ num_processes = len(args.mapping_dict)
+ b = Barrier(num_processes)
+ pool_params = [(args, p, read_op) for p in range(num_processes)]
+ with Pool(processes=num_processes, initializer=_init_tasklet, initargs=(b, )) as p:
pool_results = p.map(_aio_handle_tasklet, pool_params)
report_results(args, read_op, pool_results)
diff --git a/csrc/aio/py_test/ds_aio_job.py b/csrc/aio/py_test/ds_aio_job.py
new file mode 100644
index 000000000000..bbddee1bf26d
--- /dev/null
+++ b/csrc/aio/py_test/ds_aio_job.py
@@ -0,0 +1,48 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""
+Functionality of swapping tensors to/from (NVMe) storage devices.
+"""
+import subprocess
+
+
+class Job(object):
+
+ def __init__(self, cmd_line, output_file=None, work_dir=None):
+ self.cmd_line = cmd_line
+ self.output_file = output_file
+ self.work_dir = work_dir
+ self.output_fd = None
+
+ def cmd(self):
+ return self.cmd_line
+
+ def get_stdout(self):
+ return self.output_fd
+
+ def get_stderr(self):
+ return self.output_fd
+
+ def get_cwd(self):
+ return self.work_dir
+
+ def open_output_file(self):
+ if self.output_file is not None:
+ self.output_fd = open(self.output_file, 'w')
+
+ def close_output_file(self):
+ if self.output_fd is not None:
+ self.output_fd.close()
+ self.output_fd = None
+
+
+def run_job(job):
+ args = ' '.join(job.cmd())
+ print(f'args = {args}')
+ job.open_output_file()
+ proc = subprocess.run(args=args, shell=True, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd())
+ job.close_output_file()
+ assert proc.returncode == 0, \
+ f"This command failed: {job.cmd()}"
diff --git a/csrc/aio/py_test/run_read_sweep.sh b/csrc/aio/py_test/run_read_sweep.sh
index b9d7e050454a..59d82996a0e2 100755
--- a/csrc/aio/py_test/run_read_sweep.sh
+++ b/csrc/aio/py_test/run_read_sweep.sh
@@ -1,13 +1,22 @@
#!/bin/bash
-if [[ $# -ne 2 ]]; then
- echo "Usage: $0 "
+if [[ $# -lt 2 ]]; then
+ echo "Usage: $0 "
exit 1
fi
+function prep_folder()
+{
+ folder=$1
+ if [[ -d ${folder} ]]; then
+ rm -f ${folder}/*
+ else
+ mkdir -p ${folder}
+ fi
+}
function validate_environment()
{
- validate_cmd="python ./validate_async_io.py"
+ validate_cmd="TORCH_EXTENSIONS_DIR=./torch_extentions python3 ./validate_async_io.py"
eval ${validate_cmd}
res=$?
if [[ $res != 0 ]]; then
@@ -17,18 +26,27 @@ function validate_environment()
fi
}
+function fileExists() {
+ local file="$1"
+ if [[ -f "$file" ]]; then
+ return 0
+ else
+ return 1
+ fi
+}
validate_environment
-INPUT_FILE=$1
-if [[ ! -f ${INPUT_FILE} ]]; then
- echo "Input file not found: ${INPUT_FILE}"
- exit 1
-fi
-
-LOG_DIR=$2/aio_perf_sweep
+IO_SIZE=$1
+LOG_DIR=./aio_perf_sweep
+MAP_DIR=$2/aio
+GPU_MEM=$3
+USE_GDS=$4
RUN_SCRIPT=./test_ds_aio.py
-READ_OPT="--read_file ${INPUT_FILE}"
+READ_OPT="--read"
+
+prep_folder ${MAP_DIR}
+prep_folder ${LOG_DIR}
if [[ -d ${LOG_DIR} ]]; then
rm -f ${LOG_DIR}/*
@@ -36,37 +54,60 @@ else
mkdir -p ${LOG_DIR}
fi
-DISABLE_CACHE="sync; sudo bash -c 'echo 1 > /proc/sys/vm/drop_caches' "
-SYNC="sync"
+if [[ ${GPU_MEM} == "gpu" ]]; then
+ gpu_opt="--gpu"
+else
+ gpu_opt=""
+fi
+if [[ ${USE_GDS} == "gds" ]]; then
+ gds_opt="--use_gds"
+else
+ gds_opt=""
+fi
+
+DISABLE_CACHE="sudo sync; sudo bash -c 'echo 1 > /proc/sys/vm/drop_caches' "
+SYNC="sudo sync"
-for sub in single block; do
- if [[ $sub == "single" ]]; then
- sub_opt="--single_submit"
+for xtype in cpu gpu gds; do
+ if [[ $xtype == "cpu" ]]; then
+ gpu_opt=""
+ gds_opt=""
+ elif [[ $xtype == "gpu" ]]; then
+ gpu_opt="--gpu"
+ gds_opt=""
else
- sub_opt=""
+ gpu_opt="--gpu"
+ gds_opt="--use_gds"
fi
- for ov in overlap sequential; do
- if [[ $ov == "overlap" ]]; then
- ov_opt="--overlap_events"
+ for sub in single block; do
+ if [[ $sub == "single" ]]; then
+ sub_opt="--single_submit"
else
- ov_opt=""
+ sub_opt=""
fi
- for t in 1 2 4 8; do
- for p in 1 ; do
- for d in 1 2 4 8 16 32; do
- for bs in 128K 256K 512K 1M; do
- SCHED_OPTS="${sub_opt} ${ov_opt} --handle --threads ${t}"
- OPTS="--io_parallel ${p} --queue_depth ${d} --block_size ${bs}"
- LOG="${LOG_DIR}/read_${sub}_${ov}_t${t}_p${p}_d${d}_bs${bs}.txt"
- cmd="python ${RUN_SCRIPT} ${READ_OPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}"
- echo ${DISABLE_CACHE}
- echo ${cmd}
- echo ${SYNC}
+ for ov in overlap sequential; do
+ if [[ $ov == "sequential" ]]; then
+ ov_opt="--sequential_requests"
+ else
+ ov_opt=""
+ fi
+ for p in 1 2 4 8; do
+ for t in 1 2 4 8; do
+ for d in 8 16 32 64 128; do
+ for bs in 128K 256K 512K 1M 2M 4M 8M 16M; do
+ SCHED_OPTS="${sub_opt} ${ov_opt} --handle ${gpu_opt} ${gds_opt} --folder_to_device_mapping /mnt/nvme01:0"
+ OPTS="--queue_depth ${d} --block_size ${bs} --io_size ${IO_SIZE} --io_parallel ${t}"
+ LOG="${LOG_DIR}/read_${xtype}_${sub}_${ov}_t${t}_p${p}_d${d}_bs${bs}.txt"
+ cmd="/usr/bin/time python ${RUN_SCRIPT} ${READ_OPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}"
- eval ${DISABLE_CACHE}
- eval ${cmd}
- eval ${SYNC}
- sleep 2
+ echo ${DISABLE_CACHE}
+ echo ${cmd}
+ echo ${SYNC}
+ eval ${DISABLE_CACHE}
+ eval ${cmd}
+ eval ${SYNC}
+ sleep 2
+ done
done
done
done
diff --git a/csrc/aio/py_test/run_write_sweep.sh b/csrc/aio/py_test/run_write_sweep.sh
index 99f2113dda6f..a54d1c8d7bed 100755
--- a/csrc/aio/py_test/run_write_sweep.sh
+++ b/csrc/aio/py_test/run_write_sweep.sh
@@ -25,25 +25,33 @@ function validate_environment()
validate_environment
-if [[ $# -ne 3 ]]; then
- echo "Usage: $0 "
- exit 1
-fi
-
-SIZE="$1M"
-WRITE_DIR=$2
-LOG_DIR=$3/aio_perf_sweep
+IO_SIZE=$1
+LOG_DIR=$2/aio_perf_sweep
+MAP_DIR=$2/aio
+GPU_MEM=$3
+USE_GDS=$4
+RUN_SCRIPT=./test_ds_aio.py
-OUTPUT_FILE=${WRITE_DIR}/ds_aio_write_${SIZE}B.pt
-WRITE_OPT="--write_file ${OUTPUT_FILE} --write_size ${SIZE}"
+OUTPUT_FILE=${MAP_DIR}/ds_aio_write_${SIZE}B.pt
+WRITE_OPT=""
-prep_folder ${WRITE_DIR}
+prep_folder ${MAP_DIR}
prep_folder ${LOG_DIR}
-RUN_SCRIPT=./test_ds_aio.py
-DISABLE_CACHE="sync; sudo bash -c 'echo 1 > /proc/sys/vm/drop_caches' "
+if [[ ${GPU_MEM} == "gpu" ]]; then
+ gpu_opt="--gpu"
+else
+ gpu_opt=""
+fi
+if [[ ${USE_GDS} == "gds" ]]; then
+ gds_opt="--use_gds"
+else
+ gds_opt=""
+fi
+
+DISABLE_CACHE="sync; bash -c 'echo 1 > /proc/sys/vm/drop_caches' "
SYNC="sync"
for sub in single block; do
@@ -53,19 +61,19 @@ for sub in single block; do
sub_opt=""
fi
for ov in overlap sequential; do
- if [[ $ov == "overlap" ]]; then
- ov_opt="--overlap_events"
+ if [[ $ov == "sequential" ]]; then
+ ov_opt="--sequential_requests"
else
ov_opt=""
fi
- for t in 1 2 4 8; do
- for p in 1; do
- for d in 1 2 4 8 16 32; do
- for bs in 128K 256K 512K 1M; do
- SCHED_OPTS="${sub_opt} ${ov_opt} --handle --threads ${t}"
- OPTS="--io_parallel ${p} --queue_depth ${d} --block_size ${bs}"
+ for p in 1 2 4 8; do
+ for t in 1 2 4 8; do
+ for d in 32 64 128; do
+ for bs in 256K 512K 1M; do
+ SCHED_OPTS="${sub_opt} ${ov_opt} --handle ${gpu_opt} ${gds_opt} --folder ${MAP_DIR}"
+ OPTS="--queue_depth ${d} --block_size ${bs} --io_size ${IO_SIZE} --multi_process ${p} --io_parallel ${t}"
LOG="${LOG_DIR}/write_${sub}_${ov}_t${t}_p${p}_d${d}_bs${bs}.txt"
- cmd="python ${RUN_SCRIPT} ${WRITE_OPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}"
+ cmd="python ${RUN_SCRIPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}"
echo ${DISABLE_CACHE}
echo ${cmd}
echo ${SYNC}
diff --git a/csrc/aio/py_test/test_ds_aio.py b/csrc/aio/py_test/test_ds_aio.py
index e6242cb35789..6de72755e9e5 100755
--- a/csrc/aio/py_test/test_ds_aio.py
+++ b/csrc/aio/py_test/test_ds_aio.py
@@ -6,79 +6,19 @@
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
-import os
-import argparse
import multiprocessing as mp
from ds_aio_basic import aio_basic_multiprocessing
from ds_aio_handle import aio_handle_multiprocessing
-from test_ds_aio_utils import refine_args
-
-
-def parse_arguments():
- parser = argparse.ArgumentParser()
-
- parser.add_argument('--read_file', type=str, default=None, help='Read file.')
-
- parser.add_argument('--write_file', type=str, default=None, help='Write file.')
-
- parser.add_argument('--write_size', type=str, default=None, help='Number of bytes to write.')
-
- parser.add_argument('--block_size', type=str, default='1M', help='I/O block size.')
-
- parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth.')
-
- parser.add_argument('--threads', type=int, default=1, help='Thread parallelism count.')
-
- parser.add_argument('--single_submit',
- action='store_true',
- help='Submit I/O requests in singles (default is submit queue_depth amount at once.).')
-
- parser.add_argument('--overlap_events',
- action='store_true',
- help='Overlap I/O submission and completion requests.')
-
- parser.add_argument('--validate', action='store_true', help='Perform validation in library.')
-
- parser.add_argument('--handle', action='store_true', help='Use AIO handle.')
-
- parser.add_argument('--loops', type=int, default=1, help='Count of operation repetitions')
-
- parser.add_argument('--io_parallel', type=int, default=None, help='Per iop parallelism')
-
- parser.add_argument('--gpu', action='store_true', help='Use GPU memory')
-
- parser.add_argument('--use_accelerator_pin_memory',
- action='store_true',
- help='Obtain pinned (CPU page-locked) tensors from accelerator')
-
- args = parser.parse_args()
- print(f'args = {args}')
- return args
-
-
-def validate_args(args):
- if args.read_file and not os.path.isfile(args.read_file):
- print(f'args validation error: {args.read_file} not found')
- return False
-
- return True
+from ds_aio_args import get_validated_args
def main():
print(f'Testing deepspeed_aio python frontend')
- args = parse_arguments()
- refine_args(args)
- if not validate_args(args):
- quit()
-
+ args = get_validated_args()
mp.set_start_method('spawn')
multiprocess_function = aio_handle_multiprocessing if args.handle else aio_basic_multiprocessing
- if args.read_file:
- multiprocess_function(args, True)
-
- if args.write_file:
- multiprocess_function(args, False)
+ multiprocess_function(args, args.read)
if __name__ == "__main__":
diff --git a/csrc/aio/py_test/test_ds_aio_utils.py b/csrc/aio/py_test/test_ds_aio_utils.py
index 6aad114c0bdc..968ff4a60ef9 100755
--- a/csrc/aio/py_test/test_ds_aio_utils.py
+++ b/csrc/aio/py_test/test_ds_aio_utils.py
@@ -6,12 +6,17 @@
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
+import os
+from ds_aio_job import Job, run_job
+
BYTES_PER_GB = 1024**3
+BYTES_PER_MB = 1024**2
+BYTES_PER_KB = 1024
LOG_TIDS = [0]
-def task_log(tid, msg):
- if tid in LOG_TIDS:
+def task_log(tid, msg, force=False):
+ if force or tid in LOG_TIDS:
print(f'tid {tid}: {msg}')
@@ -31,16 +36,29 @@ def report_results(args, read_op, pool_results):
total_bytes = sum([num_bytes for _, _, num_bytes in pool_results])
task_latency_sec = max([sec for _, sec, _ in pool_results])
- task_speed_GB = total_bytes / task_latency_sec / BYTES_PER_GB
+ task_speed_GB = 0 if task_latency_sec == 0 else total_bytes / task_latency_sec / BYTES_PER_GB
print(f'Task {io_string} Latency = {task_latency_sec} sec')
print(f'Task {io_string} Speed = {task_speed_GB} GB/sec')
e2e_latency_sec = max([sec for sec, _, _ in pool_results])
- e2e_speed_GB = total_bytes / e2e_latency_sec / BYTES_PER_GB
+ e2e_speed_GB = 0 if e2e_latency_sec == 0 else total_bytes / e2e_latency_sec / BYTES_PER_GB
print(f'E2E {io_string} Latency = {e2e_latency_sec} sec')
print(f'E2E {io_string} Speed = {e2e_speed_GB} GB/sec')
+def get_block_size_and_count(io_bytes):
+ if io_bytes > BYTES_PER_MB and io_bytes % BYTES_PER_MB == 0:
+ block_size = BYTES_PER_MB
+ block_size_string = '1M'
+ else:
+ assert io_bytes % BYTES_PER_KB == 0
+ block_size = BYTES_PER_KB
+ block_size_string = '1K'
+ block_count = io_bytes / block_size
+
+ return block_size_string, int(block_count)
+
+
def refine_integer_value(value):
unit_dict = {'K': 1024, 'M': 1024**2, 'G': 1024**3}
@@ -50,9 +68,14 @@ def refine_integer_value(value):
return int(value)
-def refine_args(args):
- if args.write_size and type(args.write_size) == str:
- args.write_size = refine_integer_value(args.write_size)
+def create_filename(folder, read_op, size, tid):
+ io_string = "read" if read_op else "write"
+ return os.path.join(folder, f'_aio_{io_string}_{size}.pt.{tid}')
+
- if args.block_size and type(args.block_size) == str:
- args.block_size = refine_integer_value(args.block_size)
+def create_file(filename, num_bytes):
+ block_size, block_count = get_block_size_and_count(num_bytes)
+ dd_job = Job(cmd_line=[f'dd if=/dev/urandom of={filename} bs={block_size} count={block_count}'])
+ print(f'[Start] Create {filename} of {num_bytes} bytes by running {dd_job.cmd()} ....')
+ run_job(dd_job)
+ print(f'[Done] Create read file of {num_bytes} bytes by running {dd_job.cmd()} ....')
diff --git a/csrc/aio/py_test/validate_async_io.py b/csrc/aio/py_test/validate_async_io.py
index 019ec05d49d3..10fb638347bc 100644
--- a/csrc/aio/py_test/validate_async_io.py
+++ b/csrc/aio/py_test/validate_async_io.py
@@ -7,3 +7,4 @@
"""
from deepspeed.ops.op_builder import AsyncIOBuilder
assert AsyncIOBuilder().is_compatible()
+assert AsyncIOBuilder().load()
diff --git a/csrc/gds/py_lib/deepspeed_gds_op.cpp b/csrc/gds/py_lib/deepspeed_gds_op.cpp
new file mode 100644
index 000000000000..c370a448e5a2
--- /dev/null
+++ b/csrc/gds/py_lib/deepspeed_gds_op.cpp
@@ -0,0 +1,154 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*
+Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
+*/
+
+#include "deepspeed_gds_op.h"
+
+using namespace std;
+
+// For when there is more than 1 device
+static std::map> base_ptr_registry;
+
+static void _safe_handle_register(const int fd, CUfileDescr_t& cf_descr, CUfileHandle_t& cf_handle)
+{
+ memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t));
+ cf_descr.handle.fd = fd;
+ cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
+ CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr);
+ if (status.err != CU_FILE_SUCCESS) {
+ std::cerr << "file register error:" << cuFileGetErrorString(status) << std::endl;
+ close(fd);
+ exit(EXIT_FAILURE);
+ }
+}
+
+static void* _find_base_ptr(const int64_t device, char* buf_ptr)
+{
+ void* base_ptr = nullptr;
+ int64_t last = -1;
+ int64_t ptr_diff;
+ for (const auto& value : base_ptr_registry[device]) {
+ ptr_diff = buf_ptr - (char*)value;
+ if (last == -1 && ptr_diff >= 0) {
+ last = ptr_diff;
+ base_ptr = value;
+ } else if (ptr_diff < last && ptr_diff >= 0) {
+ last = ptr_diff;
+ base_ptr = value;
+ }
+ }
+ if (!base_ptr || buf_ptr < base_ptr) {
+ std::cerr << "BASE PTR ERROR :" << base_ptr << " BUF PTR " << (void*)buf_ptr << std::endl;
+ for (const auto& value : base_ptr_registry[device]) {
+ std::cerr << "BASE PTR AVAIL :" << value << std::endl;
+ }
+ exit(EXIT_FAILURE);
+ }
+
+ return base_ptr;
+}
+
+void gds_op_desc_t::add_buffer_to_registry(const torch::Tensor& buffer)
+{
+ const int64_t device = buffer.get_device();
+ void* reg_ptr = buffer.data_ptr();
+
+ // std::cout << "REG PTR " << reg_ptr << std::endl;
+ // TODO: add checking to make sure pointer isn't already in set
+ const auto it = base_ptr_registry.find(device);
+ if (it == base_ptr_registry.end()) {
+ std::set new_ptr_set;
+ new_ptr_set.insert(reg_ptr);
+ base_ptr_registry.insert(std::pair>(device, new_ptr_set));
+ } else {
+ base_ptr_registry[device].insert(reg_ptr);
+ }
+
+ check_cudaruntimecall(cudaSetDevice(device));
+ CUfileError_t status = cuFileBufRegister(reg_ptr, buffer.nbytes(), 0);
+ if (status.err != CU_FILE_SUCCESS) {
+ std::cerr << "buffer register failed:" << cuFileGetErrorString(status) << std::endl;
+ exit(EXIT_FAILURE);
+ }
+}
+
+void gds_op_desc_t::remove_buffer_from_registry(const torch::Tensor& buffer)
+{
+ const int64_t device = buffer.get_device();
+ void* reg_ptr = buffer.data_ptr();
+
+ // std::cout << "DEREG PTR " << reg_ptr << std::endl;
+ check_cudaruntimecall(cudaSetDevice(device));
+ cuFileBufDeregister(reg_ptr);
+
+ // Remove from tracked registry
+ base_ptr_registry[device].erase(reg_ptr);
+}
+
+gds_op_desc_t::gds_op_desc_t(const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const int num_threads,
+ const bool validate)
+ : io_op_desc_t(read_op, buffer, fd, filename, file_num_bytes, num_threads, validate)
+{
+ _contiguous_buffer = _buffer.contiguous();
+ const int64_t device = _buffer.get_device();
+ check_cudaruntimecall(cudaSetDevice(device));
+ _base_ptr = _find_base_ptr(device, (char*)_contiguous_buffer.data_ptr());
+
+ _safe_handle_register(fd, _cf_descr, _cf_handle);
+}
+
+char* gds_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }
+
+void gds_op_desc_t::finish() { cuFileHandleDeregister(_cf_handle); }
+
+void gds_op_desc_t::validate()
+{
+ check_cudaruntimecall(cudaSetDevice(_buffer.get_device()));
+ const auto cpu_buffer = _buffer.to(torch::kCPU);
+ validate_aio_operation(
+ _read_op, _filename.c_str(), (char*)(cpu_buffer.data_ptr()), _file_num_bytes);
+}
+
+void gds_op_desc_t::run(const int tid,
+ std::unique_ptr& aio_ctxt,
+ deepspeed_aio_config_t* aio_config)
+{
+ assert(tid < _num_threads);
+ check_cudaruntimecall(cudaSetDevice(_buffer.get_device()));
+ int64_t buf_offset = data_ptr() + (_num_bytes_per_thread * tid) - (char*)_base_ptr;
+ const auto file_offset = _num_bytes_per_thread * tid;
+
+ if (_read_op) {
+ auto ret =
+ cuFileRead(_cf_handle, _base_ptr, _num_bytes_per_thread, file_offset, buf_offset);
+ if (ret < 0) { _report_error(ret, errno, buf_offset); }
+ } else {
+ auto ret =
+ cuFileWrite(_cf_handle, _base_ptr, _num_bytes_per_thread, file_offset, buf_offset);
+ if (ret < 0) { _report_error(ret, errno, buf_offset); }
+ }
+}
+
+void gds_op_desc_t::_report_error(const ssize_t return_code,
+ const int error_num,
+ const off_t offset)
+{
+ const auto op_string = _read_op ? "read failed with " : "write failed with ";
+ const auto error_string = IS_CUFILE_ERR(return_code) ? "cuFile error: " : "posix error: ";
+ const auto error_code = IS_CUFILE_ERR(return_code) ? cuFileGetErrorString(return_code)
+ : cuFileGetErrorString(error_num);
+ std::cerr << op_string << error_string << error_code << " return code = " << return_code
+ << " filename = " << _filename.c_str() << " num bytes = " << _num_bytes_per_thread
+ << " offset = " << offset << std::endl;
+ exit(EXIT_FAILURE);
+}
diff --git a/csrc/gds/py_lib/deepspeed_gds_op.h b/csrc/gds/py_lib/deepspeed_gds_op.h
new file mode 100644
index 000000000000..b7fab64d4054
--- /dev/null
+++ b/csrc/gds/py_lib/deepspeed_gds_op.h
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "deepspeed_aio_op_desc.h"
+#include "deepspeed_gds_utils.h"
+
+struct gds_op_desc_t : io_op_desc_t {
+ CUfileDescr_t _cf_descr;
+ CUfileHandle_t _cf_handle;
+ void* _base_ptr;
+
+ gds_op_desc_t(const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const int num_threads,
+ const bool validate);
+
+ void run(const int tid,
+ std::unique_ptr& aio_ctxt,
+ deepspeed_aio_config_t* aio_config);
+
+ char* data_ptr() const;
+
+ void validate();
+
+ void finish();
+
+ void _report_error(const ssize_t return_code, const int error_num, const off_t offset);
+
+ static void add_buffer_to_registry(const torch::Tensor& buffer);
+
+ static void remove_buffer_from_registry(const torch::Tensor& buffer);
+};
diff --git a/csrc/gds/py_lib/deepspeed_gds_utils.h b/csrc/gds/py_lib/deepspeed_gds_utils.h
new file mode 100644
index 000000000000..12b014d90988
--- /dev/null
+++ b/csrc/gds/py_lib/deepspeed_gds_utils.h
@@ -0,0 +1,91 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+
+// CUDA/cuFile includes
+#include
+#include
+#include "cufile.h"
+
+// Macro for checking cuda errors following a cuda launch or api call
+#define cudaCheckError() \
+ { \
+ cudaError_t e = cudaGetLastError(); \
+ if (e != cudaSuccess) { \
+ printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
+ exit(EXIT_FAILURE); \
+ } \
+ }
+
+#define check_cudadrivercall(fn) \
+ do { \
+ CUresult res = fn; \
+ if (res != CUDA_SUCCESS) { \
+ const char* str = nullptr; \
+ cuGetErrorName(res, &str); \
+ std::cerr << "cuda driver api call failed " << #fn << " res : " << res << ", " \
+ << __LINE__ << ":" << str << std::endl; \
+ std::cerr << "EXITING program!!!" << std::endl; \
+ exit(1); \
+ } \
+ } while (0)
+
+#define check_cudaruntimecall(fn) \
+ do { \
+ cudaError_t res = fn; \
+ if (res != cudaSuccess) { \
+ const char* str = cudaGetErrorName(res); \
+ std::cerr << "cuda runtime api call failed " << #fn << __LINE__ << ":" << str \
+ << std::endl; \
+ std::cerr << "EXITING program!!!" << std::endl; \
+ exit(1); \
+ } \
+ } while (0)
+
+#define check_cuFileCall(fn, api_msg) \
+ do { \
+ CUfileError_t status = fn; \
+ if (status.err != CU_FILE_SUCCESS) { \
+ std::cout << api_msg << " failed with error " << CUFILE_ERRSTR(status.err) \
+ << std::endl; \
+ exit(EXIT_FAILURE); \
+ } \
+ } while (0)
+
+//
+// cuda driver error description
+//
+static inline const char* GetCuErrorString(CUresult curesult)
+{
+ const char* descp;
+ if (cuGetErrorName(curesult, &descp) != CUDA_SUCCESS) descp = "unknown cuda error";
+ return descp;
+}
+
+//
+// cuFile APIs return both cuFile specific error codes as well as POSIX error codes
+// for ease, the below template can be used for getting the error description depending
+// on its type.
+
+// POSIX
+template ::value, std::nullptr_t>::type = nullptr>
+std::string cuFileGetErrorString(T status)
+{
+ status = std::abs(status);
+ return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status))
+ : std::string(std::strerror(status));
+}
+
+// CUfileError_t
+template ::value, std::nullptr_t>::type = nullptr>
+std::string cuFileGetErrorString(T status)
+{
+ std::string errStr = cuFileGetErrorString(static_cast(status.err));
+ if (IS_CUDA_ERR(status)) errStr.append(".").append(GetCuErrorString(status.cu_err));
+ return errStr;
+}
diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp
new file mode 100644
index 000000000000..3a35ad3145a0
--- /dev/null
+++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.cpp
@@ -0,0 +1,114 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*
+ GPUDirect Storage functionality for swapping optimizer tensors to/from (NVMe) storage devices.
+*/
+
+#include "deepspeed_py_gds_handle.h"
+#include
+#include "deepspeed_gds_op.h"
+
+using namespace std;
+
+int deepspeed_gds_handle_t::s_cuFile_init = 0;
+
+deepspeed_gds_handle_t::deepspeed_gds_handle_t(const int block_size,
+ const int queue_depth,
+ const bool single_submit,
+ const bool overlap_events,
+ const int num_threads)
+ : deepspeed_io_handle_t(block_size, queue_depth, single_submit, overlap_events, num_threads)
+{
+ _init_cuFile(block_size, queue_depth, num_threads);
+}
+
+deepspeed_gds_handle_t::~deepspeed_gds_handle_t() { _close_cuFile(); }
+
+void deepspeed_gds_handle_t::_init_cuFile(const int block_size,
+ const int queue_depth,
+ const int num_threads)
+{
+ if (deepspeed_gds_handle_t::s_cuFile_init == 0) {
+ std::string depthStr = std::to_string(queue_depth);
+ std::string threadsStr = std::to_string(num_threads);
+ std::string json1 = R"({"execution": {"max_io_queue_depth": )" + depthStr + ", ";
+ std::string json2 = R"("max_request_parallelism": )" + threadsStr + ", ";
+ std::string json3 = R"("max_io_threads": )" + threadsStr + ", ";
+ std::string json4 = R"("parallel_io": true, "min_io_threshold_size_kb": 8192}})";
+ std::ofstream outFile("local_cufile.json");
+ if (outFile.is_open()) {
+ outFile << json1 + json2 + json3 + json4;
+ outFile.close();
+ } else {
+ std::cerr << "Can't open local cufile" << std::endl;
+ exit(EXIT_FAILURE);
+ }
+ // TODO: Address the following issues with this code
+ // (1) Fix C++14 warning
+ // (2) Create file in a different location than PWD
+ // (3) Handle multi-GPU/multi-rank scenarios: should cufile be shared, is per-rank cufile
+ // safe?
+ putenv("CUFILE_ENV_PATH_JSON=$PWD/local_cufile.json");
+ cuFileDriverOpen();
+ cudaCheckError();
+ size_t direct_io_size = (size_t)block_size / 1024;
+ CUfileError_t status = cuFileDriverSetMaxDirectIOSize(direct_io_size);
+ if (status.err != CU_FILE_SUCCESS) {
+ std::cerr << "file register error:" << cuFileGetErrorString(status) << std::endl;
+ exit(EXIT_FAILURE);
+ }
+ }
+ deepspeed_gds_handle_t::s_cuFile_init++;
+}
+
+void deepspeed_gds_handle_t::_close_cuFile()
+{
+ deepspeed_gds_handle_t::s_cuFile_init--;
+ if (deepspeed_gds_handle_t::s_cuFile_init == 0) { cuFileDriverClose(); }
+}
+
+torch::Tensor deepspeed_gds_handle_t::new_pinned_device_tensor(const size_t num_elem,
+ const torch::Tensor& example_tensor)
+{
+ auto options = torch::TensorOptions().dtype(example_tensor.scalar_type()).device(torch::kCUDA);
+ auto dev_tensor = torch::empty(num_elem, options);
+ pin_device_tensor(dev_tensor);
+ return dev_tensor;
+}
+
+bool deepspeed_gds_handle_t::free_pinned_device_tensor(torch::Tensor& buffer)
+{
+ unpin_device_tensor(buffer);
+ return true;
+}
+
+bool deepspeed_gds_handle_t::pin_device_tensor(const torch::Tensor& buffer)
+{
+ gds_op_desc_t::add_buffer_to_registry(buffer);
+ return true;
+}
+
+bool deepspeed_gds_handle_t::unpin_device_tensor(const torch::Tensor& buffer)
+{
+ gds_op_desc_t::remove_buffer_from_registry(buffer);
+ return true;
+}
+
+std::shared_ptr deepspeed_gds_handle_t::_create_io_op_desc(
+ const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const bool validate)
+{
+ if (buffer.is_cuda()) {
+ return std::make_shared(
+ read_op, buffer, fd, filename, file_num_bytes, _num_threads, validate);
+ }
+ return deepspeed_io_handle_t::_create_io_op_desc(
+ read_op, buffer, fd, filename, file_num_bytes, validate);
+}
diff --git a/csrc/gds/py_lib/deepspeed_py_gds_handle.h b/csrc/gds/py_lib/deepspeed_py_gds_handle.h
new file mode 100644
index 000000000000..f324e6b65e80
--- /dev/null
+++ b/csrc/gds/py_lib/deepspeed_py_gds_handle.h
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*
+Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
+*/
+
+#include
+#include
+#include "deepspeed_py_io_handle.h"
+
+struct deepspeed_gds_handle_t : deepspeed_io_handle_t {
+ deepspeed_gds_handle_t(const int block_size,
+ const int queue_depth,
+ const bool single_submit,
+ const bool overlap_events,
+ const int num_threads);
+
+ ~deepspeed_gds_handle_t();
+
+ torch::Tensor new_pinned_device_tensor(const size_t num_elem,
+ const torch::Tensor& example_tensor);
+
+ bool free_pinned_device_tensor(torch::Tensor&);
+
+ bool pin_device_tensor(const torch::Tensor& buffer);
+
+ bool unpin_device_tensor(const torch::Tensor& buffer);
+
+ void _init_cuFile(const int block_size, const int queue_length, const int num_threads);
+
+ void _close_cuFile();
+
+ std::shared_ptr _create_io_op_desc(const bool read_op,
+ const torch::Tensor& buffer,
+ const int fd,
+ const char* filename,
+ const long long int file_num_bytes,
+ const bool validate);
+
+ static int s_cuFile_init;
+};
diff --git a/csrc/gds/py_lib/py_ds_gds.cpp b/csrc/gds/py_lib/py_ds_gds.cpp
new file mode 100644
index 000000000000..66eb34d4ea8c
--- /dev/null
+++ b/csrc/gds/py_lib/py_ds_gds.cpp
@@ -0,0 +1,122 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+/*
+Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
+*/
+
+#include
+#include "deepspeed_py_gds_handle.h"
+using namespace pybind11::literals;
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ py::class_(m, "gds_handle")
+ .def(py::init(),
+ "GDS handle constructor",
+ "block_size"_a = 1024 * 1024,
+ "queue_depth"_a = 128,
+ "single_submit"_a = false,
+ "overlap_events"_a = false,
+ "num_threads"_a = 1)
+
+ .def("get_block_size", &deepspeed_gds_handle_t::get_block_size)
+ .def("get_queue_depth", &deepspeed_gds_handle_t::get_queue_depth)
+ .def("get_single_submit", &deepspeed_gds_handle_t::get_single_submit)
+ .def("get_overlap_events", &deepspeed_gds_handle_t::get_overlap_events)
+ .def("get_thread_count", &deepspeed_gds_handle_t::get_thread_count)
+
+ .def("read",
+ &deepspeed_gds_handle_t::read,
+ "Synchronous and non-parallel file read. Returns count of completed read ops",
+ "buffer"_a,
+ "filename"_a,
+ "validate"_a)
+
+ .def("write",
+ &deepspeed_gds_handle_t::write,
+ "Synchronous and non-parallel file write. Returns count of completed write ops",
+ "buffer"_a,
+ "filename"_a,
+ "validate"_a)
+
+ .def("pread",
+ &deepspeed_gds_handle_t::pread,
+ "Parallel file read with option of parallelism. Returns count of completed read ops",
+ "buffer"_a,
+ "filename"_a,
+ "validate"_a,
+ "async"_a)
+
+ .def("pwrite",
+ &deepspeed_gds_handle_t::pwrite,
+ "Parallel file write with option of parallelism. Returns count of completed write ops",
+ "buffer"_a,
+ "filename"_a,
+ "validate"_a,
+ "async"_a)
+
+ .def("sync_pread",
+ &deepspeed_gds_handle_t::sync_pread,
+ "Synchrononous parallel file read. Returns count of completed read ops",
+ "buffer"_a,
+ "filename"_a)
+
+ .def("sync_pwrite",
+ &deepspeed_gds_handle_t::sync_pwrite,
+ "Synchronous parallel file write. Returns count of completed write ops",
+ "buffer"_a,
+ "filename"_a)
+
+ .def("async_pread",
+ &deepspeed_gds_handle_t::async_pread,
+ "Asynchronous parallel file read. Returns 0 on success. Returns 0 on success, and "
+ "following wait() returns count of completed ops.",
+ "buffer"_a,
+ "filename"_a)
+
+ .def("async_pwrite",
+ &deepspeed_gds_handle_t::async_pwrite,
+ "Asynchronous parallel file write. Returns 0 on success, and following wait() returns "
+ "count of completed ops.",
+ "buffer"_a,
+ "filename"_a)
+
+ .def("new_cpu_locked_tensor",
+ &deepspeed_gds_handle_t::new_cpu_locked_tensor,
+ "Allocate pinned CPU tensor.",
+ "num_elem"_a,
+ "example_tenosr"_a)
+
+ .def("free_cpu_locked_tensor",
+ &deepspeed_gds_handle_t::free_cpu_locked_tensor,
+ "Free pinned CPU tensor.",
+ "tensor"_a)
+
+ .def("new_pinned_device_tensor",
+ &deepspeed_gds_handle_t::new_pinned_device_tensor,
+ "Allocate pinned device tensor.",
+ "num_elem"_a,
+ "example_tenosr"_a)
+
+ .def("free_pinned_device_tensor",
+ &deepspeed_gds_handle_t::free_pinned_device_tensor,
+ "Free pinned device tensor.",
+ "tensor"_a)
+
+ .def("pin_device_tensor",
+ &deepspeed_gds_handle_t::pin_device_tensor,
+ "Pin device tensor.",
+ "tensor"_a)
+
+ .def("unpin_device_tensor",
+ &deepspeed_gds_handle_t::unpin_device_tensor,
+ "Unpin device tensor.",
+ "tensor"_a)
+
+ .def("wait",
+ &deepspeed_gds_handle_t::wait,
+ "Wait for (ongoing) asynchronous operations to complete");
+}
diff --git a/csrc/gds/py_test/validate_gds.py b/csrc/gds/py_test/validate_gds.py
new file mode 100644
index 000000000000..b34b1194f582
--- /dev/null
+++ b/csrc/gds/py_test/validate_gds.py
@@ -0,0 +1,10 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+"""
+Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
+"""
+from deepspeed.ops.op_builder import GDSBuilder
+assert GDSBuilder().is_compatible(True)
+assert GDSBuilder().load(True)
diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py
index fe0043547860..d8655299282f 100755
--- a/deepspeed/__init__.py
+++ b/deepspeed/__init__.py
@@ -77,6 +77,7 @@ def initialize(args=None,
dist_init_required: Optional[bool] = None,
collate_fn=None,
config=None,
+ mesh_param=None,
config_params=None):
"""Initialize the DeepSpeed Engine.
@@ -144,10 +145,22 @@ def initialize(args=None,
distributed_port=distributed_port,
dist_init_required=dist_init_required)
+ ##TODO: combine reuse mpu as mesh device and vice versa
# Set config using config_params for backwards compat
if config is None and config_params is not None:
config = config_params
+ mesh_device = None
+ if mesh_param:
+ logger.info(f"mesh_param to Initialize mesh device: {mesh_param}")
+ mesh_device = dist.initialize_mesh_device(mesh_param, ("data_parallel", "sequence_parallel"))
+ #if config file has sequence parallelize and data parallelize, then use them to initialize mesh device
+ elif config is not None:
+ if "sequence_parallel_size" in config and "data_parallel_size" in config:
+ logger.info(f"config to Initialize mesh device: {config}")
+ mesh_device = dist.initialize_mesh_device((config["data_parallel_size"], config["sequence_parallel_size"]), \
+ ("data_parallel", "sequence_parallel"))
+
# Check for deepscale_config for backwards compat
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
@@ -162,9 +175,8 @@ def initialize(args=None,
assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call"
config = args.deepspeed_config
assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"
-
if not isinstance(model, PipelineModule):
- config_class = DeepSpeedConfig(config, mpu)
+ config_class = DeepSpeedConfig(config, mpu, mesh_device=mesh_device)
if config_class.hybrid_engine.enabled:
engine = DeepSpeedHybridEngine(args=args,
model=model,
@@ -188,6 +200,7 @@ def initialize(args=None,
dist_init_required=dist_init_required,
collate_fn=collate_fn,
config=config,
+ mesh_device=mesh_device,
config_class=config_class)
else:
assert mpu is None, "mpu must be None with pipeline parallelism"
@@ -208,7 +221,12 @@ def initialize(args=None,
# Restore zero.Init context if necessary
zero.partition_parameters.restore_init_context()
- return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler]
+ return_items = [
+ engine,
+ engine.optimizer,
+ engine.training_dataloader,
+ engine.lr_scheduler,
+ ]
return tuple(return_items)
diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py
old mode 100644
new mode 100755
index 85b7fab2c548..2895e0f2e011
--- a/deepspeed/comm/comm.py
+++ b/deepspeed/comm/comm.py
@@ -600,6 +600,21 @@ def get_all_ranks_from_group(group=None):
return group_ranks
+def initialize_mesh_device(mesh_shape, mesh_dim_names):
+ global cdb
+ assert cdb is not None and cdb.is_initialized(
+ ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ mesh_device = None
+ if hasattr(cdb, 'init_device_mesh'):
+ utils.logger.info(f"Initializing mesh device with backend {cdb.name} \
+ with shape {mesh_shape} and dim names {mesh_dim_names}")
+ mesh_device = cdb.init_device_mesh(mesh_shape, mesh_dim_names)
+ else:
+ if get_rank() == 0:
+ utils.logger.warning_once(f"Backend {cdb.name} does not support mesh device initialization")
+ return mesh_device
+
+
# Main DeepSpeed Comms. public API.
def init_distributed(dist_backend=None,
auto_mpi_discovery=True,
diff --git a/deepspeed/comm/config.py b/deepspeed/comm/config.py
index 1c441bb6bfe9..57501c9dd237 100644
--- a/deepspeed/comm/config.py
+++ b/deepspeed/comm/config.py
@@ -3,20 +3,12 @@
# DeepSpeed Team
-from .constants import *
-from ..pydantic_v1 import BaseModel
-
+from deepspeed.runtime.config_utils import DeepSpeedConfigModel
-class CommsConfig(BaseModel):
-
- class Config:
- validate_all = True
- validate_assignment = True
- use_enum_values = True
- extra = 'forbid'
+from .constants import *
-class CommsLoggerConfig(CommsConfig):
+class CommsLoggerConfig(DeepSpeedConfigModel):
enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py
old mode 100644
new mode 100755
index 83754e98f033..ed2645d415c4
--- a/deepspeed/comm/torch.py
+++ b/deepspeed/comm/torch.py
@@ -386,6 +386,14 @@ def _reduce_op(self, op):
op = torch.distributed.ReduceOp.BXOR
return op
+ def init_device_mesh(self, mesh_shape, mesh_dim_names):
+ if not required_torch_version(min_version=2.2):
+ raise RuntimeError(f"Current torch version does not have device mesh"
+ f"api (torch.__version__: {torch.__version__})")
+ return torch.distributed.device_mesh.init_device_mesh(get_accelerator().current_device_name(),
+ mesh_shape,
+ mesh_dim_names=mesh_dim_names)
+
# This will become a light-weight wrapper around torch.distributed functions
# TODO: create some example to show how this wrapper can help profile communication
diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py
index 1d5018aaa75b..c7c7684fff79 100644
--- a/deepspeed/inference/config.py
+++ b/deepspeed/inference/config.py
@@ -5,38 +5,25 @@
import torch
import deepspeed
-from deepspeed.pydantic_v1 import Field, validator
+from pydantic import Field, field_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
-from typing import Dict, Union
+from typing import Dict, Union, Optional
from enum import Enum
class DtypeEnum(Enum):
- # The torch dtype must always be the first value (so we return torch.dtype)
- fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
- fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
- bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat"
- int8 = torch.int8, "torch.int8", "int8"
-
- # Copied from https://stackoverflow.com/a/43210118
- # Allows us to use multiple values for each Enum index and returns first
- # listed value when Enum is called
- def __new__(cls, *values):
- obj = object.__new__(cls)
- # first value is canonical value
- obj._value_ = values[0]
- for other_value in values[1:]:
- cls._value2member_map_[other_value] = obj
- obj._all_values = values
- return obj
-
- def __repr__(self):
- return "<%s.%s: %s>" % (
- self.__class__.__name__,
- self._name_,
- ", ".join([repr(v) for v in self._all_values]),
- )
+ fp16 = (torch.float16, "torch.float16", "fp16", "float16", "half")
+ fp32 = (torch.float32, "torch.float32", "fp32", "float32", "float")
+ bf16 = (torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat")
+ int8 = (torch.int8, "torch.int8", "int8")
+
+ @classmethod
+ def from_str(cls, value: str):
+ for dtype in cls:
+ if value in dtype.value:
+ return dtype
+ raise ValueError(f"'{value}' is not a valid DtypeEnum")
class MoETypeEnum(str, Enum):
@@ -91,24 +78,24 @@ class QuantTypeEnum(str, Enum):
class BaseQuantConfig(DeepSpeedConfigModel):
- enabled = True
- num_bits = 8
+ enabled: bool = True
+ num_bits: int = 8
q_type: QuantTypeEnum = QuantTypeEnum.sym
q_groups: int = 1
class WeightQuantConfig(BaseQuantConfig):
- enabled = True
+ enabled: bool = True
quantized_initialization: Dict = {}
post_init_quant: Dict = {}
class ActivationQuantConfig(BaseQuantConfig):
- enabled = True
+ enabled: bool = True
class QKVQuantConfig(DeepSpeedConfigModel):
- enabled = True
+ enabled: bool = True
class QuantizationConfig(DeepSpeedConfigModel):
@@ -120,9 +107,9 @@ class QuantizationConfig(DeepSpeedConfigModel):
# todo: brainstorm on how to do ckpt loading for DS inference
class InferenceCheckpointConfig(DeepSpeedConfigModel):
- checkpoint_dir: str = None
- save_mp_checkpoint_path: str = None
- base_dir: str = None
+ checkpoint_dir: Optional[str] = None
+ save_mp_checkpoint_path: Optional[str] = None
+ base_dir: Optional[str] = None
class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
@@ -136,7 +123,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
`(attention_output projection, transformer output projection)`
"""
- dtype: DtypeEnum = torch.float16
+ dtype: torch.dtype = torch.float16
"""
Desired model data type, will convert model to this type.
Supported target types: `torch.half`, `torch.int8`, `torch.float`
@@ -198,7 +185,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
"""
#todo: refactor the following 3 into the new checkpoint_config
- checkpoint: Union[str, Dict] = None
+ checkpoint: Optional[Union[str, Dict]] = None
"""
Path to deepspeed compatible checkpoint or path to JSON with load policy.
"""
@@ -214,7 +201,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
specifying whether the inference-module is created with empty or real Tensor
"""
- save_mp_checkpoint_path: str = None
+ save_mp_checkpoint_path: Optional[str] = None
"""
The path for which we want to save the loaded model with a checkpoint. This
feature is used for adjusting the parallelism degree to help alleviate the
@@ -243,19 +230,21 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
replace_method: str = Field(
"auto",
- deprecated=True,
- deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference")
+ json_schema_extra={
+ "deprecated": True,
+ "deprecated_msg": "This parameter is no longer needed, please remove from your call to DeepSpeed-inference"
+ })
- injection_policy: Dict = Field(None, alias="injection_dict")
+ injection_policy: Optional[Dict] = Field(None, alias="injection_dict")
"""
Dictionary mapping a client nn.Module to its corresponding injection
policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}`
"""
- injection_policy_tuple: tuple = None
+ injection_policy_tuple: Optional[tuple] = None
""" TODO: Add docs """
- config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor
+ config: Optional[Dict] = Field(None, alias="args") # todo: really no need for this field if we can refactor
max_out_tokens: int = Field(1024, alias="max_tokens")
"""
@@ -274,31 +263,49 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
transposed_mode: bool = Field(False, alias="transposed_mode")
- mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size")
+ mp_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.tp_size"})
"""
Desired model parallel size, default is 1 meaning no model parallelism.
Deprecated, please use the ``tensor_parallel` config to control model
parallelism.
"""
- mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu")
- ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size")
- ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group")
- ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group")
- moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts")
- moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type")
-
- @validator("moe")
+ mpu: object = Field(None, json_schema_extra={"deprecated": True, "new_param": "tensor_parallel.mpu"})
+ ep_size: int = Field(1, json_schema_extra={"deprecated": True, "new_param": "moe.ep_size"})
+ ep_group: object = Field(None,
+ alias="expert_group",
+ json_schema_extra={
+ "deprecated": True,
+ "new_param": "moe.ep_group"
+ })
+ ep_mp_group: object = Field(None,
+ alias="expert_mp_group",
+ json_schema_extra={
+ "deprecated": True,
+ "new_param": "moe.ep_mp_group"
+ })
+ moe_experts: list = Field([1], json_schema_extra={"deprecated": True, "new_param": "moe.moe_experts"})
+ moe_type: MoETypeEnum = Field(MoETypeEnum.standard,
+ json_schema_extra={
+ "deprecated": True,
+ "new_param": "moe.type"
+ })
+
+ @field_validator("dtype", mode="before")
+ def validate_dtype(cls, field_value, values):
+ if isinstance(field_value, str):
+ return DtypeEnum.from_str(field_value).value[0]
+ if isinstance(field_value, torch.dtype):
+ return field_value
+ raise TypeError(f"Invalid type for dtype: {type(field_value)}")
+
+ @field_validator("moe")
def moe_backward_compat(cls, field_value, values):
if isinstance(field_value, bool):
return DeepSpeedMoEConfig(moe=field_value)
return field_value
- @validator("use_triton")
+ @field_validator("use_triton")
def has_triton(cls, field_value, values):
if field_value and not deepspeed.HAS_TRITON:
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
return field_value
-
- class Config:
- # Get the str representation of the datatype for serialization
- json_encoders = {torch.dtype: lambda x: str(x)}
diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py
index 46a84c61f884..d88d99ebebfd 100644
--- a/deepspeed/inference/v2/checkpoint/huggingface_engine.py
+++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py
@@ -15,13 +15,13 @@
class HuggingFaceCheckpointEngine(CheckpointEngineBase):
- def __init__(self, model_name_or_path: str, auth_token: str = None) -> None:
+ def __init__(self, model_name_or_path: str, auth_token: str = None, **hf_kwargs) -> None:
super().__init__()
from transformers import AutoConfig, GenerationConfig
self.model_name_or_path = model_name_or_path
self.auth_token = auth_token
- self.model_config = AutoConfig.from_pretrained(self.model_name_or_path)
+ self.model_config = AutoConfig.from_pretrained(self.model_name_or_path, **hf_kwargs)
# Define this property here so we can use it in the model implementation
if not hasattr(self.model_config, "max_seq_length"):
if hasattr(self.model_config, "max_position_embeddings"):
@@ -108,6 +108,12 @@ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
for checkpoint in self._all_ckpt_paths:
inference_logger().info(f"Loading checkpoint: {checkpoint}")
checkpoint_sd = self._checkpoint_load_fn(checkpoint)
+
+ # If the model has tied embeddings, we need to make sure the lm_head weights are tied to the embeddings weights
+ if hasattr(self.model_config, "tie_word_embeddings") and self.model_config.tie_word_embeddings:
+ if self.model_config.model_type == "qwen2":
+ checkpoint_sd["lm_head.weight"] = checkpoint_sd["model.embed_tokens.weight"]
+
param_keys = list(checkpoint_sd.keys())
for param_name in param_keys:
param = checkpoint_sd[param_name]
diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py
index 85e4b7a0e0a0..325b57d8f56a 100644
--- a/deepspeed/inference/v2/config_v2.py
+++ b/deepspeed/inference/v2/config_v2.py
@@ -3,8 +3,9 @@
# DeepSpeed Team
+from pydantic import Field
from typing import Optional
-from deepspeed.pydantic_v1 import Field
+
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from .ragged import DSStateManagerConfig
diff --git a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h
index 2cc430ccfe34..f5104f899d9c 100644
--- a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h
+++ b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h
@@ -14,5 +14,8 @@
} else if (4 == N_TOP_K) { \
constexpr int CONST_TOP_K = 4; \
__VA_ARGS__(); \
+ } else if (8 == N_TOP_K) { \
+ constexpr int CONST_TOP_K = 8; \
+ __VA_ARGS__(); \
} \
}()
diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py
index 7e1ec1a13cb9..aacbec0bd3ae 100644
--- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py
+++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py
@@ -19,7 +19,7 @@ class BlockedRotaryEmbeddings(DSKernelBase):
supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_head_sizes = [64, 80, 96, 128]
- supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71]
+ supported_q_ratios = [1, 2, 4, 5, 6, 7, 8, 16, 29, 35, 36, 71]
def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype, rotary_dim: int,
theta_base: float) -> None:
diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu
index fbafece5ccf2..f7bc693eefee 100644
--- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu
+++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary_cuda.cu
@@ -265,6 +265,8 @@ void launch_kv_rotary_kernel(T* kv_cache,
LAUNCH_KV_ROTARY_FOR_Q_RATIO(2)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(4)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(5)
+ LAUNCH_KV_ROTARY_FOR_Q_RATIO(6)
+ LAUNCH_KV_ROTARY_FOR_Q_RATIO(7)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(8)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(16)
LAUNCH_KV_ROTARY_FOR_Q_RATIO(29)
diff --git a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py
index ebdb59bca920..c5e02adaffc4 100644
--- a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py
+++ b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py
@@ -27,9 +27,9 @@ class TensorMetadata(DeepSpeedConfigModel):
"""
A class to represent a tensor specification.
"""
- dtype: Optional[str]
- shape: Optional[Tuple[int, ...]]
- strides: Optional[Tuple[int, ...]]
+ dtype: Optional[str] = None
+ shape: Optional[Tuple[int, ...]] = None
+ strides: Optional[Tuple[int, ...]] = None
offset: int
@@ -37,7 +37,7 @@ class ParameterMetadata(DeepSpeedConfigModel):
"""
A class to represent a parameter specification.
"""
- core_param: TensorMetadata = None
+ core_param: Optional[TensorMetadata] = None
aux_params: Dict[str, TensorMetadata] = {}
diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py
index b4621257ff82..e499379da7e3 100644
--- a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py
+++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py
@@ -8,45 +8,45 @@
from ..common_parameters import *
from ..layer_container_base import LayerContainer
'''
- # HF Qwen1.5-MoE-A2.7B model looks like this:
+ # HF Qwen2-57B-A14B model looks like this:
Qwen2MoeForCausalLM(
(model): Qwen2MoeModel(
- (embed_tokens): Embedding(151936, 2048)
+ (embed_tokens): Embedding(151936, 3584)
(layers): ModuleList(
- (0-23): 24 x Qwen2MoeDecoderLayer(
+ (0-27): 28 x Qwen2MoeDecoderLayer(
(self_attn): Qwen2MoeSdpaAttention(
- (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
- (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
- (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
- (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
+ (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
+ (k_proj): Linear(in_features=3584, out_features=512, bias=True)
+ (v_proj): Linear(in_features=3584, out_features=512, bias=True)
+ (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
(rotary_emb): Qwen2MoeRotaryEmbedding()
)
(mlp): Qwen2MoeSparseMoeBlock(
- (gate): Linear(in_features=2048, out_features=60, bias=False)
+ (gate): Linear(in_features=3584, out_features=64, bias=False)
(experts): ModuleList(
- (0-59): 60 x Qwen2MoeMLP(
- (gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
- (up_proj): Linear(in_features=2048, out_features=1408, bias=False)
- (down_proj): Linear(in_features=1408, out_features=2048, bias=False)
+ (0-63): 64 x Qwen2MoeMLP(
+ (gate_proj): Linear(in_features=3584, out_features=2560, bias=False)
+ (up_proj): Linear(in_features=3584, out_features=2560, bias=False)
+ (down_proj): Linear(in_features=2560, out_features=3584, bias=False)
(act_fn): SiLU()
)
)
(shared_expert): Qwen2MoeMLP(
- (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
- (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
- (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
+ (gate_proj): Linear(in_features=3584, out_features=20480, bias=False)
+ (up_proj): Linear(in_features=3584, out_features=20480, bias=False)
+ (down_proj): Linear(in_features=20480, out_features=3584, bias=False)
(act_fn): SiLU()
)
- (shared_expert_gate): Linear(in_features=2048, out_features=1, bias=False)
+ (shared_expert_gate): Linear(in_features=3584, out_features=1, bias=False)
)
- (input_layernorm): Qwen2MoeRMSNorm()
- (post_attention_layernorm): Qwen2MoeRMSNorm()
+ (input_layernorm): Qwen2MoeRMSNorm((3584,), eps=1e-06)
+ (post_attention_layernorm): Qwen2MoeRMSNorm((3584,), eps=1e-06)
)
)
- (norm): Qwen2MoeRMSNorm()
+ (norm): Qwen2MoeRMSNorm((3584,), eps=1e-06)
)
- (lm_head): Linear(in_features=2048, out_features=151936, bias=False)
+ (lm_head): Linear(in_features=3584, out_features=151936, bias=False)
)
'''
diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py
index 7cddbf978369..c7841b24e5fc 100644
--- a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py
+++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py
@@ -73,7 +73,7 @@ def n_heads(self) -> int:
@property
def intermediate_dim(self) -> int:
- return self._config.intermediate_size
+ return self._config.shared_expert_intermediate_size
@property
def n_heads_kv(self) -> int:
diff --git a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py
index bd90cbd5d697..a9b01d1233cd 100644
--- a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py
+++ b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py
@@ -42,7 +42,7 @@ def supports_config(config: DSMoEConfig) -> bool:
if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16:
return False
- if config.top_k != 1 and config.top_k != 2 and config.top_k != 4:
+ if config.top_k != 1 and config.top_k != 2 and config.top_k != 4 and config.top_k != 8:
return False
return True
diff --git a/deepspeed/inference/v2/ragged/manager_configs.py b/deepspeed/inference/v2/ragged/manager_configs.py
index a5e98e5bcef1..17283b8bc0c4 100644
--- a/deepspeed/inference/v2/ragged/manager_configs.py
+++ b/deepspeed/inference/v2/ragged/manager_configs.py
@@ -6,7 +6,7 @@
from enum import Enum
from typing import Tuple
-from deepspeed.pydantic_v1 import PositiveInt, validator
+from pydantic import PositiveInt, model_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from ..inference_utils import DtypeEnum
@@ -173,11 +173,9 @@ class DSStateManagerConfig(DeepSpeedConfigModel):
Enable tracking for offloading KV-cache to host memory. Currently unsupported.
"""
- @validator("max_ragged_sequence_count")
- def max_ragged_sequence_count_validator(cls, v: int, values: dict):
+ @model_validator(mode="after")
+ def max_ragged_sequence_count_validator(self):
# If the attributes below failed their validation they won't appear in the values dict.
- if "max_tracked_sequences" in values and v > values["max_tracked_sequences"]:
- raise ValueError("max_ragged_sequence_count must be less than max_tracked_sequences")
- if "max_ragged_batch_size" in values and v > values["max_ragged_batch_size"]:
- raise ValueError("max_ragged_sequence_count must be less than max_ragged_batch_size")
- return v
+ assert self.max_ragged_sequence_count <= self.max_tracked_sequences, "max_ragged_sequence_count must be less than max_tracked_sequences"
+ assert self.max_ragged_sequence_count <= self.max_ragged_batch_size, "max_ragged_sequence_count must be less than max_ragged_batch_size"
+ return self
diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py
index 416f01b82e3d..340bc82de508 100644
--- a/deepspeed/moe/sharded_moe.py
+++ b/deepspeed/moe/sharded_moe.py
@@ -124,6 +124,8 @@ def einsum(rule, a, b):
return a.unsqueeze(2) * b.unsqueeze(1)
elif rule == 'se,se->s':
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
+ elif rule == 'se,sec->sec':
+ return a.unsqueeze(2) * b
elif rule == 'sec,sm->ecm':
s = a.shape[0]
e = a.shape[1]
@@ -191,8 +193,8 @@ def top1gating(logits: Tensor,
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# everything is in fp32 in this function
- gates = F.softmax(logits, dim=1)
+ gates = F.softmax(logits, dim=1)
capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
@@ -206,7 +208,7 @@ def top1gating(logits: Tensor,
mask1 = einsum("s,se->se", used_token, mask1)
# gating decisions
- exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
+ exp_counts = torch.sum(mask1, dim=0).detach().to(logits.device)
# if we don't want to drop any tokens
if not drop_tokens:
@@ -322,7 +324,7 @@ def top2gating(logits: Tensor,
l_aux = torch.mean(me * ce) * num_experts * num_experts
# gating decisions
- exp_counts = torch.sum(mask1 + mask2, dim=0)
+ exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)
if drop_tokens:
# Calculate configured capacity and remove locations outside capacity from mask
@@ -366,7 +368,82 @@ def top2gating(logits: Tensor,
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()
- return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu')
+ return l_aux, combine_weights, dispatch_mask, exp_counts
+
+
+def topkgating(
+ logits: Tensor,
+ k: int,
+ capacity_factor: float,
+ min_capacity: int,
+ drop_tokens: bool = True,
+ ep_group: Union[torch.distributed.ProcessGroup, None] = None,
+ drop_policy: str = "probs",
+) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ """Implements TopKGating on logits."""
+
+ # everything is in fp32 in this function
+ # get topk gates
+ top_gate, top_idx = torch.topk(logits, k=k, dim=1)
+ # gating decisions
+ gates = F.softmax(logits, dim=1)
+ num_experts = int(gates.shape[1])
+
+ # get topk mask
+ topk_masked_gates = torch.zeros_like(logits).scatter(1, top_idx, top_gate)
+
+ mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(1, top_idx, 1)
+
+ exp_counts = torch.sum(mask, dim=0).detach().to(logits.device)
+
+ # Compute l_aux
+ me = torch.mean(gates, dim=0)
+ ce = torch.mean(mask.float(), dim=0)
+ l_aux = torch.mean(me * ce) * num_experts * num_experts / k
+
+ if drop_tokens:
+ # Calculate configured capacity and remove locations outside capacity from mask
+ capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity))
+ # update mask and locations by capacity
+
+ if drop_policy == 'probs':
+ capacity_probs, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False)
+ capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
+ mask = torch.logical_and(mask, capacity_mask)
+ locations = torch.cumsum(mask, dim=0) - 1
+
+ elif drop_policy == "position":
+ locations = torch.cumsum(mask, dim=0) - 1
+ mask *= torch.lt(locations, capacity)
+ else:
+ raise ValueError(f"Invalid drop_policy: {drop_policy}")
+
+ else:
+ # Do not drop tokens - set capacity according to current expert assignments
+ new_capacity = torch.max(exp_counts)
+ if ep_group is not None:
+ dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group)
+ if groups._get_expert_model_parallel_world_size() == 1:
+ # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
+ # This is since we are going to activate drop_tokens() to drop duplicate tokens.
+ tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
+ new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
+ capacity = new_capacity
+
+ # normalize gates
+ gates_masked = gates * mask
+ gates_s = torch.sum(gates_masked, dim=-1, keepdim=True)
+ denom_s = torch.clamp(gates_s, min=torch.finfo(gates_masked.dtype).eps)
+ gates_masked = gates_masked / denom_s
+
+ # dispatch_mask
+ locations_sc = _one_hot_to_float((locations * mask), capacity)
+
+ combine_weights = torch.einsum("se,sec->sec", gates_masked, locations_sc)
+
+ dispatch_mask = combine_weights.bool()
+
+ return l_aux, combine_weights, dispatch_mask, exp_counts
class TopKGate(Module):
@@ -401,9 +478,6 @@ def __init__(self,
top2_2nd_expert_sampling: bool = True) -> None:
super().__init__()
- # Only top-1 and top-2 are supported at the moment.
- if k != 1 and k != 2:
- raise ValueError('Only top-1 and top-2 gatings are supported.')
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
self.ep_group = ep_group
self.k = k
@@ -441,9 +515,13 @@ def forward(self,
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
self.drop_tokens, self.use_rts, self.ep_group, use_tutel)
- else:
+ elif self.k == 2:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling)
+ else:
+ gate_output = topkgating(logits, self.k,
+ self.capacity_factor if self.training else self.eval_capacity_factor,
+ self.min_capacity, self.drop_tokens, self.ep_group)
if self.wall_clock_breakdown:
self.timers(TOPK_GATE_TIMER).stop()
diff --git a/deepspeed/monitor/config.py b/deepspeed/monitor/config.py
index c4200877089a..960ce1ba997a 100644
--- a/deepspeed/monitor/config.py
+++ b/deepspeed/monitor/config.py
@@ -5,7 +5,7 @@
from typing import Optional
-from deepspeed.pydantic_v1 import root_validator
+from pydantic import model_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
@@ -36,10 +36,10 @@ class WandbConfig(DeepSpeedConfigModel):
enabled: bool = False
""" Whether logging to WandB is enabled. Requires `wandb` package is installed. """
- group: str = None
+ group: Optional[str] = None
""" Name for the WandB group. This can be used to group together runs. """
- team: str = None
+ team: Optional[str] = None
""" Name for the WandB team. """
project: str = "deepspeed"
@@ -137,8 +137,8 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
csv_monitor: CSVConfig = {}
""" Local CSV output of monitoring data. """
- @root_validator
- def check_enabled(cls, values):
- values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get(
- "csv_monitor").enabled or values.get("comet").enabled
- return values
+ @model_validator(mode="after")
+ def check_enabled(self):
+ enabled = self.tensorboard.enabled or self.wandb.enabled or self.csv_monitor.enabled or self.comet.enabled
+ self.__dict__["enabled"] = enabled
+ return self
diff --git a/deepspeed/ops/gds/__init__.py b/deepspeed/ops/gds/__init__.py
new file mode 100755
index 000000000000..3c0762c81076
--- /dev/null
+++ b/deepspeed/ops/gds/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from ..op_builder import GDSBuilder
diff --git a/deepspeed/pydantic_v1.py b/deepspeed/pydantic_v1.py
deleted file mode 100644
index 6aba072ad929..000000000000
--- a/deepspeed/pydantic_v1.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# SPDX-License-Identifier: Apache-2.0
-
-# DeepSpeed Team
-"""Pydantic v1 compatibility module.
-
-Pydantic v2 introduced breaking changes that hinder its adoption:
-https://docs.pydantic.dev/latest/migration/. To provide deepspeed users the option to
-migrate to pydantic v2 on their own timeline, deepspeed uses this compatibility module
-as a pydantic-version-agnostic alias for pydantic's v1 API.
-"""
-
-try:
- from pydantic.v1 import * # noqa: F401
-except ImportError:
- from pydantic import * # noqa: F401
diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py
index b49b4a8b6086..8be2f7ac4055 100755
--- a/deepspeed/runtime/config.py
+++ b/deepspeed/runtime/config.py
@@ -705,7 +705,7 @@ def write_config(self, filename):
class DeepSpeedConfig(object):
- def __init__(self, config: Union[str, dict], mpu=None):
+ def __init__(self, config: Union[str, dict], mpu=None, mesh_device=None):
super(DeepSpeedConfig, self).__init__()
if isinstance(config, dict):
self._param_dict = config
@@ -721,14 +721,16 @@ def __init__(self, config: Union[str, dict], mpu=None):
)
try:
self.global_rank = dist.get_rank()
- if mpu is None:
- self.world_size = dist.get_world_size()
- else:
+ if mpu is not None:
self.world_size = mpu.get_data_parallel_world_size()
+ elif mesh_device is not None:
+ self.world_size = dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
+ else:
+ self.world_size = dist.get_world_size()
except:
self.global_rank = 0
self.world_size = 1
-
+ logger.info(f"Config mesh_device {mesh_device} world_size = {self.world_size}")
# If elastic-mode enabled, update compute + update _param_dict
self.elasticity_enabled = elasticity_enabled(self._param_dict)
if self.elasticity_enabled:
diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py
index 5522a8e79d69..d5c3a1548360 100755
--- a/deepspeed/runtime/config_utils.py
+++ b/deepspeed/runtime/config_utils.py
@@ -5,11 +5,12 @@
"""
Collection of DeepSpeed configuration utilities
"""
-import json
import collections
-import collections.abc
+import json
+import torch
from functools import reduce
-from deepspeed.pydantic_v1 import BaseModel
+from pydantic import BaseModel, ConfigDict, field_serializer
+
from deepspeed.utils import logger
@@ -54,67 +55,73 @@ def __init__(self, strict=False, **data):
if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models
data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")}
super().__init__(**data)
- self._deprecated_fields_check(self)
+ self._deprecated_fields_check()
- def _process_deprecated_field(self, pydantic_config, field):
+ def _process_deprecated_field(self, dep_field):
# Get information about the deprecated field
- fields_set = pydantic_config.__fields_set__
- dep_param = field.name
- kwargs = field.field_info.extra
+ pydantic_config = self
+ fields_set = pydantic_config.model_fields_set
+ kwargs = pydantic_config.model_fields[dep_field].json_schema_extra
new_param_fn = kwargs.get("new_param_fn", lambda x: x)
- param_value = new_param_fn(getattr(pydantic_config, dep_param))
- new_param = kwargs.get("new_param", "")
+ param_value = new_param_fn(getattr(pydantic_config, dep_field))
+ new_field = kwargs.get("new_param", "")
dep_msg = kwargs.get("deprecated_msg", "")
- if dep_param in fields_set:
- logger.warning(f"Config parameter {dep_param} is deprecated" +
- (f" use {new_param} instead" if new_param else "") + (f". {dep_msg}" if dep_msg else ""))
+ if dep_field in fields_set:
+ logger.warning(f"Config parameter {dep_field} is deprecated" +
+ (f" use {new_field} instead" if new_field else "") + (f". {dep_msg}" if dep_msg else ""))
# Check if there is a new param and if it should be set with a value
- if new_param and kwargs.get("set_new_param", True):
+ if new_field and kwargs.get("set_new_param", True):
# Remove the deprecate field if there is a replacing field
try:
- delattr(pydantic_config, dep_param)
+ delattr(pydantic_config, dep_field)
except Exception as e:
- logger.error(f"Tried removing deprecated '{dep_param}' from config")
+ logger.error(f"Tried removing deprecated '{dep_field}' from config")
raise e
# Set new param value
- new_param_nested = new_param.split(".")
+ new_param_nested = new_field.split(".")
if len(new_param_nested) > 1:
# If the new param exists in a subconfig, we need to get
# the fields set for that subconfig
pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config)
- fields_set = pydantic_config.__fields_set__
+ fields_set = pydantic_config.model_fields_set
new_param_name = new_param_nested[-1]
assert (
new_param_name not in fields_set
- ), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together"
+ ), f"Cannot provide deprecated parameter '{dep_field}' and replacing parameter '{new_field}' together"
# A custom function for converting the old param value to new param value can be provided
try:
setattr(pydantic_config, new_param_name, param_value)
except Exception as e:
- logger.error(f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'")
+ logger.error(f"Tried setting value for '{new_field}' with value from deprecated '{dep_field}'")
raise e
- def _deprecated_fields_check(self, pydantic_config):
- fields = pydantic_config.__fields__
- for field in fields.values():
- if field.field_info.extra.get("deprecated", False):
- self._process_deprecated_field(pydantic_config, field)
+ def _deprecated_fields_check(self):
+ fields = self.model_fields
+ for field_name, field_info in fields.items():
+ if field_info.json_schema_extra and field_info.json_schema_extra.get("deprecated", False):
+ self._process_deprecated_field(field_name)
+
+ model_config = ConfigDict(
+ validate_default=True,
+ validate_assignment=True,
+ use_enum_values=True,
+ populate_by_name=True,
+ extra="forbid",
+ arbitrary_types_allowed=True,
+ protected_namespaces=(),
+ )
- class Config:
- validate_all = True
- validate_assignment = True
- use_enum_values = True
- allow_population_by_field_name = True
- extra = "forbid"
- arbitrary_types_allowed = True
+ @field_serializer("dtype", check_fields=False)
+ def serialize_torch_dtype(dtype: torch.dtype) -> str:
+ return str(dtype)
def get_config_default(config, field_name):
- assert field_name in config.__fields__, f"'{field_name}' is not a field in {config}"
- assert not config.__fields__.get(
- field_name).required, f"'{field_name}' is a required field and does not have a default value"
- return config.__fields__.get(field_name).default
+ assert field_name in config.model_fields, f"'{field_name}' is not a field in {config}"
+ assert not config.model_fields.get(
+ field_name).is_required(), f"'{field_name}' is a required field and does not have a default value"
+ return config.model_fields.get(field_name).get_default()
class pp_int(int):
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
old mode 100644
new mode 100755
index d2839a8f5d7c..61e6da2663cf
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -194,6 +194,7 @@ def __init__(self,
collate_fn=None,
config=None,
config_class=None,
+ mesh_device=None,
dont_change_device=False):
super(DeepSpeedEngine, self).__init__()
self.dont_change_device = dont_change_device
@@ -233,10 +234,14 @@ def __init__(self,
self._is_gradient_accumulation_boundary = None
self.scale_wrt_gas = None
self.losses = None
+ self.mesh_device = mesh_device
# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)
+ if self.mesh_device:
+ groups.mesh_device = self.mesh_device
+
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
@@ -615,6 +620,9 @@ def random_ltd_initialize(self):
raise ValueError(f'not yet support')
#self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler)
+ def get_sequence_parallel_group(self):
+ return self.seq_parallel_group
+
def wall_clock_breakdown(self):
return self._config.wall_clock_breakdown
@@ -1009,13 +1017,13 @@ def _set_distributed_vars(self, args):
device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank
if device_rank >= 0:
get_accelerator().set_device(device_rank)
- self.device = torch.device(get_accelerator().device_name(), device_rank)
+ self.device = torch.device(get_accelerator().device_name(device_rank))
self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
else:
self.world_size = 1
self.global_rank = 0
- self.device = torch.device(get_accelerator().device_name())
+ self.device = get_accelerator().device()
# Configure based on command line arguments
def _configure_with_arguments(self, args, mpu):
@@ -1187,6 +1195,7 @@ def _configure_distributed_model(self, model):
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()
if self.sequence_parallel_size > 1:
self.communication_data_type = self._config.seq_parallel_communication_data_type
+ self.seq_parallel_group = groups._get_sequence_parallel_group()
if not (self.amp_enabled() or is_zero_init_model):
self._broadcast_model()
diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py
index c627846b743c..26196ff37ac4 100644
--- a/deepspeed/runtime/pipe/engine.py
+++ b/deepspeed/runtime/pipe/engine.py
@@ -5,6 +5,8 @@
from types import MethodType
from collections import OrderedDict
+from functools import reduce
+from operator import mul
import torch
from deepspeed import comm as dist
@@ -40,6 +42,9 @@
PIPE_RECV_INPUT_TIMER = 'pipe_recv_input'
PIPE_RECV_GRAD_TIMER = 'pipe_recv_grad'
+# The buffer size to store the meta data for each tensor.
+TENSOR_META_SIZE = 256
+
def is_even(number):
return number % 2 == 0
@@ -179,6 +184,7 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
}
self.pipe_recv_buf = None
self.grad_layer = None
+ self._grad_layer_buf = []
self.meta_buffer = None
@@ -250,6 +256,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
self.timers(STEP_MICRO_TIMER).start()
self.timers(STEP_MICRO_TIMER).stop()
+ self.dynamic_shape = self.module.dynamic_shape
+
def set_has_attention_mask(self, value):
assert isinstance(value, bool)
self.has_attention_mask = value
@@ -318,6 +326,7 @@ def reset_activation_shape(self):
self.first_output_send = True
self.pipe_recv_buf = None
self.grad_layer = None
+ self._grad_layer_buf = []
self.meta_buffer = None
self.pipe_partition_input_meta_cache = None
@@ -926,51 +935,38 @@ def _send_tensor_meta(self, buffer, recv_stage):
* ndims
* shape
"""
- send_bytes = 0
+ meta_buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device)
if isinstance(buffer, torch.Tensor):
- type_tensor = torch.LongTensor(data=[0]).to(self.device)
- p2p.send(type_tensor, recv_stage)
- send_shape = torch.LongTensor(data=buffer.size()).to(self.device)
- send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)
- p2p.send(send_ndims, recv_stage)
- p2p.send(send_shape, recv_stage)
- send_bytes += _tensor_bytes(buffer)
- elif isinstance(buffer, list):
- assert (False)
- type_tensor = torch.LongTensor(data=[1]).to(self.device)
- p2p.send(type_tensor, recv_stage)
- count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
- p2p.send(count_tensor, recv_stage)
- for tensor in buffer:
- assert isinstance(tensor, torch.Tensor)
- send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
- send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
- p2p.send(send_ndims, recv_stage)
- p2p.send(send_shape, recv_stage)
- send_bytes += _tensor_bytes(tensor)
+ meta_buf_list = [
+ 0, # type of data (0: tensor, 1: list (unused), 2: tuple)
+ self.DTYPE_TO_ID[buffer.dtype], # dtype
+ len(buffer.size()) # ndims
+ ]
+ meta_buf_list.extend(buffer.size())
+ assert len(
+ meta_buf_list
+ ) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}"
+ meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32))
+ p2p.send(meta_buffer, recv_stage)
+
elif isinstance(buffer, tuple):
- type_tensor = torch.LongTensor(data=[2]).to(self.device)
- p2p.send(type_tensor, recv_stage)
- count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
- p2p.send(count_tensor, recv_stage)
- for idx, tensor in enumerate(buffer):
+ meta_buf_list = [
+ 2, # type of data (0: tensor, 1: list (unused), 2: tuple)
+ len(buffer) # num_tensors
+ ]
+
+ for tensor in buffer:
assert isinstance(tensor, torch.Tensor)
- send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
- send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
- send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(self.device)
- p2p.send(send_dtype, recv_stage)
- p2p.send(send_ndims, recv_stage)
- p2p.send(send_shape, recv_stage)
- # Useful for performance debugging.
- '''
- new_bytes = _tensor_bytes(tensor)
- send_bytes += _tensor_bytes(tensor)
- # Useful for performance debugging.
- if self.grid.data_parallel_id == 0:
- print(
- f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB'
- )
- '''
+ meta_buf_list.append(self.DTYPE_TO_ID[tensor.dtype])
+ meta_buf_list.append(len(tensor.size()))
+ meta_buf_list.extend(tensor.size())
+
+ assert len(
+ meta_buf_list
+ ) <= TENSOR_META_SIZE, f"Buffer for metadata is too small. Current buffer size: {TENSOR_META_SIZE} but required {len(meta_buf_list)}"
+ meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32))
+ p2p.send(meta_buffer, recv_stage)
+
else:
raise NotImplementedError(f'Could not send meta type {type(buffer)}')
@@ -983,49 +979,35 @@ def _send_tensor_meta(self, buffer, recv_stage):
def _recv_tensor_meta(self, send_stage):
"""Receive metadata about upcoming p2p transfers and return allocated buffers.
- Metadata is communicated in this order:
- * type (0: tensor, 1: list)
- * num_tensors if type=list
- foreach tensor in buffer:
- * ndims
- * shape
-
Returns:
Allocated buffer for receiving from send_stage.
"""
+ buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device)
+ p2p.recv(buffer, send_stage)
- type_tensor = torch.LongTensor(data=[0]).to(self.device)
- p2p.recv(type_tensor, send_stage)
- recv_type = type_tensor.item()
+ recv_type = buffer[0].item()
# A single tensor will be sent.
if recv_type == 0:
- recv_ndims = torch.LongTensor(data=[0]).to(self.device)
- p2p.recv(recv_ndims, send_stage)
- recv_ndims = recv_ndims.item()
- recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
- p2p.recv(recv_shape, send_stage)
- recv_shape = recv_shape.tolist()
- return self._allocate_buffer(recv_shape, num_buffers=1)[0]
-
- # List or tuple of tensors
+ recv_dtype = self.ID_TO_DTYPE[buffer[1].item()]
+ recv_ndims = buffer[2].item()
+ recv_shape = buffer[3:3 + recv_ndims].tolist()
+ return self._allocate_or_extend_buffers(0, recv_shape, recv_dtype)
+
+ # List or tuple of tensors (recv_type == 1 (list) is currently unused)
elif recv_type == 1 or recv_type == 2:
- count_tensor = torch.LongTensor(data=[0]).to(self.device)
- p2p.recv(count_tensor, send_stage)
- num_tensors = count_tensor.item()
- recv_shapes_and_dtypes = []
+ num_tensors = buffer[1].item()
+
+ buffers = []
+ offset = 2
for idx in range(num_tensors):
- recv_dtype = torch.LongTensor(data=[0]).to(self.device)
- p2p.recv(recv_dtype, send_stage)
- recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
- recv_ndims = torch.LongTensor(data=[0]).to(self.device)
- p2p.recv(recv_ndims, send_stage)
- recv_ndims = recv_ndims.item()
- recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
- p2p.recv(recv_shape, send_stage)
- recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype))
-
- buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0]
+ recv_dtype = self.ID_TO_DTYPE[buffer[offset].item()]
+ recv_ndims = buffer[offset + 1].item()
+ recv_shape = buffer[offset + 2:offset + 2 + recv_ndims].tolist()
+ offset += 2 + recv_ndims
+
+ buffers.append(self._allocate_or_extend_buffers(idx, recv_shape, recv_dtype))
+
# Convert to tuples if requested.
if recv_type == 2:
buffers = tuple(buffers)
@@ -1048,7 +1030,7 @@ def _exec_send_activations(self, buffer_id):
outputs[-1] = outputs[-1].half()
outputs = tuple(outputs)
- if self.first_output_send:
+ if self.dynamic_shape or self.first_output_send:
self.first_output_send = False
self._send_tensor_meta(outputs, self.next_stage)
@@ -1133,7 +1115,7 @@ def _exec_recv_activations(self, buffer_id):
recvd = None
# Allocate the buffer if necessary
- if self.pipe_recv_buf is None:
+ if self.dynamic_shape or self.pipe_recv_buf is None:
self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage)
if isinstance(self.pipe_recv_buf, torch.Tensor):
@@ -1188,10 +1170,9 @@ def _exec_recv_grads(self, buffer_id):
self.pipe_buffers['outputs'][buffer_id] = outputs
# Allocate gradient if necessary
- if self.grad_layer is None:
+ if self.dynamic_shape or self.grad_layer is None:
if isinstance(outputs, torch.Tensor):
- s = list(outputs.size())
- self.grad_layer = self._allocate_buffer(s, dtype=outputs.dtype, num_buffers=1)[0]
+ self.grad_layer = self._allocate_or_extend_buffers(0, list(outputs.size()), outputs.dtype)
else:
# XXX This is a HACK
# When we exchange activations/gradients, the two pipe stages
@@ -1213,7 +1194,11 @@ def _exec_recv_grads(self, buffer_id):
for t in outputs[2:] if t.is_floating_point()]
else:
sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()]
- self.grad_layer = self._allocate_buffers(sizes_and_dtypes, num_buffers=1)[0]
+
+ self.grad_layer = [
+ self._allocate_or_extend_buffers(i, size, dtype)
+ for i, (size, dtype) in enumerate(sizes_and_dtypes)
+ ]
if isinstance(self.grad_layer, torch.Tensor):
p2p.recv(self.grad_layer, self.next_stage)
@@ -1294,16 +1279,17 @@ def _allocate_buffer(self, shape, num_buffers=-1, **kwargs):
buffers.append(self._allocate_zeros(shape, **kwargs))
return buffers
- def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=-1):
- buffers = []
- if num_buffers == -1:
- num_buffers = self.num_pipe_buffers
- for count in range(num_buffers):
- buffer = []
- for shape, dtype in shapes_and_dtypes:
- buffer.append(self._allocate_zeros(shape, dtype=dtype, requires_grad=requires_grad))
- buffers.append(buffer)
- return buffers
+ def _allocate_or_extend_buffers(self, idx, shape, dtype):
+ numel = reduce(mul, shape) if len(shape) > 0 else 1
+ if len(self._grad_layer_buf) <= idx or self._grad_layer_buf[idx].numel() < numel:
+ new_buf = self._allocate_buffer(shape, dtype=dtype, num_buffers=1)[0]
+ if len(self._grad_layer_buf) <= idx:
+ self._grad_layer_buf.append(new_buf)
+ else:
+ self._grad_layer_buf[idx] = new_buf
+ return self._grad_layer_buf[idx]
+ else:
+ return self._grad_layer_buf[idx].flatten()[:numel].view(shape)
def forward(self, *args, **kwargs):
"""Disabled for pipeline parallel training. See ``train_batch()``. """
diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py
index 3c25cbee66ec..31fec30be788 100644
--- a/deepspeed/runtime/pipe/module.py
+++ b/deepspeed/runtime/pipe/module.py
@@ -117,6 +117,7 @@ def forward(self, inputs):
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
+ dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact.
"""
def __init__(self,
@@ -130,7 +131,8 @@ def __init__(self,
partition_method='parameters',
activation_checkpoint_interval=0,
activation_checkpoint_func=checkpointing.checkpoint,
- checkpointable_layers=None):
+ checkpointable_layers=None,
+ dynamic_shape=False):
super().__init__()
@@ -213,6 +215,8 @@ def __init__(self,
self.tied_comms = self._index_tied_modules()
self._synchronize_tied_weights()
+ self.dynamic_shape = dynamic_shape
+
def _precompute_checkpointable_values(self):
if self.activation_checkpoint_interval > 0 and self.is_checkpointable_results_interval != self.activation_checkpoint_interval:
num_layers = len(self.forward_funcs)
diff --git a/deepspeed/runtime/swap_tensor/aio_config.py b/deepspeed/runtime/swap_tensor/aio_config.py
index df4a38380089..46c3f2a0c954 100644
--- a/deepspeed/runtime/swap_tensor/aio_config.py
+++ b/deepspeed/runtime/swap_tensor/aio_config.py
@@ -5,25 +5,33 @@
from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.runtime.swap_tensor.constants import *
+from deepspeed.accelerator import get_accelerator
AIO_DEFAULT_DICT = {
AIO_BLOCK_SIZE: AIO_BLOCK_SIZE_DEFAULT,
AIO_QUEUE_DEPTH: AIO_QUEUE_DEPTH_DEFAULT,
AIO_THREAD_COUNT: AIO_THREAD_COUNT_DEFAULT,
AIO_SINGLE_SUBMIT: AIO_SINGLE_SUBMIT_DEFAULT,
- AIO_OVERLAP_EVENTS: AIO_OVERLAP_EVENTS_DEFAULT
+ AIO_OVERLAP_EVENTS: AIO_OVERLAP_EVENTS_DEFAULT,
+ AIO_USE_GDS: AIO_USE_GDS_DEFAULT
}
def get_aio_config(param_dict):
if AIO in param_dict.keys() and param_dict[AIO] is not None:
aio_dict = param_dict[AIO]
- return {
+ aio_config = {
AIO_BLOCK_SIZE: get_scalar_param(aio_dict, AIO_BLOCK_SIZE, AIO_BLOCK_SIZE_DEFAULT),
AIO_QUEUE_DEPTH: get_scalar_param(aio_dict, AIO_QUEUE_DEPTH, AIO_QUEUE_DEPTH_DEFAULT),
AIO_THREAD_COUNT: get_scalar_param(aio_dict, AIO_THREAD_COUNT, AIO_THREAD_COUNT_DEFAULT),
AIO_SINGLE_SUBMIT: get_scalar_param(aio_dict, AIO_SINGLE_SUBMIT, AIO_SINGLE_SUBMIT_DEFAULT),
- AIO_OVERLAP_EVENTS: get_scalar_param(aio_dict, AIO_OVERLAP_EVENTS, AIO_OVERLAP_EVENTS_DEFAULT)
+ AIO_OVERLAP_EVENTS: get_scalar_param(aio_dict, AIO_OVERLAP_EVENTS, AIO_OVERLAP_EVENTS_DEFAULT),
+ AIO_USE_GDS: get_scalar_param(aio_dict, AIO_USE_GDS, AIO_USE_GDS_DEFAULT)
}
+ if aio_config[AIO_USE_GDS]:
+ assert get_accelerator().device_name() == 'cuda', 'GDS currently only supported for CUDA accelerator'
+
+ return aio_config
+
return AIO_DEFAULT_DICT
diff --git a/deepspeed/runtime/swap_tensor/constants.py b/deepspeed/runtime/swap_tensor/constants.py
index 4c9722bc4e4f..cee20ac7b78c 100644
--- a/deepspeed/runtime/swap_tensor/constants.py
+++ b/deepspeed/runtime/swap_tensor/constants.py
@@ -11,7 +11,8 @@
"queue_depth": 8,
"thread_count": 1,
"single_submit": false,
- "overlap_events": true
+ "overlap_events": true,
+ "use_gds": false
}
'''
AIO = "aio"
@@ -25,3 +26,5 @@
AIO_SINGLE_SUBMIT_DEFAULT = False
AIO_OVERLAP_EVENTS = "overlap_events"
AIO_OVERLAP_EVENTS_DEFAULT = True
+AIO_USE_GDS = "use_gds"
+AIO_USE_GDS_DEFAULT = False
diff --git a/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py b/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py
index fcc6a272883f..26fbf6164d54 100644
--- a/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py
+++ b/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py
@@ -13,6 +13,7 @@
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import AsyncIOBuilder
+from deepspeed.ops.op_builder import GDSBuilder
from .constants import *
from .utils import swap_in_tensors, swap_out_tensors, MIN_AIO_BYTES, AIO_ALIGNED_BYTES, print_object, SwapBufferPool
@@ -37,8 +38,6 @@ class AsyncPartitionedParameterSwapper(object):
def __init__(self, ds_config, model_dtype):
- aio_op = AsyncIOBuilder().load(verbose=False)
- self.aio_handle = aio_op.aio_handle
self.dtype = model_dtype
#set swap buffers, create aio handles
@@ -93,6 +92,10 @@ def _configure_aio(self, ds_config):
self.aio_config = ds_config.aio_config
+ self.use_gds = self.aio_config[AIO_USE_GDS]
+ self.aio_handle = GDSBuilder().load(verbose=False).gds_handle if self.use_gds else AsyncIOBuilder().load(
+ verbose=False).aio_handle
+
# Read/Write alignment for each thread during Intra-request parallelism
self.min_aio_bytes = max(MIN_AIO_BYTES, self.aio_config[AIO_BLOCK_SIZE])
self.aligned_bytes = AIO_ALIGNED_BYTES * self.aio_config[AIO_THREAD_COUNT]
@@ -104,11 +107,6 @@ def _configure_aio(self, ds_config):
self.available_buffer_ids = [i for i in range(self.param_buffer_count)]
self.reserved_buffer_ids = []
- self.buffers = get_accelerator().pin_memory(torch.empty(int(self.aligned_elements_per_buffer *
- self.param_buffer_count),
- dtype=self.dtype,
- requires_grad=False),
- align_bytes=0)
self.aio_read_handle = self.aio_handle(self.aio_config[AIO_BLOCK_SIZE], self.aio_config[AIO_QUEUE_DEPTH],
self.aio_config[AIO_SINGLE_SUBMIT], self.aio_config[AIO_OVERLAP_EVENTS],
@@ -118,6 +116,19 @@ def _configure_aio(self, ds_config):
self.aio_config[AIO_SINGLE_SUBMIT],
self.aio_config[AIO_OVERLAP_EVENTS], self.aio_config[AIO_THREAD_COUNT])
+ if self.use_gds:
+ self.buffers = torch.empty(int(self.aligned_elements_per_buffer * self.param_buffer_count),
+ dtype=self.dtype,
+ device=get_accelerator().device_name(),
+ requires_grad=False)
+ self.aio_read_handle.pin_device_tensor(self.buffers)
+ else:
+ self.buffers = get_accelerator().pin_memory(torch.empty(int(self.aligned_elements_per_buffer *
+ self.param_buffer_count),
+ dtype=self.dtype,
+ requires_grad=False),
+ align_bytes=0)
+
self.swap_out_params = []
#Check if partitioned param or numel in a tensor is swappable or not
diff --git a/deepspeed/runtime/zero/__init__.py b/deepspeed/runtime/zero/__init__.py
index 1ccca09a9e69..23fcf9ec13fb 100644
--- a/deepspeed/runtime/zero/__init__.py
+++ b/deepspeed/runtime/zero/__init__.py
@@ -13,3 +13,5 @@
from .tiling import TiledLinearReturnBias
from .mics import MiCS_Init
+
+from .stage3 import unwrap_model_for_generation
diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py
index 2089d59dbce4..1cfcd784e2ce 100644
--- a/deepspeed/runtime/zero/config.py
+++ b/deepspeed/runtime/zero/config.py
@@ -6,7 +6,7 @@
import sys
from typing import Optional
from enum import Enum
-from deepspeed.pydantic_v1 import Field, validator, root_validator
+from pydantic import Field, model_validator
from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel
from deepspeed.utils import logger
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
@@ -30,7 +30,7 @@
"reduce_bucket_size": 500000000,
"load_from_fp32_weights": [true|false],
"cpu_offload": [true|false] (deprecated),
- "cpu_offload_params" : [true|false] (deprecated),
+ "cpu_offload_param" : [true|false] (deprecated),
"cpu_offload_use_pin_memory": [true|false] (deprecated),
"sub_group_size" : 1000000000000,
"offload_param": {...},
@@ -128,7 +128,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
the allgather for large model sizes
"""
- overlap_comm: bool = None # None for dynamic default value (see validator `overlap_comm_valid` below)
+ overlap_comm: Optional[bool] = None # None for dynamic default value (see validator `overlap_comm_valid` below)
"""
Attempts to overlap the reduction of the gradients with backward computation
"""
@@ -168,27 +168,37 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
parameters). Used by ZeRO3-Offload and ZeRO-Infinity
"""
- cpu_offload_param: bool = Field(
+ cpu_offload_param: Optional[bool] = Field(
None,
- deprecated=True,
- new_param="offload_param",
- new_param_fn=(lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu) if val else None),
+ json_schema_extra={
+ "deprecated": True,
+ "new_param": "offload_param",
+ "new_param_fn": (lambda val: DeepSpeedZeroOffloadParamConfig(device=OffloadDeviceEnum.cpu)
+ if val else None)
+ },
)
""" Deprecated, please use ``offload_param`` """
- cpu_offload_use_pin_memory: bool = Field(
+ cpu_offload_use_pin_memory: Optional[bool] = Field(
None,
- deprecated=True,
- new_param="offload_param or offload_optimizer",
- set_new_param=False,
+ json_schema_extra={
+ "deprecated": True,
+ "new_param": "offload_param or offload_optimizer",
+ "set_new_param": False
+ },
)
""" Deprecated, please use ``offload_param`` or ``offload_optimizer`` """
- cpu_offload: bool = Field(
+ cpu_offload: Optional[bool] = Field(
None,
- deprecated=True,
- new_param="offload_optimizer",
- new_param_fn=(lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu) if val else None),
+ json_schema_extra={
+ "deprecated":
+ True,
+ "new_param":
+ "offload_optimizer",
+ "new_param_fn": (lambda val: DeepSpeedZeroOffloadOptimizerConfig(device=OffloadDeviceEnum.cpu)
+ if val else None)
+ },
)
""" Deprecated, please use ``offload_optimizer`` """
@@ -242,8 +252,10 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
"""
stage3_gather_fp16_weights_on_model_save: bool = Field(False,
- deprecated=True,
- new_param="gather_16bit_weights_on_model_save")
+ json_schema_extra={
+ "deprecated": True,
+ "new_param": "gather_16bit_weights_on_model_save"
+ })
""" Deprecated, please use ``gather_16bit_weights_on_model_save`` """
ignore_unused_parameters: bool = True
@@ -309,16 +321,15 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
"""
# Validators
- @validator("overlap_comm")
- def overlap_comm_valid(cls, field_value, values):
- if field_value is None:
- assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'"
- field_value = values["stage"] == ZeroStageEnum.weights
- return field_value
-
- @root_validator
- def offload_ratio_check(cls, values):
- offload_config = getattr(values, "offload_optimizer", {})
+ @model_validator(mode="after")
+ def overlap_comm_valid(self):
+ if self.overlap_comm is None:
+ self.overlap_comm = self.stage == ZeroStageEnum.weights
+ return self
+
+ @model_validator(mode="after")
+ def offload_ratio_check(self):
+ offload_config = self.offload_optimizer
if offload_config and offload_config.ratio < 1.0:
- assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
- return values
+ assert self.stage == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3."
+ return self
diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py
index e9dd78864cde..8c8db60768eb 100644
--- a/deepspeed/runtime/zero/linear.py
+++ b/deepspeed/runtime/zero/linear.py
@@ -16,6 +16,7 @@
#when implemented outside of torch.autograd.Function
import math
+import functools
import torch
from torch import Tensor
@@ -33,8 +34,14 @@ def print_rank_0(message, debug=False, force=False):
try:
- autocast_custom_fwd = get_accelerator().amp().custom_fwd
- autocast_custom_bwd = get_accelerator().amp().custom_bwd
+ # Fix `torch.[device].amp.custom_fwd/bwd` FutureWarning in torch 2.4
+ if hasattr(torch, 'amp') and hasattr(torch.amp, 'custom_fwd') and hasattr(torch.amp, 'custom_bwd'):
+ autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name())
+ autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name())
+ else:
+ # original implementation
+ autocast_custom_fwd = get_accelerator().amp().custom_fwd
+ autocast_custom_bwd = get_accelerator().amp().custom_bwd
except (ImportError, AttributeError) as exp:
autocast_custom_fwd = noop_decorator
autocast_custom_bwd = noop_decorator
diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py
index b7adc13a0ea2..74a5673bc1bc 100644
--- a/deepspeed/runtime/zero/offload_config.py
+++ b/deepspeed/runtime/zero/offload_config.py
@@ -5,7 +5,9 @@
from enum import Enum
from pathlib import Path
-from deepspeed.pydantic_v1 import Field, validator
+from pydantic import Field, model_validator
+from typing import Optional
+
from deepspeed.runtime.config_utils import DeepSpeedConfigModel, pp_int
@@ -25,7 +27,7 @@ class DeepSpeedZeroOffloadParamConfig(DeepSpeedConfigModel):
`nvme`.
"""
- nvme_path: Path = None
+ nvme_path: Optional[Path] = None
""" Filesystem path for NVMe device for parameter offloading. """
buffer_count: int = Field(5, ge=0)
@@ -56,7 +58,7 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
`nvme`. Optimizer computation is offload to CPU regardless of device option.
"""
- nvme_path: Path = None
+ nvme_path: Optional[Path] = None
""" Filesystem path for NVMe device for optimizer state offloading. """
buffer_count: int = Field(4, ge=0)
@@ -88,10 +90,11 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel):
fast_init: bool = False
""" Enable fast optimizer initialization when offloading to NVMe. """
- @validator("pipeline_read", "pipeline_write", always=True)
- def set_pipeline(cls, field_value, values):
- values["pipeline"] = field_value or values.get("pipeline", False)
- return field_value
-
ratio: float = Field(1.0, ge=0.0, le=1.0)
""" Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3."""
+
+ @model_validator(mode="after")
+ def set_pipeline(self):
+ pipeline = self.pipeline_read or self.pipeline_write
+ self.__dict__["pipeline"] = pipeline
+ return self
diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py
index 9b7645261eae..796957a4c6e5 100644
--- a/deepspeed/runtime/zero/stage3.py
+++ b/deepspeed/runtime/zero/stage3.py
@@ -7,6 +7,7 @@
import gc
import collections
from typing import Deque, Dict, Tuple
+from contextlib import contextmanager
from deepspeed import comm as dist
from deepspeed.utils import groups
@@ -15,7 +16,7 @@
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
-from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item
+from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
@@ -69,6 +70,39 @@ def move_to_cpu(tensor_list):
tensor.data = tensor.data.cpu()
+@contextmanager
+def unwrap_model_for_generation(model):
+ """
+ For ZeRO-3 models, we gather the weights once to speed up generation.
+ """
+ with GatheredParameters(model.parameters()):
+ # Removes the optimizer hooks from a DeepSpeed ZeRO-3 model.
+
+ # Remove hooks
+ if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
+ optimizer_offload = model.optimizer.parameter_offload
+ elif model.optimizer is not None:
+ optimizer_offload = model.optimizer
+
+ for hook in optimizer_offload.forward_hooks:
+ hook.remove()
+ for hook in optimizer_offload.backward_hooks:
+ hook.remove()
+
+ optimizer_offload.forward_hooks = []
+ optimizer_offload.backward_hooks = []
+
+ yield model
+
+ # Adds the optimizer hooks from a DeepSpeed ZeRO-3 model.
+ if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
+ optimizer_offload = model.optimizer.parameter_offload
+ elif model.optimizer is not None:
+ optimizer_offload = model.optimizer
+ optimizer_offload._register_hooks_recursively(optimizer_offload.module)
+ return
+
+
INITIAL_MICRO_STEP_ID = -1
@@ -215,14 +249,12 @@ def __init__(
self.module = module
self.elastic_checkpoint = elastic_checkpoint
- self.inf_or_nan_tracker: Tensor = torch.zeros(1,
- dtype=torch.bool,
- device=get_accelerator().current_device_name(),
- requires_grad=False)
+ self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu
+
+ self.inf_or_nan_tracker: Tensor = torch.zeros(1, dtype=torch.bool, device=self.device, requires_grad=False)
self.deepspeed_adam_offload = (self.offload_optimizer and type(init_optimizer) == DeepSpeedCPUAdam)
- self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu
### streams used for overlapping computation with communication
self.reduce_and_partition_stream = None if get_accelerator().is_synchronized_device() else get_accelerator(
).Stream() if overlap_comm else get_accelerator().default_stream()
@@ -1413,7 +1445,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
- return total_norm
+ return total_norm.cpu()
@instrument_w_nvtx
def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None:
@@ -2028,7 +2060,7 @@ def step(self, closure=None):
return
norm_groups = self._get_norm_groups()
- scaled_global_grad_norm = get_global_norm(norm_list=norm_groups)
+ scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups))
# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
@@ -2112,8 +2144,8 @@ def unscale_and_clip_grads(self, sub_group_id, total_norm):
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
- if clip > 1:
- combined_scale = clip * self.loss_scale
+ clip = torch.clamp(clip, min=1.0)
+ combined_scale = clip * self.loss_scale
self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)
@@ -2148,7 +2180,8 @@ def has_overflow(self, partition_gradients=True):
self.inf_or_nan_tracker += torch.isnan(self.grad_partitions_flat_buffer).any()
self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0
- overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8)
+ overflow_gpu = self.inf_or_nan_tracker.clone().to(get_accelerator().current_device_name()).to(
+ torch.uint8)
self.inf_or_nan_tracker.zero_()
if not get_accelerator().resolves_data_dependency():
diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py
index 57e80911d645..83cf996ca019 100755
--- a/deepspeed/runtime/zero/stage_1_and_2.py
+++ b/deepspeed/runtime/zero/stage_1_and_2.py
@@ -725,8 +725,9 @@ def reduce_gradients(self, pipeline_parallel=False):
def get_first_param_index(self, group_id, param_group, partition_id):
for index, param in enumerate(param_group):
param_id = self.get_param_id(param)
- if partition_id in self.param_to_partition_ids[group_id][param_id]:
- return index
+ if group_id in self.param_to_partition_ids and param_id in self.param_to_partition_ids[group_id]:
+ if partition_id in self.param_to_partition_ids[group_id][param_id]:
+ return index
return None
def initialize_gradient_partitioning_data_structures(self):
diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py
index fae725819a6b..8f913d065934 100755
--- a/deepspeed/runtime/zero/utils.py
+++ b/deepspeed/runtime/zero/utils.py
@@ -68,7 +68,6 @@ def get_lst_from_rank0(lst: List[int]) -> None:
lst_tensor = torch.tensor(
lst if dist.get_rank() == 0 else [-1] * len(lst),
dtype=int,
- # device=get_accelerator().current_device_name(),
device=torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])),
requires_grad=False,
)
diff --git a/deepspeed/sequence/cross_entropy.py b/deepspeed/sequence/cross_entropy.py
new file mode 100644
index 000000000000..baa7bc1ea7a8
--- /dev/null
+++ b/deepspeed/sequence/cross_entropy.py
@@ -0,0 +1,60 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+
+import deepspeed.comm as dist
+
+
+class _VocabSequenceParallelCrossEntropy(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, vocab_seq_parallel_logits, target, sp_group):
+ # vocab_seq_parallel_logits: [S/P, B, V]
+ # target: [S/P, B]
+ # return: [S, B]
+
+ # Need softmax for backward
+ softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1)
+ ctx.vocab_size = vocab_seq_parallel_logits.size(2)
+ loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction='none')
+
+ sp_world_size = dist.get_world_size(sp_group)
+ sp_rank = dist.get_rank(sp_group)
+ ctx.sp_world_size = sp_world_size
+ ctx.sp_rank = sp_rank
+ ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_world_size
+ batch_size = vocab_seq_parallel_logits.size(1)
+
+ loss_all = torch.empty(ctx.seqlen,
+ batch_size,
+ dtype=vocab_seq_parallel_logits.dtype,
+ device=vocab_seq_parallel_logits.device)
+ dist.all_gather_into_tensor(loss_all, loss, group=sp_group)
+
+ ctx.save_for_backward(softmax, target)
+
+ return loss_all
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ softmax, target = ctx.saved_tensors
+
+ step_seqlen = ctx.seqlen // ctx.sp_world_size
+ sp_rank = ctx.sp_rank
+ grad_output_part = grad_output[step_seqlen * sp_rank:step_seqlen * (sp_rank + 1), :]
+
+ grad_input = softmax
+ grad_2d = grad_input.view(-1, ctx.vocab_size)
+ arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
+
+ grad_2d[arange_1d, target.view(-1)] -= 1
+ grad_input.mul_(grad_output_part.unsqueeze(dim=-1))
+
+ return grad_input, None, None, None
+
+
+def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, sp_group):
+ return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, sp_group)
diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py
old mode 100644
new mode 100755
index c49f4520e16e..9dd288ef46db
--- a/deepspeed/utils/groups.py
+++ b/deepspeed/utils/groups.py
@@ -30,6 +30,7 @@
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size, bwc_pipeline_parallel_world_size
from deepspeed.utils.exceptions import DeprecatedException
from deepspeed.accelerator import get_accelerator
+
# Expert parallel group that the current rank belongs to.
_EXPERT_PARALLEL_GROUP = {}
# Expert data parallel group that the current rank belongs to.
@@ -47,6 +48,8 @@
_DATA_PARALLEL_GROUP = None
+mesh_device = None
+
# Deprecated groups initialize function.
def initialize(ep_size=1, mpu=None):
@@ -398,8 +401,11 @@ def _get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert dist.is_initialized(), 'dist is not initialized'
global mpu
+ if mesh_device is not None:
+ return mesh_device.get_group(mesh_dim="data_parallel")
if mpu is not None:
return mpu.get_data_parallel_group()
+
# Return the clone of dist world group
return _clone_world_group()
@@ -442,6 +448,8 @@ def _get_expert_data_parallel_rank(group_name):
def _get_data_parallel_world_size():
"""Return world size for the data parallel group."""
+ if mesh_device is not None:
+ return dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
global mpu
if mpu is not None:
return mpu.get_data_parallel_world_size()
@@ -464,6 +472,8 @@ def _get_data_parallel_rank():
def _get_sequence_parallel_world_size():
"""Return world size for the model parallel group."""
global mpu
+ if mesh_device is not None:
+ return dist.get_world_size(mesh_device.get_group(mesh_dim="sequence_parallel"))
if mpu is not None and hasattr(mpu, 'get_sequence_parallel_world_size'):
return mpu.get_sequence_parallel_world_size()
return 1
@@ -479,9 +489,11 @@ def _get_sequence_parallel_rank():
def _get_sequence_parallel_group():
global mpu
- if mpu is not None and hasattr(mpu, 'get_sequence_parallel_group'):
- return mpu.get_sequence_parallel_group()
- return None
+ if mpu is None or not hasattr(mpu, 'get_sequence_parallel_group'):
+ if mesh_device is None:
+ raise KeyError("No sequence parallel group found")
+ return mesh_device.get_group(mesh_dim="sequence_parallel")
+ return mpu.get_sequence_parallel_group()
def _get_sequence_data_parallel_world_size():
diff --git a/docs/_tutorials/accelerator-abstraction-interface.md b/docs/_tutorials/accelerator-abstraction-interface.md
index 88a43236ce9d..d7c153638c0d 100644
--- a/docs/_tutorials/accelerator-abstraction-interface.md
+++ b/docs/_tutorials/accelerator-abstraction-interface.md
@@ -12,7 +12,6 @@ tags: getting-started
- [Tensor operations](#tensor-operations)
- [Communication backend](#communication-backend)
- [Run DeepSpeed model on different accelerators](#run-deepspeed-model-on-different-accelerators)
-- [Run DeepSpeed model on CPU](#run-deepspeed-model-on-cpu)
- [Implement new accelerator extension](#implement-new-accelerator-extension)
# Introduction
@@ -79,69 +78,9 @@ torch.distributed.init_process_group(get_accelerator().communication_backend_nam
```
# Run DeepSpeed model on different accelerators
-Once a model is ported with DeepSpeed Accelerator Abstraction Interface, we can run this model on different accelerators using an extension to DeepSpeed. DeepSpeed checks whether a certain extension is installed in the environment to decide whether to use the Accelerator backend in that extension. For example, if we wish to run a model on Intel GPU, we can install _Intel Extension for DeepSpeed_ following the instructions in the following [link](https://github.com/intel/intel-extension-for-deepspeed/)
-
-After the extension is installed, install DeepSpeed and run the model. The model will be running on top of DeepSpeed. Because DeepSpeed installation is also accelerator related, it is recommended to install DeepSpeed accelerator extension before installing DeepSpeed.
-
-`CUDA_Accelerator` is the default accelerator in DeepSpeed. If no other DeepSpeed accelerator extension is installed, `CUDA_Accelerator` will be used.
-
-When running a model on different accelerators in a cloud environment, the recommended practice is to provision an environment for each accelerator in a different env with tools such as _anaconda/miniconda/virtualenv_. When running models on different Accelerator, load the env accordingly.
-
-Note that different accelerator may have different 'flavor' of float16 or bfloat16. So it is recommended to make the model configurable for both float16 and bfloat16, in that way model code does not need to be changed when running on different accelerators.
-
-# Run DeepSpeed model on CPU
-DeepSpeed support using CPU as accelerator. DeepSpeed model using DeepSpeed Accelerator Abstraction Interface could run on CPU without change to model code. DeepSpeed decide whether _Intel Extension for PyTorch_ is installed in the environment. If this packaged is installed, DeepSpeed will use CPU as accelerator. Otherwise CUDA device will be used as accelerator.
-
-To run DeepSpeed model on CPU, use the following steps to prepare environment:
-
-```
-python -m pip install intel_extension_for_pytorch
-python -m pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable-cpu
-git clone https://github.com/oneapi-src/oneCCL
-cd oneCCL
-mkdir build
-cd build
-cmake ..
-make
-make install
-```
-
-Before run CPU workload, we need to source oneCCL environment variables
-```
-source /build/_install/env/setvars.sh
-```
-
-After environment is prepared, we can launch DeepSpeed inference with the following command
-```
-deepspeed --bind_cores_to_rank
-```
-
-This command would launch number of workers equal to number of CPU sockets on the system. Currently DeepSpeed support running inference model with AutoTP on top of CPU. The argument `--bind_cores_to_rank` distribute CPU cores on the system evenly among workers, to allow each worker running on a dedicated set of CPU cores.
-
-On CPU system, there might be daemon process that periodically activate which would increase variance of each worker. One practice is leave a couple of cores for daemon process using `--bind-core-list` argument:
-
-```
-deepspeed --bind_cores_to_rank --bind_core_list 0-51,56-107
-```
-
-The command above leave 4 cores on each socket to daemon process (assume two sockets, each socket has 56 cores).
-
-We can also set an arbitrary number of workers. Unlike GPU, CPU cores on host can be further divided into subgroups. When this number is not set, DeepSpeed would detect number of NUMA nodes on the system and launch one worker for each NUMA node.
-
-```
-deepspeed --num_accelerators 4 --bind_cores_to_rank
-```
-
-Launching DeepSpeed model on multiple CPU nodes is similar to other accelerators. We need to specify `impi` as launcher and specify `--bind_cores_to_rank` for better core binding. Also specify `slots` number according to number of CPU sockets in host file.
-
-```
-# hostfile content should follow the format
-# worker-1-hostname slots=<#sockets>
-# worker-2-hostname slots=<#sockets>
-# ...
-
-deepspeed --hostfile= --bind_cores_to_rank --launcher impi --master_addr
-```
+[Accelerator Setup Guide](accelerator-setup-guide.md) provides a guide on how to setup different accelerators for DeepSpeed. It also comes with simple example how to run deepspeed for different accelerators. The following guides are provided:
+1. Run DeepSpeed model on CPU
+2. Run DeepSpeed model on XPU
# Implement new accelerator extension
It is possible to implement a new DeepSpeed accelerator extension to support new accelerator in DeepSpeed. An example to follow is _[Intel Extension For DeepSpeed](https://github.com/intel/intel-extension-for-deepspeed/)_. An accelerator extension contains the following components:
diff --git a/docs/_tutorials/accelerator-setup-guide.md b/docs/_tutorials/accelerator-setup-guide.md
new file mode 100644
index 000000000000..cf2d01d2b25c
--- /dev/null
+++ b/docs/_tutorials/accelerator-setup-guide.md
@@ -0,0 +1,134 @@
+---
+title: DeepSpeed Accelerator Setup Guides
+tags: getting-started
+---
+
+# Contents
+- [Contents](#contents)
+- [Introduction](#introduction)
+- [Intel Architecture (IA) CPU](#intel-architecture-ia-cpu)
+- [Intel XPU](#intel-xpu)
+
+# Introduction
+DeepSpeed supports different accelerators from different companies. Setup steps to run DeepSpeed on certain accelerators might be different. This guide allows user to lookup setup instructions for the accelerator family and hardware they are using.
+
+# Intel Architecture (IA) CPU
+DeepSpeed supports CPU with Intel Architecture instruction set. It is recommended to have the CPU support at least AVX2 instruction set and recommend AMX instruction set.
+
+DeepSpeed has been verified on the following CPU processors:
+* 4th Gen Intel® Xeon® Scalarable Processors
+* 5th Gen Intel® Xeon® Scalarable Processors
+* 6th Gen Intel® Xeon® Scalarable Processors
+
+## Installation steps for Intel Architecture CPU
+To install DeepSpeed on Intel Architecture CPU, use the following steps:
+1. Install gcc compiler
+DeepSpeed requires gcc-9 or above to build kernels on Intel Architecture CPU, install gcc-9 or above.
+
+2. Install numactl
+DeepSpeed use `numactl` for fine grain CPU core allocation for load-balancing, install numactl on your system.
+For example, on Ubuntu system, use the following command:
+`sudo apt-get install numactl`
+
+3. Install PyTorch
+`pip install torch`
+
+4. Install DeepSpeed
+`pip install deepspeed`
+
+## How to launch DeepSpeed on Intel Architecture CPU
+DeepSpeed can launch on Intel Architecture CPU with default deepspeed command. However, for compute intensive workloads, Intel Architecture CPU works best when each worker process runs on different set of physical CPU cores, so worker process does not compete CPU cores with each other. To bind cores to each worker (rank), use the following command line switch for better performance.
+```
+deepspeed --bind_cores_to_rank
+```
+This switch would automatically detect the number of CPU NUMA node on the host, launch the same number of workers, and bind each worker to cores/memory of a different NUMA node. This improves performance by ensuring workers do not interfere with each other, and that all memory allocation is from local memory.
+
+If a user wishes to have more control on the number of workers and specific cores that can be used by the workload, user can use the following command line switches.
+```
+deepspeed --num_accelerators --bind_cores_to_rank --bind_core_list
+```
+For example:
+```
+deepspeed --num_accelerators 4 --bind_cores_to_rank --bind_core_list <0-27,32-59> inference.py
+```
+This would start 4 workers for the workload. The core list range will be divided evenly between 4 workers, with worker 0 take 0-13, worker 1, take 14-27, worker 2 take 32-45, and worker 3 take 46-59. Core 28-31,60-63 are left out because there might be some background process running on the system, leaving some idle cores will reduce performance jitting and straggler effect.
+
+Launching DeepSpeed model on multiple CPU nodes is similar to other accelerators. We need to specify `impi` as launcher and specify `--bind_cores_to_rank` for better core binding. Also specify `slots` number according to number of CPU sockets in host file.
+
+```
+# hostfile content should follow the format
+# worker-1-hostname slots=<#sockets>
+# worker-2-hostname slots=<#sockets>
+# ...
+
+deepspeed --hostfile= --bind_cores_to_rank --launcher impi --master_addr
+```
+
+## Install with Intel Extension for PyTorch and oneCCL
+Although not mandatory, Intel Extension for PyTorch and Intel oneCCL provide better optimizations for LLM models. Intel oneCCL also provide optimization when running LLM model on multi-node. To use DeepSpeed with Intel Extension for PyTorch and oneCCL, use the following steps:
+1. Install Intel Extension for PyTorch. This is suggested if you want to get better LLM inference performance on CPU.
+`pip install intel-extension-for-pytorch`
+
+The following steps are to install oneCCL binding for PyTorch. This is suggested if you are running DeepSpeed on multiple CPU node, for better communication performance. On single node with multiple CPU socket, these steps are not needed.
+
+2. Install Intel oneCCL binding for PyTorch
+`python -m pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable-cpu`
+
+3. Install Intel oneCCL, this will be used to build direct oneCCL kernels (CCLBackend kernels)
+```
+pip install oneccl-devel
+pip install impi-devel
+```
+Then set the environment variables for Intel oneCCL (assuming using conda environment).
+```
+export CPATH=${CONDA_PREFIX}/include:$CPATH
+export CCL_ROOT=${CONDA_PREFIX}
+export I_MPI_ROOT=${CONDA_PREFIX}
+export LD_LIBRARY_PATH=${CONDA_PREFIX}/lib/ccl/cpu:${CONDA_PREFIX}/lib/libfabric:${CONDA_PREFIX}/lib
+```
+
+## Optimize LLM inference with Intel Extension for PyTorch
+Intel Extension for PyTorch compatible with DeepSpeed AutoTP tensor parallel inference. It allows CPU inference to benefit from both DeepSpeed Automatic Tensor Parallelism, and LLM optimizations of Intel Extension for PyTorch. To use Intel Extension for PyTorch, after calling deepspeed.init_inference, call
+```
+ipex_model = ipex.llm.optimize(deepspeed_model)
+```
+to get model optimzied by Intel Extension for PyTorch.
+
+## More example for using DeepSpeed with Intel Extension for PyTorch on Intel Architecture CPU
+Refer to https://github.com/intel/intel-extension-for-pytorch/tree/main/examples/cpu/inference/python/llm for more extensive guide.
+
+# Intel XPU
+DeepSpeed XPU accelerator supports Intel® Data Center GPU Max Series.
+
+DeepSpeed has been verified on the following GPU products:
+* Intel® Data Center GPU Max 1100
+* Intel® Data Center GPU Max 1550
+
+## Installation steps for Intel XPU
+To install DeepSpeed on Intel XPU, use the following steps:
+1. Install oneAPI base toolkit \
+The Intel® oneAPI Base Toolkit (Base Kit) is a core set of tools and libraries, including an DPC++/C++ Compiler for building Deepspeed XPU kernels like fusedAdam and CPUAdam, high performance computation libraries demanded by IPEX, etc.
+For easy download, usage and more details, check [Intel oneAPI base-toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html).
+2. Install PyTorch, Intel extension for pytorch, Intel oneCCL Bindings for PyTorch. These packages are required in `xpu_accelerator` for torch functionality and performance, also communication backend on Intel platform. The recommended installation reference:
+https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu.
+
+3. Install DeepSpeed \
+`pip install deepspeed`
+
+## How to use DeepSpeed on Intel XPU
+DeepSpeed can be launched on Intel XPU with deepspeed launch command. Before that, user needs activate the oneAPI environment by: \
+`source /setvars.sh`
+
+To validate the XPU availability and if the XPU accelerator is correctly chosen, here is an example:
+```
+$ python
+>>> import torch; print('torch:', torch.__version__)
+torch: 2.3.0
+>>> import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available())
+XPU available: True
+>>> from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)
+accelerator: xpu
+```
+
+## More example for using DeepSpeed on Intel XPU
+Refer to https://github.com/intel/intel-extension-for-pytorch/tree/release/xpu/2.1.40/examples/gpu/inference/python/llm for more extensive guide.
diff --git a/docs/index.md b/docs/index.md
index 127c7226e6d4..1efdcea132d2 100755
--- a/docs/index.md
+++ b/docs/index.md
@@ -7,11 +7,12 @@ title: "Latest News"
---
DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) .
+* [2024/08] [DeepSpeed on Windows](https://github.com/microsoft/DeepSpeed/blob/master/blogs/windows/08-2024/README.md)[[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/windows/08-2024/japanese/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/windows/08-2024/chinese/README.md)]
+
* [2024/08] [DeepNVMe: Improving DL Applications through I/O Optimizations](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/README.md)[[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/japanese/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/chinese/README.md)]
* [2024/07] [DeepSpeed Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/README.md)[[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/japanese/README.md)]
* [2024/03] [DeepSpeed-FP6: The Power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)]
* [2024/01] [DeepSpeed-FastGen: Introducting Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19)
-* [2023/11] [Llama 2 Inference on 4th Gen Intel® Xeon® Scalable Processor with DeepSpeed](https://github.com/microsoft/DeepSpeed/tree/master/blogs/intel-inference) [[Intel version]](https://www.intel.com/content/www/us/en/developer/articles/technical/xllama-2-on-xeon-scalable-processor-with-deepspeed.html)
@@ -19,12 +20,12 @@ title: "Latest News"
More news
diff --git a/op_builder/async_io.py b/op_builder/async_io.py
index b55c821910b9..e7f16adbf2a3 100644
--- a/op_builder/async_io.py
+++ b/op_builder/async_io.py
@@ -3,13 +3,14 @@
# DeepSpeed Team
+import os
import distutils.spawn
import subprocess
-from .builder import OpBuilder
+from .builder import TorchCPUOpBuilder
-class AsyncIOBuilder(OpBuilder):
+class AsyncIOBuilder(TorchCPUOpBuilder):
BUILD_VAR = "DS_BUILD_AIO"
NAME = "async_io"
@@ -19,44 +20,54 @@ def __init__(self):
def absolute_name(self):
return f'deepspeed.ops.aio.{self.NAME}_op'
- def sources(self):
- return [
- 'csrc/aio/py_lib/deepspeed_py_copy.cpp', 'csrc/aio/py_lib/py_ds_aio.cpp',
- 'csrc/aio/py_lib/deepspeed_py_aio.cpp', 'csrc/aio/py_lib/deepspeed_py_aio_handle.cpp',
- 'csrc/aio/py_lib/deepspeed_aio_thread.cpp', 'csrc/aio/common/deepspeed_aio_utils.cpp',
- 'csrc/aio/common/deepspeed_aio_common.cpp', 'csrc/aio/common/deepspeed_aio_types.cpp',
+ def lib_sources(self):
+ src_list = [
+ 'csrc/aio/py_lib/deepspeed_py_io_handle.cpp', 'csrc/aio/py_lib/deepspeed_py_aio.cpp',
+ 'csrc/aio/py_lib/deepspeed_py_aio_handle.cpp', 'csrc/aio/py_lib/deepspeed_aio_thread.cpp',
+ 'csrc/aio/common/deepspeed_aio_utils.cpp', 'csrc/aio/common/deepspeed_aio_common.cpp',
+ 'csrc/aio/common/deepspeed_aio_types.cpp', 'csrc/aio/py_lib/deepspeed_cpu_op.cpp',
+ 'csrc/aio/py_lib/deepspeed_aio_op_desc.cpp', 'csrc/aio/py_lib/deepspeed_py_copy.cpp',
'csrc/aio/py_lib/deepspeed_pin_tensor.cpp'
]
+ return src_list
+
+ def sources(self):
+ return self.lib_sources() + ['csrc/aio/py_lib/py_ds_aio.cpp']
def include_paths(self):
- return ['csrc/aio/py_lib', 'csrc/aio/common']
+ import torch
+ if self.build_for_cpu:
+ CUDA_INCLUDE = []
+ elif not self.is_rocm_pytorch():
+ CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
+ else:
+ CUDA_INCLUDE = [
+ os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"),
+ os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "rocrand"),
+ os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include", "hiprand"),
+ ]
+ return ['csrc/aio/py_lib', 'csrc/aio/common'] + CUDA_INCLUDE
def cxx_args(self):
# -O0 for improved debugging, since performance is bound by I/O
- CPU_ARCH = self.cpu_arch()
- SIMD_WIDTH = self.simd_width()
- import torch # Keep this import here to avoid errors when building DeepSpeed wheel without torch installed
+ args = super().cxx_args()
+ import torch
TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[0:2])
- if TORCH_MAJOR >= 2 and TORCH_MINOR >= 1:
- CPP_STD = '-std=c++17'
- else:
- CPP_STD = '-std=c++14'
- return [
- '-g',
- '-Wall',
- '-O0',
- CPP_STD,
- '-shared',
- '-fPIC',
- '-Wno-reorder',
- CPU_ARCH,
- '-fopenmp',
- SIMD_WIDTH,
- '-laio',
- ]
+ if not (TORCH_MAJOR >= 2 and TORCH_MINOR >= 1):
+ args.remove('-std=c++17')
+ args.append('-std=c++14')
+ args += ['-Wall', '-O0', '-shared', '-fPIC', '-Wno-reorder']
+ return args
def extra_ldflags(self):
- return ['-laio']
+ if self.build_for_cpu:
+ return ['-fopenmp']
+
+ import torch.utils.cpp_extension
+ CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME
+ CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64")
+ ldflags = [f'-L{CUDA_HOME}', f'-L{CUDA_LIB64}', '-laio', '-lcuda', '-lcudart']
+ return ldflags
def check_for_libaio_pkg(self):
libs = dict(
@@ -79,13 +90,13 @@ def check_for_libaio_pkg(self):
break
return found
- def is_compatible(self, verbose=True):
+ def is_compatible(self, verbose=False):
# Check for the existence of libaio by using distutils
# to compile and link a test program that calls io_submit,
# which is a function provided by libaio that is used in the async_io op.
# If needed, one can define -I and -L entries in CFLAGS and LDFLAGS
# respectively to specify the directories for libaio.h and libaio.so.
- aio_compatible = self.has_function('io_pgetevents', ('aio', ))
+ aio_compatible = self.has_function('io_submit', ('aio', ))
if verbose and not aio_compatible:
self.warning(f"{self.NAME} requires the dev libaio .so object and headers but these were not found.")
diff --git a/op_builder/builder.py b/op_builder/builder.py
index 8998fc0eddb8..ca4b339e2447 100644
--- a/op_builder/builder.py
+++ b/op_builder/builder.py
@@ -305,7 +305,7 @@ def is_compatible(self, verbose=True):
def extra_ldflags(self):
return []
- def has_function(self, funcname, libraries, verbose=False):
+ def has_function(self, funcname, libraries, library_dirs=None, verbose=False):
'''
Test for existence of a function within a tuple of libraries.
@@ -361,7 +361,8 @@ def has_function(self, funcname, libraries, verbose=False):
compiler.link_executable(objs,
os.path.join(tempdir, 'a.out'),
extra_preargs=self.strip_empty_entries(ldflags),
- libraries=libraries)
+ libraries=libraries,
+ library_dirs=library_dirs)
# Compile and link succeeded
return True
diff --git a/op_builder/evoformer_attn.py b/op_builder/evoformer_attn.py
index af3aa7429775..7f68ccf87290 100644
--- a/op_builder/evoformer_attn.py
+++ b/op_builder/evoformer_attn.py
@@ -52,11 +52,27 @@ def is_compatible(self, verbose=False):
if verbose:
self.warning("Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH")
return False
- with open(f'{self.cutlass_path}/CHANGELOG.md', 'r') as f:
- if '3.1.0' not in f.read():
+ if os.path.exists(f'{self.cutlass_path}/CHANGELOG.md'):
+ with open(f'{self.cutlass_path}/CHANGELOG.md', 'r') as f:
+ if '3.1.0' not in f.read():
+ if verbose:
+ self.warning("Please use CUTLASS version >= 3.1.0")
+ return False
+ else:
+ # pip install nvidia-cutlass package
+ try:
+ import cutlass
+ except ImportError:
+ if verbose:
+ self.warning("Please pip install nvidia-cutlass if trying to pre-compile kernels")
+ return False
+ cutlass_major, cutlass_minor = cutlass.__version__.split('.')[:2]
+ cutlass_compatible = (int(cutlass_major) >= 3 and int(cutlass_minor) >= 1)
+ if not cutlass_compatible:
if verbose:
self.warning("Please use CUTLASS version >= 3.1.0")
return False
+
cuda_okay = True
if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda
sys_cuda_major, _ = installed_cuda_version()
diff --git a/op_builder/gds.py b/op_builder/gds.py
new file mode 100644
index 000000000000..01c2d5a245d1
--- /dev/null
+++ b/op_builder/gds.py
@@ -0,0 +1,56 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import os
+from .async_io import AsyncIOBuilder
+
+
+class GDSBuilder(AsyncIOBuilder):
+ BUILD_VAR = "DS_BUILD_GDS"
+ NAME = "gds"
+
+ def __init__(self):
+ super().__init__()
+
+ def absolute_name(self):
+ return f'deepspeed.ops.gds.{self.NAME}_op'
+
+ def lib_sources(self):
+ src_list = ['csrc/gds/py_lib/deepspeed_py_gds_handle.cpp', 'csrc/gds/py_lib/deepspeed_gds_op.cpp']
+ return super().lib_sources() + src_list
+
+ def sources(self):
+ return self.lib_sources() + ['csrc/gds/py_lib/py_ds_gds.cpp']
+
+ def cxx_args(self):
+ return super().cxx_args() + ['-lcufile']
+
+ def include_paths(self):
+ import torch
+ CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
+ return ['csrc/aio/py_lib', 'csrc/aio/common'] + CUDA_INCLUDE
+
+ def extra_ldflags(self):
+ return super().extra_ldflags() + ['-lcufile']
+
+ def is_compatible(self, verbose=False):
+ try:
+ import torch.utils.cpp_extension
+ except ImportError:
+ if verbose:
+ self.warning("Please install torch if trying to pre-compile GDS")
+ return False
+
+ CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME
+ CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64")
+ gds_compatible = self.has_function(funcname="cuFileDriverOpen",
+ libraries=("cufile", ),
+ library_dirs=(
+ CUDA_HOME,
+ CUDA_LIB64,
+ ),
+ verbose=verbose)
+
+ return gds_compatible and super().is_compatible(verbose)
diff --git a/op_builder/xpu/inference.py b/op_builder/xpu/inference.py
index 9114dcc2c315..a9ac4f84c2ca 100644
--- a/op_builder/xpu/inference.py
+++ b/op_builder/xpu/inference.py
@@ -30,7 +30,10 @@ def cxx_args(self):
def load(self):
try:
- import intel_extension_for_pytorch.deepspeed
- return intel_extension_for_pytorch.deepspeed.transformer_inference.transformer_inference
+ import intel_extension_for_pytorch
+ if hasattr(intel_extension_for_pytorch, "deepspeed"):
+ return intel_extension_for_pytorch.deepspeed.transformer_inference.transformer_inference
+ else:
+ return intel_extension_for_pytorch.xpu.deepspeed
except ImportError:
raise ImportError("Please install intel-extension-for-pytorch >= 2.1.30 to include DeepSpeed kernels.")
diff --git a/requirements/requirements-readthedocs.txt b/requirements/requirements-readthedocs.txt
index 1a2ad18611e7..a48a47e4428d 100644
--- a/requirements/requirements-readthedocs.txt
+++ b/requirements/requirements-readthedocs.txt
@@ -1,10 +1,10 @@
-autodoc_pydantic
+autodoc_pydantic>=2.0.0
docutils<0.18
hjson
packaging
psutil
py-cpuinfo
-pydantic<2.0.0
+pydantic>=2.0.0
recommonmark
sphinx_rtd_theme
torch
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 05f88337f3a9..70c94a745435 100755
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -1,10 +1,9 @@
hjson
ninja
numpy
-nvidia-ml-py
packaging>=20.0
psutil
py-cpuinfo
-pydantic
+pydantic>=2.0.0
torch
tqdm
diff --git a/setup.py b/setup.py
index 2b7555361655..8707209526ad 100755
--- a/setup.py
+++ b/setup.py
@@ -92,6 +92,10 @@ def get_env_if_set(key, default: typing.Any = ""):
'triton': fetch_requirements('requirements/requirements-triton.txt'),
}
+# Only install pynvml on nvidia gpus.
+if torch_available and get_accelerator().device_name() == 'cuda' and not is_rocm_pytorch:
+ install_requires.append('nvidia-ml-py')
+
# Add specific cupy version to both onebit extension variants.
if torch_available and get_accelerator().device_name() == 'cuda':
cupy = None
diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py
index 25256d376eeb..dfab28aa7477 100644
--- a/tests/unit/alexnet_model.py
+++ b/tests/unit/alexnet_model.py
@@ -14,6 +14,7 @@
from deepspeed.utils.torch import required_torch_version
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec
+from .util import no_child_process_in_deepspeed_io
class AlexNet(nn.Module):
@@ -125,22 +126,11 @@ def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True,
trainset = cifar_trainset(fp16=fp16)
config['local_rank'] = dist.get_rank()
- # deepspeed_io defaults to creating a dataloader that uses a
- # multiprocessing pool. Our tests use pools and we cannot nest pools in
- # python. Therefore we're injecting this kwarg to ensure that no pools
- # are used in the dataloader.
- old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io
-
- def new_method(*args, **kwargs):
- kwargs["num_local_io_workers"] = 0
- return old_method(*args, **kwargs)
-
- deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method
-
- engine, _, _, _ = deepspeed.initialize(config=config,
- model=model,
- model_parameters=[p for p in model.parameters()],
- training_data=trainset)
+ with no_child_process_in_deepspeed_io():
+ engine, _, _, _ = deepspeed.initialize(config=config,
+ model=model,
+ model_parameters=[p for p in model.parameters()],
+ training_data=trainset)
losses = []
for step in range(num_steps):
diff --git a/tests/unit/inference/v2/ragged/test_manager_configs.py b/tests/unit/inference/v2/ragged/test_manager_configs.py
index a5f270cced8c..bdd513445ddb 100644
--- a/tests/unit/inference/v2/ragged/test_manager_configs.py
+++ b/tests/unit/inference/v2/ragged/test_manager_configs.py
@@ -5,7 +5,7 @@
import pytest
-from deepspeed.pydantic_v1 import ValidationError
+from pydantic import ValidationError
from deepspeed.inference.v2.ragged import DSStateManagerConfig
diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py
index fdff9430a4e6..f65d5e2a03bc 100644
--- a/tests/unit/moe/test_moe.py
+++ b/tests/unit/moe/test_moe.py
@@ -11,7 +11,7 @@
from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader
import deepspeed.comm as dist
from deepspeed import get_accelerator
-from deepspeed.moe.sharded_moe import top1gating
+from deepspeed.moe.sharded_moe import top1gating, topkgating
from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param
from deepspeed.utils.torch import required_torch_version
@@ -191,3 +191,50 @@ def test(self):
drop_tokens=False,
use_rts=True,
use_tutel=False)
+
+
+class TestTopkGate(DistributedTest):
+
+ def test(self):
+
+ def check_equal(logits, cap, sparse_truth, res):
+ m, n = logits.shape
+ dispatch_mask_truth = torch.zeros(m, n, cap)
+ i, j, k = sparse_truth.t()
+ dispatch_mask_truth[i, j, k] = 1
+ assert (torch.equal(dispatch_mask_truth, res))
+
+ #s=4 e=4 topk=2 cap=2(s*topk/e)
+ logits = torch.tensor([[0.11, 0.2, 0.1, 0.3], [0.3, 0.4, 0.11, 0.1], [0.11, 0.1, 0.6, 0.5],
+ [0.1, 0.11, 0.7, 0.8]])
+ logits *= dist.get_rank() + 1
+ probs_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='probs')[2]
+ probs_sec_sparse = torch.tensor([[0, 1, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 0], [3, 2, 1], [3, 3, 1]])
+ check_equal(logits, 2, probs_sec_sparse, probs_dispatch_res)
+
+ position_sec_sparse = torch.tensor([[0, 1, 0], [0, 3, 0], [1, 0, 0], [1, 1, 1], [2, 2, 0], [2, 3, 1],
+ [3, 2, 1]])
+ position_dispatch_res = topkgating(logits, 2, 1, min_capacity=1, drop_policy='position')[2]
+ check_equal(logits, 2, position_sec_sparse, position_dispatch_res)
+
+ #s=4 e=6 topk=3 cap=2(s*topk/e)
+ logits2 = torch.tensor([[0.5858, 0.4801, 0.6269, 0.5397, 0.9722, 0.7034],
+ [0.5445, 0.6332, 0.4519, 0.6308, 0.0519, 0.6450],
+ [0.4874, 0.8110, 0.7467, 0.8474, 0.0277, 0.3068],
+ [0.8570, 0.6714, 0.5310, 0.3274, 0.4836, 0.9892]])
+ logits2 *= dist.get_rank() + 1
+
+ #top3 full mask #prob_mask #postion_mask
+ #0 0 1 0 1 1 #0 0 1 0 1 1 #0 0 1 0 1 1
+ #0 1 0 1 0 1 #0 0 0 1 0 0 #0 1 0 1 0 1
+ #0 1 1 1 0 0 #0 1 1 1 0 0 #0 1 1 1 0 0
+ #1 1 0 0 0 1 #1 1 0 0 0 1 #1 0 0 0 0 0
+ probs_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='probs')[2]
+ probs_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 3, 0], [2, 1, 0], [2, 2, 1], [2, 3, 1],
+ [3, 0, 0], [3, 1, 1], [3, 5, 1]])
+ check_equal(logits2, 2, probs_sec_sparse, probs_dispatch_res)
+
+ position_sec_sparse = torch.tensor([[0, 2, 0], [0, 4, 0], [0, 5, 0], [1, 1, 0], [1, 3, 0], [1, 5, 1],
+ [2, 1, 1], [2, 2, 1], [2, 3, 1], [3, 0, 0]])
+ position_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='position')[2]
+ check_equal(logits2, 2, position_sec_sparse, position_dispatch_res)
diff --git a/tests/unit/ops/aio/test_aio.py b/tests/unit/ops/aio/test_aio.py
index f6d175ce67bc..e6927efc3824 100644
--- a/tests/unit/ops/aio/test_aio.py
+++ b/tests/unit/ops/aio/test_aio.py
@@ -78,7 +78,7 @@ def _validate_handle_state(handle, single_submit, overlap_events):
assert handle.get_queue_depth() == QUEUE_DEPTH
-@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False])
+@pytest.mark.parametrize("use_cuda_pinned_tensor", [True]) # TODO: aio_handle pinned tensor API is broken
@pytest.mark.parametrize("single_submit", [True, False])
@pytest.mark.parametrize("overlap_events", [True, False])
class TestRead(DistributedTest):
@@ -144,7 +144,7 @@ def test_async_read(self, tmpdir, use_cuda_pinned_tensor, single_submit, overlap
h.free_cpu_locked_tensor(aio_buffer)
-@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False])
+@pytest.mark.parametrize("use_cuda_pinned_tensor", [True]) # TODO: aio_handle pinned tensor API is broken
@pytest.mark.parametrize("single_submit", [True, False])
@pytest.mark.parametrize("overlap_events", [True, False])
class TestWrite(DistributedTest):
@@ -213,7 +213,7 @@ def test_async_write(self, tmpdir, use_cuda_pinned_tensor, single_submit, overla
@pytest.mark.sequential
-@pytest.mark.parametrize("use_cuda_pinned_tensor", [True, False])
+@pytest.mark.parametrize("use_cuda_pinned_tensor", [True]) # TODO: aio_handle pinned tensor API is broken
@pytest.mark.parametrize("cuda_device", [True, False])
class TestAsyncQueue(DistributedTest):
world_size = 1
diff --git a/tests/unit/ops/aio/test_gds.py b/tests/unit/ops/aio/test_gds.py
new file mode 100644
index 000000000000..53655994b560
--- /dev/null
+++ b/tests/unit/ops/aio/test_gds.py
@@ -0,0 +1,270 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import pytest
+import os
+import filecmp
+import torch
+import deepspeed
+import deepspeed.comm as dist
+from deepspeed.accelerator import get_accelerator
+from deepspeed.ops.op_builder import GDSBuilder
+from unit.common import DistributedTest
+
+KILO_BYTE = 1024 * 256
+BLOCK_SIZE = KILO_BYTE
+QUEUE_DEPTH = 2
+IO_SIZE = 4 * BLOCK_SIZE
+IO_PARALLEL = 2
+
+if not deepspeed.ops.__compatible_ops__[GDSBuilder.NAME]:
+ pytest.skip('Skip tests since gds is not compatible', allow_module_level=True)
+
+
+def _get_local_rank():
+ if get_accelerator().is_available():
+ return dist.get_rank()
+ return 0
+
+
+def _do_ref_write(tmpdir, index=0):
+ file_suffix = f'{_get_local_rank()}_{index}'
+ ref_file = os.path.join(tmpdir, f'_py_random_{file_suffix}.pt')
+ ref_buffer = os.urandom(IO_SIZE)
+ with open(ref_file, 'wb') as f:
+ f.write(ref_buffer)
+
+ return ref_file, ref_buffer
+
+
+def _get_test_write_file(tmpdir, index):
+ file_suffix = f'{_get_local_rank()}_{index}'
+ return os.path.join(tmpdir, f'_gds_write_random_{file_suffix}.pt')
+
+
+def _get_test_write_file_and_device_buffer(tmpdir, ref_buffer, gds_handle, index=0):
+ test_file = _get_test_write_file(tmpdir, index)
+ test_buffer = get_accelerator().ByteTensor(list(ref_buffer))
+ gds_handle.pin_device_tensor(test_buffer)
+ return test_file, test_buffer
+
+
+def _validate_handle_state(handle, single_submit, overlap_events):
+ assert handle.get_single_submit() == single_submit
+ assert handle.get_overlap_events() == overlap_events
+ assert handle.get_thread_count() == IO_PARALLEL
+ assert handle.get_block_size() == BLOCK_SIZE
+ assert handle.get_queue_depth() == QUEUE_DEPTH
+
+
+@pytest.mark.parametrize("single_submit", [True, False])
+@pytest.mark.parametrize("overlap_events", [True, False])
+class TestRead(DistributedTest):
+ world_size = 1
+ reuse_dist_env = True
+ if not get_accelerator().is_available():
+ init_distributed = False
+ set_dist_env = False
+
+ def test_parallel_read(self, tmpdir, single_submit, overlap_events):
+
+ h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL)
+
+ gds_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name())
+ h.pin_device_tensor(gds_buffer)
+
+ _validate_handle_state(h, single_submit, overlap_events)
+
+ ref_file, _ = _do_ref_write(tmpdir)
+ read_status = h.sync_pread(gds_buffer, ref_file)
+ assert read_status == 1
+
+ with open(ref_file, 'rb') as f:
+ ref_buffer = list(f.read())
+ assert ref_buffer == gds_buffer.tolist()
+
+ h.unpin_device_tensor(gds_buffer)
+
+ def test_async_read(self, tmpdir, single_submit, overlap_events):
+
+ h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL)
+
+ gds_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name())
+ h.pin_device_tensor(gds_buffer)
+
+ _validate_handle_state(h, single_submit, overlap_events)
+
+ ref_file, _ = _do_ref_write(tmpdir)
+ read_status = h.async_pread(gds_buffer, ref_file)
+ assert read_status == 0
+
+ wait_status = h.wait()
+ assert wait_status == 1
+
+ with open(ref_file, 'rb') as f:
+ ref_buffer = list(f.read())
+ assert ref_buffer == gds_buffer.tolist()
+
+ h.unpin_device_tensor(gds_buffer)
+
+
+@pytest.mark.parametrize("single_submit", [True, False])
+@pytest.mark.parametrize("overlap_events", [True, False])
+class TestWrite(DistributedTest):
+ world_size = 1
+ reuse_dist_env = True
+ if not get_accelerator().is_available():
+ init_distributed = False
+ set_dist_env = False
+
+ def test_parallel_write(self, tmpdir, single_submit, overlap_events):
+
+ ref_file, ref_buffer = _do_ref_write(tmpdir)
+ h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL)
+
+ gds_file, gds_buffer = _get_test_write_file_and_device_buffer(tmpdir, ref_buffer, h)
+
+ _validate_handle_state(h, single_submit, overlap_events)
+
+ write_status = h.sync_pwrite(gds_buffer, gds_file)
+ assert write_status == 1
+
+ h.unpin_device_tensor(gds_buffer)
+
+ assert os.path.isfile(gds_file)
+
+ filecmp.clear_cache()
+ assert filecmp.cmp(ref_file, gds_file, shallow=False)
+
+ def test_async_write(self, tmpdir, single_submit, overlap_events):
+ ref_file, ref_buffer = _do_ref_write(tmpdir)
+
+ h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL)
+ gds_file, gds_buffer = _get_test_write_file_and_device_buffer(tmpdir, ref_buffer, h)
+
+ _validate_handle_state(h, single_submit, overlap_events)
+
+ write_status = h.async_pwrite(gds_buffer, gds_file)
+ assert write_status == 0
+
+ wait_status = h.wait()
+ assert wait_status == 1
+
+ h.unpin_device_tensor(gds_buffer)
+
+ assert os.path.isfile(gds_file)
+
+ filecmp.clear_cache()
+ assert filecmp.cmp(ref_file, gds_file, shallow=False)
+
+
+@pytest.mark.sequential
+class TestAsyncQueue(DistributedTest):
+ world_size = 1
+ if not get_accelerator().is_available():
+ init_distributed = False
+ set_dist_env = False
+
+ @pytest.mark.parametrize("async_queue", [2, 3])
+ def test_read(self, tmpdir, async_queue):
+
+ ref_files = []
+ for i in range(async_queue):
+ f, _ = _do_ref_write(tmpdir, i)
+ ref_files.append(f)
+
+ single_submit = True
+ overlap_events = True
+ h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL)
+
+ gds_buffers = [
+ torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name()) for _ in range(async_queue)
+ ]
+ for buf in gds_buffers:
+ h.pin_device_tensor(buf)
+
+ _validate_handle_state(h, single_submit, overlap_events)
+
+ for i in range(async_queue):
+ read_status = h.async_pread(gds_buffers[i], ref_files[i])
+ assert read_status == 0
+
+ wait_status = h.wait()
+ assert wait_status == async_queue
+
+ for i in range(async_queue):
+ with open(ref_files[i], 'rb') as f:
+ ref_buffer = list(f.read())
+ assert ref_buffer == gds_buffers[i].tolist()
+
+ for t in gds_buffers:
+ h.unpin_device_tensor(t)
+
+ @pytest.mark.parametrize("async_queue", [2, 3])
+ def test_write(self, tmpdir, async_queue):
+ ref_files = []
+ ref_buffers = []
+ for i in range(async_queue):
+ f, buf = _do_ref_write(tmpdir, i)
+ ref_files.append(f)
+ ref_buffers.append(buf)
+
+ single_submit = True
+ overlap_events = True
+ h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, single_submit, overlap_events, IO_PARALLEL)
+
+ gds_files = []
+ gds_buffers = []
+ for i in range(async_queue):
+ f, buf = _get_test_write_file_and_device_buffer(tmpdir, ref_buffers[i], h, i)
+ gds_files.append(f)
+ gds_buffers.append(buf)
+
+ _validate_handle_state(h, single_submit, overlap_events)
+
+ for i in range(async_queue):
+ read_status = h.async_pwrite(gds_buffers[i], gds_files[i])
+ assert read_status == 0
+
+ wait_status = h.wait()
+ assert wait_status == async_queue
+
+ for t in gds_buffers:
+ h.unpin_device_tensor(t)
+
+ for i in range(async_queue):
+ assert os.path.isfile(gds_files[i])
+
+ filecmp.clear_cache()
+ assert filecmp.cmp(ref_files[i], gds_files[i], shallow=False)
+
+
+@pytest.mark.parametrize("use_new_api", [True, False])
+class TestLockDeviceTensor(DistributedTest):
+ world_size = 2
+ reuse_dist_env = True
+ if not get_accelerator().is_available():
+ init_distributed = False
+ set_dist_env = False
+
+ def test_pin_device_tensor(self, use_new_api):
+
+ h = GDSBuilder().load().gds_handle(BLOCK_SIZE, QUEUE_DEPTH, True, True, IO_PARALLEL)
+
+ unpinned_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device=get_accelerator().device_name())
+ if use_new_api:
+ pinned_buffer = h.new_pinned_device_tensor(unpinned_buffer.numel(), unpinned_buffer)
+ else:
+ pinned_buffer = torch.empty_like(unpinned_buffer)
+ h.pin_device_tensor(pinned_buffer)
+
+ assert unpinned_buffer.device == pinned_buffer.device
+ assert unpinned_buffer.dtype == pinned_buffer.dtype
+ assert unpinned_buffer.numel() == pinned_buffer.numel()
+
+ if use_new_api:
+ h.free_pinned_device_tensor(pinned_buffer)
+ else:
+ h.unpin_device_tensor(pinned_buffer)
diff --git a/tests/unit/runtime/pipe/test_pipe.py b/tests/unit/runtime/pipe/test_pipe.py
index 88e26290b650..f198762c5fcc 100644
--- a/tests/unit/runtime/pipe/test_pipe.py
+++ b/tests/unit/runtime/pipe/test_pipe.py
@@ -7,12 +7,15 @@
import torch.nn as nn
import pytest
+import torch
+
+import deepspeed
import deepspeed.comm as dist
from deepspeed.runtime.pipe.topology import PipeDataParallelTopology
from deepspeed.runtime.pipe.module import PipelineModule
from unit.alexnet_model import AlexNetPipe, train_cifar
from unit.common import DistributedTest
-from unit.util import skip_on_arch
+from unit.util import skip_on_arch, no_child_process_in_deepspeed_io
PipeTopo = PipeDataParallelTopology
@@ -155,3 +158,95 @@ def test_pipe_use_reentrant(self, topo_config):
# the following check could passed on higher version docker: nvcr.io/nvidia/pytorch:23.07-py3(torch2.1.0 cuda12.1)
# Check if models have same weights after training
# self._check_model_params_equal(base_model, test_model)
+
+
+class DynamicShapeTestLayer(nn.Module):
+
+ def __init__(self, hidden_size):
+ super().__init__()
+ self.fc = nn.Linear(hidden_size, hidden_size)
+ self.shapes = set()
+
+ def forward(self, x):
+ self.shapes.add(x.shape)
+ y = self.fc(x)
+ return y
+
+
+class DynamicShapeTestModel(nn.Module):
+
+ def __init__(self, n_layers, hidden_size):
+ super().__init__()
+ self.layers = nn.ModuleList([DynamicShapeTestLayer(hidden_size) for _ in range(n_layers)])
+
+
+@pytest.mark.parametrize('topo_config', [
+ {
+ "num_pp": 1,
+ "num_dp": 4
+ },
+ {
+ "num_pp": 2,
+ "num_dp": 2
+ },
+ {
+ "num_pp": 4,
+ "num_dp": 1
+ },
+])
+class TestPipeDynamicShape(DistributedTest):
+ world_size = 4
+
+ def test_pipe_base(self, topo_config):
+ """This test checks if the pipeline engine can handle dynamic shapes correctly.
+ We pass inputs of different shapes to the pipeline engine.
+ """
+
+ n_iter = 10
+ n_layers = 4
+ n_samples = 1024
+ batch_size = 4
+ channel_dims = [8, 16, 32, 64]
+ hidden_size = 16
+
+ topo = PipeTopo(**topo_config)
+
+ model = DynamicShapeTestModel(n_layers, hidden_size)
+ model = PipelineModule(layers=model.layers, topology=topo, loss_fn=nn.MSELoss(), dynamic_shape=True)
+
+ # Each batch has different channel dim but we use the same channel dim in the same batch
+ xs = [
+ torch.randn(channel_dims[(i // batch_size) % len(channel_dims)], hidden_size, dtype=torch.float32)
+ for i in range(n_samples)
+ ]
+ ys = [torch.randn_like(x) for x in xs]
+
+ class CustomDataset(torch.utils.data.Dataset):
+
+ def __init__(self, xs, ys):
+ self.xs = xs
+ self.ys = ys
+
+ def __len__(self):
+ return len(self.xs)
+
+ def __getitem__(self, idx):
+ return self.xs[idx], self.ys[idx]
+
+ dataset = CustomDataset(xs, ys)
+
+ config_dict["train_batch_size"] = batch_size
+
+ with no_child_process_in_deepspeed_io():
+ engine, _, _, _ = deepspeed.initialize(config=config_dict,
+ model=model,
+ model_parameters=[p for p in model.parameters()],
+ training_data=dataset)
+
+ for _ in range(n_iter):
+ _ = engine.train_batch()
+
+ # Check if all layers have seen different shapes
+ for layer in model.modules():
+ if isinstance(layer, DynamicShapeTestLayer):
+ assert len(layer.shapes) > 1
diff --git a/tests/unit/runtime/test_ds_config_dict.py b/tests/unit/runtime/test_ds_config_dict.py
index c11c63d04867..d06b35e208fe 100644
--- a/tests/unit/runtime/test_ds_config_dict.py
+++ b/tests/unit/runtime/test_ds_config_dict.py
@@ -67,13 +67,11 @@ def _batch_assert(status, ds_config, batch, micro_batch, gas, success):
if not success:
assert not status
- print("Failed but All is well")
return
assert ds_config.train_batch_size == batch
assert ds_config.train_micro_batch_size_per_gpu == micro_batch
assert ds_config.gradient_accumulation_steps == gas
- print("All is well")
#Tests different batch config provided in deepspeed json file
diff --git a/tests/unit/runtime/test_ds_config_model.py b/tests/unit/runtime/test_ds_config_model.py
index 87ea747cf423..4d184b2858a8 100644
--- a/tests/unit/runtime/test_ds_config_model.py
+++ b/tests/unit/runtime/test_ds_config_model.py
@@ -4,18 +4,25 @@
# DeepSpeed Team
import pytest
-import os
import json
-from typing import List
-from deepspeed.pydantic_v1 import Field, ValidationError
+import os
+from typing import List, Optional
+
+from pydantic import Field, ValidationError
+
from deepspeed.runtime import config as ds_config
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
class SimpleConf(DeepSpeedConfigModel):
param_1: int = 0
- param_2_old: str = Field(None, deprecated=True, new_param="param_2", new_param_fn=(lambda x: [x]))
- param_2: List[str] = None
+ param_2_old: Optional[str] = Field(None,
+ json_schema_extra={
+ "deprecated": True,
+ "new_param": "param_2",
+ "new_param_fn": (lambda x: [x])
+ })
+ param_2: Optional[List[str]] = None
param_3: int = Field(0, alias="param_3_alias")
diff --git a/tests/unit/runtime/zero/test_unwrap_model.py b/tests/unit/runtime/zero/test_unwrap_model.py
new file mode 100644
index 000000000000..d75519b67f68
--- /dev/null
+++ b/tests/unit/runtime/zero/test_unwrap_model.py
@@ -0,0 +1,67 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import deepspeed
+from deepspeed.runtime.zero import unwrap_model_for_generation
+from deepspeed.accelerator import get_accelerator
+
+from unit.common import DistributedTest
+from unit.simple_model import SimpleModel
+
+config = {
+ "train_batch_size": 2,
+ "steps_per_print": 1,
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 0.00015
+ }
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "stage3_param_persistence_threshold": 1,
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": True
+ }
+ }
+}
+
+if get_accelerator().is_fp16_supported():
+ config["fp16"] = {"enabled": True, "loss_scale": 138.}
+elif get_accelerator().is_bf16_supported():
+ config["bf16"] = {"enabled": True}
+
+
+class TestUnwrapModel(DistributedTest):
+ # gather across more than 1 gpu
+ world_size = 2
+
+ def test(self):
+
+ def hooks_exist(engine):
+ if engine.optimizer is not None and hasattr(engine.optimizer, "parameter_offload"):
+ optimizer_offload = engine.optimizer.parameter_offload
+ elif engine.optimizer is not None:
+ optimizer_offload = engine.optimizer
+
+ hooks = 0
+ for hook in optimizer_offload.forward_hooks:
+ hooks += 1
+ if hooks > 0:
+ return True
+ return False
+
+ model = SimpleModel(hidden_dim=100)
+ engine, _, _, _ = deepspeed.initialize(args=None, model=model, config=config)
+
+ with unwrap_model_for_generation(engine):
+ # assert no hooks
+ assert not hooks_exist(engine)
+ # assert parameters gathered
+ assert model.linears[0].weight.numel() != 0, "GatheredParameters should give a non-0-sized tensor"
+
+ # assert hooks
+ assert hooks_exist(engine)
diff --git a/tests/unit/runtime/zero/test_zero_offloadpp.py b/tests/unit/runtime/zero/test_zero_offloadpp.py
index 5bfec399e19f..8ae99e2237e2 100644
--- a/tests/unit/runtime/zero/test_zero_offloadpp.py
+++ b/tests/unit/runtime/zero/test_zero_offloadpp.py
@@ -43,6 +43,7 @@ def test(self, h_dim: int, n_layers: int) -> None:
config_dict = {
"train_batch_size": 256,
"steps_per_print": 1,
+ "gradient_clipping": 1.0,
"optimizer": {
"type": "Adam",
"params": {
diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py
new file mode 100644
index 000000000000..915c89e0b00a
--- /dev/null
+++ b/tests/unit/sequence_parallelism/test_ulysses.py
@@ -0,0 +1,77 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import pytest
+import torch
+import deepspeed.comm as dist
+from deepspeed import initialize
+from transformers import AutoModel
+from unit.common import DistributedTest
+from deepspeed.sequence.layer import _SeqAllToAll
+from unit.util import skip_on_arch
+
+
+#Use mesh device to create data and sequence parallel group
+class TestUlyssesUtils(DistributedTest):
+ world_size = 4
+
+ def test_mesh_device_creation(self) -> None:
+ skip_on_arch(min_arch=8)
+ model = AutoModel.from_pretrained('bert-base-uncased')
+ sp_size = 2
+ dp_size = 2
+ ds_engine, _, _, _ = initialize(
+ model=model,
+ config_params={
+ "train_batch_size": 8,
+ "data_parallel_size": dp_size,
+ "sequence_parallel_size": sp_size
+ },
+ )
+ assert ds_engine.seq_parallel_group is not None
+ assert ds_engine.data_parallel_group is not None
+ assert dist.get_world_size(group=ds_engine.seq_parallel_group) == sp_size
+ assert dist.get_world_size(group=ds_engine.data_parallel_group) == dp_size
+ assert dist.get_world_size() == sp_size * dp_size
+
+
+#Sweep b,s,h,d to test all2all consistency
+@pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension
+@pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension
+@pytest.mark.parametrize("num_heads", [4, 8])
+@pytest.mark.parametrize("head_dim", [16, 32])
+class TestUlyssesAll2All(DistributedTest):
+ world_size = 4
+
+ def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None:
+ skip_on_arch(min_arch=8)
+ model = AutoModel.from_pretrained('bert-base-uncased')
+ ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": 8}, mesh_param=(2, 2))
+ #4D tensor : b,s,h,d or s,b,h,d
+ input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device)
+ scatter_idx = 2
+ batch_dim_idx = 0
+ outputs = []
+ seq_dims = [0] #seq first API
+ #TODO: Add support for batch first (that seq_dims=[0,1]) after PR for bs>1 issue with batch first is fixed
+ ## See discussion in : https://github.com/microsoft/DeepSpeed/issues/5808
+ for seq_dim in seq_dims:
+ gather_idx = seq_dim
+ #first all2all: sequence parallel to head parallel
+ s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx,
+ batch_dim_idx)
+
+ #No op
+ # second all2all: head parallel to sequence parallel
+ h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx,
+ batch_dim_idx)
+ print(
+ f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}'
+ )
+ outputs.append(h2s_tensor)
+
+ # Check outputs are the same as input
+ for i in range(1, len(outputs)):
+ assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}"
diff --git a/tests/unit/util.py b/tests/unit/util.py
index feec326ede6c..dba29ed27a4c 100644
--- a/tests/unit/util.py
+++ b/tests/unit/util.py
@@ -5,6 +5,8 @@
import pytest
import torch
+
+import deepspeed
from deepspeed.accelerator import get_accelerator, is_current_accelerator_supported
from deepspeed.git_version_info import torch_info
@@ -67,3 +69,22 @@ def required_amp_check():
return False
else:
return True
+
+
+class no_child_process_in_deepspeed_io:
+
+ def __enter__(self):
+ # deepspeed_io defaults to creating a dataloader that uses a
+ # multiprocessing pool. Our tests use pools and we cannot nest pools in
+ # python. Therefore we're injecting this kwarg to ensure that no pools
+ # are used in the dataloader.
+ self.old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io
+
+ def new_method(*args, **kwargs):
+ kwargs["num_local_io_workers"] = 0
+ return self.old_method(*args, **kwargs)
+
+ deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method
+
+ def __exit__(self, *_):
+ deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = self.old_method
diff --git a/version.txt b/version.txt
index 436d0ce0df76..e815b861f023 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.14.5
+0.15.1