Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
CurryRice233 authored Oct 17, 2023
2 parents da7edd6 + 12aedac commit 4b6197f
Show file tree
Hide file tree
Showing 13 changed files with 123 additions and 50 deletions.
4 changes: 4 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def max_memory_reserved(self, device_index=None):
def total_memory(self, device_index=None):
...

@abc.abstractmethod
def available_memory(self, device_index=None):
...

# Data types
@abc.abstractmethod
def is_bf16_supported(self):
Expand Down
3 changes: 3 additions & 0 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def max_memory_reserved(self, device_index=None):
def total_memory(self, device_index=None):
return psutil.virtual_memory().total

def available_memory(self, device_index=None):
return psutil.virtual_memory().available

# Misc
def amp(self):
return torch.cpu.amp
Expand Down
25 changes: 25 additions & 0 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,29 @@
except ImportError:
pass

# Delay import pynvml to avoid import error when CUDA is not available
pynvml = None


class CUDA_Accelerator(DeepSpeedAccelerator):

def __init__(self):
self._name = 'cuda'
self._communication_backend_name = 'nccl'
if pynvml is None:
self._init_pynvml()

def _init_pynvml(self):
global pynvml
try:
import pynvml
except ImportError:
return
try:
pynvml.nvmlInit()
except pynvml.NVMLError:
pynvml = None
return

def is_synchronized_device(self):
return False
Expand Down Expand Up @@ -136,6 +153,14 @@ def max_memory_reserved(self, device_index=None):
def total_memory(self, device_index=None):
return torch.cuda.get_device_properties(device_index).total_memory

def available_memory(self, device_index=None):
if pynvml:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.free
else:
return self.total_memory(device_index) - self.memory_allocated(device_index)

# Data types
def is_bf16_supported(self):
return torch.cuda.is_bf16_supported()
Expand Down
3 changes: 3 additions & 0 deletions accelerator/mps_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def max_memory_reserved(self, device_index=None):
def total_memory(self, device_index=None):
return

def available_memory(self, device_index=None):
return

# Data types
def is_bf16_supported(self):
return False
Expand Down
3 changes: 3 additions & 0 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def max_memory_reserved(self, device_index=None):
def total_memory(self, device_index=None):
return torch.npu.get_device_properties(device_index).total_memory

def available_memory(self, device_index=None):
return self.total_memory(device_index) - self.memory_allocated(device_index)

# Data types
def is_bf16_supported(self):
return torch.npu.is_bf16_supported()
Expand Down
134 changes: 84 additions & 50 deletions blogs/deepspeed-ulysses/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@

</div>

To cite DeepSpeed-Ulysses, please cite our [arxiv report](https://arxiv.org/abs/2309.14509):

```
@article{jacobs2023deepspeed,
title={DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models},
author={Sam Ade Jacobs and Masahiro Tanaka and Chengming Zhang and Minjia Zhang and Shuaiwen Leon Song and Samyam Rajbhandari and Yuxiong He},
journal={arXiv preprint arXiv:2309.14509},
year={2023},
}
```

## Introduction

Training large models with long sequences is becoming very important
Expand Down Expand Up @@ -193,7 +204,7 @@ scaling not just to large sequence lengths but also to large models.

## Evaluation

We evaluate DeepSpeed-Ulysses on GPT,
We evaluate DeepSpeed-Ulysses (Ulysses) on GPT,
a foundation model for many NLP tasks on up to 64 A100 GPUs with 40GB memory. Our
evaluations are four-fold: i) sequence length scalability, ii)
throughput for dense attention and comparison with existing system, and
Expand All @@ -212,74 +223,97 @@ maintains similar computation throughput across different sequence
length at appropriate GPU count.

<div align="center">
<img src="./media/fig2Ulysses.png" style="width:5in;height:4in" />
<img src="./media/dense1B1Mscale.png" style="width:5in;height:4in" />

*Figure 2: DeepSpeed sequence parallelism strong scalability evaluation
at different sequence length and GPU count.*
</div>

### Dense Attention Evaluation

Next, we evaluate DeepSpeed sequence parallelism on 30 billion parameter
dense attention model and benchmark against Megatron sequence
parallelism on 64 A100 GPUs. The results of these evaluations are shown
in Figures 3.

We compare DeepSpeed sequence parallelism with Megatron-LM for a 30B
model running various sequence lengths. For our evaluation we chose the
sequence parallelism degree and global batch size that produced the best
performance (measured as throughput or TFLOPs) for both DeepSpeed
sequence parallelism and Megatron-LM, this we call optimal (batch
size-sequence length) configurations. For DeepSpeed sequence
parallelism, we always use a ZeRO parallelism degree of 64.

Figure 3 shows that DeepSpeed sequence parallelism consistently
outperforms Megatron-LM for the sequence length that can be run with
both. In addition, DeepSpeed sequence parallelism can run longer
sequence than Megatron-LM. DeepSpeed sequence parallelism performance
advantages are two folds: (1) DeepSpeed sequence parallelism in
combination with ZeRO-3 fits more sample than Megatron-LM because of
memory optimization leading to higher throughput (2) DeepSpeed sequence
parallelism benefits from efficient all-to-all communication relative to
*all-gather* communication as applied in Megatron-LM sequence
parallelism.
Next, we evaluate Ulysses on 7 billion (7B) and 30 billion (30B) parameter
GPT dense attention models and compare against Megatron-LM's sequence
parallelism (Megatron LM) and Colosal AI sequence parallelism (ColAI-SP) on
32 and 64 A100 GPUs respectively. The results of these evaluations are shown
in Figures 3 and 4.

We compare Ulysses with Megatron-LM and ColAI-SP for 7B and 30B models
running various sequence lengths. We chose the sequence parallelism
degree and micro-batch size that produced the best performance
(measured as TFLOPs) for the three methods, this we call optimal
(batch size-sequence length) configurations. For Ulysses, we always
use a ZeRO-3 parallelism degrees of 32 and 64 for 7B and 30B models
respectively.


Figures 3 and 4 show that Ulysses consistently outperforms Megatron-LM
and ColAI-SP for the sequence length that can be run with them. In addition,
Ulysses can run longer sequence than the two existing methods. Ulysses
performance advantages are two folds: (1) Ulysses in combination with ZeRO-3
parameter sharding across both data and sequence parallel groups fits more
samples than Megatron-LM and ColAI-SP because of the memory optimization
leading to higher throughput (2) Ulysses benefits from efficient *all-to-all*
communication relative to *all-gather* *reduce-scatter* and *ring-style* P2P
communication as applied in Megatron-LM and ColAI-SP sequence parallelism.
However, for dense attention at long sequence length, the throughput is
primarily determined by local attention computation due to quadratic
computation complexity of attention, therefore performance gap between Ulysses
and the two existing methods closes for sequence length that can be run with them.

<div align="center">
<img src="./media/fig3Ulysses.png" style="width:5in;height:4in" />
<img src="./media/dense7B.png" style="width:5in;height:4in" />

*Figure 3: Evaluation of DeepSpeed and Megatron LM sequence parallelism on 30B
parameter model with dense attention.*
*Figure 3: Evaluation of Ulysses vs Megatron LM vs ColAI-SP on GPT-7B parameter
model with dense attention (32 GPUs).*
</div>

<div align="center">
<img src="./media/dense30B.png" style="width:5in;height:4in" />

*Figure 4: Evaluation of Ulysses vs Megatron LM vs ColAI-SP on GPT-30B parameter
model with dense attention (64 GPUs).*
</div>

### Sparse Attention Evaluation

Similarly, we evaluate DeepSpeed sequence parallelism on 30 billion
parameter sparse attention model and benchmark against Megatron sequence
parallelism. Results of our evaluation are shown in Figure 4. We observe
similar trends with sparse attention as dense attention experiments. We
observe more than 2X throughput performance of DeepSpeed sequence
parallelism compared to Megatron-LM. For memory saving, DeepSpeed
sequence parallelism leveraging ZeRO-3 scales to 4X longer sequence
lengths than Megatron-LM.

DeepSpeed sequence parallelism outperforms Megatron-LM for sequence
length that can be run with both. In fact, the current DeepSpeed
throughput is bottlenecked by the local sparse attention implementation,
and as a result DeepSpeed throughput decreases as the sequence length
increases. We expect this gap in performance between DeepSpeed and
Megatron to increase further for larger sequence lengths as we improve
the performance of the local sparse attention implementation in future.
Similarly, we evaluate Ulysses on 7 billion and 30 billion parameter sparse
attention models and benchmark against Megatron-LM sequence parallelism.
There is no public implementation of block sparse attention for ColAI-SP,
therefore, evaluation of sparse attention is in comparison with Megatron-LM.
Results of our evaluation are shown in Figures 5 and 6. We observe similar
trends with sparse attention as dense attention experiments. We observe more
than 2x throughput performance of Ulysses compared to Megatron-LM. For memory
saving, Ulysses leveraging ZeRO-3 scales to 4x longer sequence lengths
than Megatron-LM.

Ulysses outperforms Megatron-LM for sequence length that can be run with both.
In fact, the current Ulysses throughput is bottle-necked by the local sparse
attention implementation, and as a result Ulysses throughput decreases as
the sequence length increases. We expect this gap in performance between our
method and Megatron-LM to increase further for larger sequence lengths as we
improve the performance of the local sparse attention implementation in future.
A noteworthy observation is that the decreasing performance gap between Ulysses
and Megatron-LM observed in dense attention evaluation is less pronounced in
sparse attention evaluation, because the attention computation in sparse attention
is less dominant compared to dense attention.

<div align="center">
<img src="./media/sparse7B.png" style="width:5in;height:4in" />

*Figure 5: Evaluation of Ulysses and Megatron LM sequence parallelism on GPT-7B
parameter model with block sparse attention (32 GPUs).*
</div>

<div align="center">
<img src="./media/fig4Ulysses.png" style="width:5in;height:4in" />
<img src="./media/sparse30B.png" style="width:5in;height:4in" />

*Figure 4: Evaluation of DeepSpeed and Megatron LM sequence parallelism on 30B
parameter model with block sparse attention.*
*Figure 6: Evaluation of Ulysses and Megatron LM sequence parallelism on GPT-30B
parameter model with block sparse attention (64 GPUs).*
</div>

### Convergence Study

Lastly, Figure 5 shows convergence of a 1.3 billion GPT model at 32K
Lastly, Figure 7 shows convergence of a 1.3 billion GPT model at 32K
sequence length on 8 A100 GPUs with sequence parallelism degree set at 4
for both DeepSpeed and Megatron-LM sequence parallelism. For DeepSpeed
sequence parallelism, we evaluate convergence with different ZeRO
Expand All @@ -289,9 +323,9 @@ there is no (negative) impact on quality of trained models, this assertion is
validated through experiments and is shown in Figure 5.

<div align="center">
<img src="./media/convg.png" width="500px" />
<img src="./media/convgZ.png" width="500px" />

*Figure 5: Convergence evaluation of DeepSpeed sequence parallelism with different
*Figure 7: Convergence evaluation of DeepSpeed sequence parallelism with different
ZeRO memory optimization stages.*
</div>

Expand Down
Binary file added blogs/deepspeed-ulysses/media/convgZ.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/deepspeed-ulysses/media/dense1B1Mscale.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/deepspeed-ulysses/media/dense30B.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/deepspeed-ulysses/media/dense7B.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/deepspeed-ulysses/media/sparse30B.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added blogs/deepspeed-ulysses/media/sparse7B.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ packaging>=20.0
psutil
py-cpuinfo
pydantic
pynvml
torch
tqdm

0 comments on commit 4b6197f

Please sign in to comment.