-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
A faster and more memory-efficient implementation of
zero_to_fp32
(#…
…6658) It is a faster and more memory-efficient implementation of `zero_to_fp32`. The previous version double the memory usage, which cause cpu OOM for very large models (e.g. llama 405B). https://github.com/microsoft/DeepSpeed/blob/b647fb2470f8f6fefe5cab0ea84a2d89696eb898/deepspeed/utils/zero_to_fp32.py#L438-L441 ## How does it work? 1. **Lazy loading**: Load checkpoint with `mmap=True`, thus the weights are mmaped rather than loading all the storages into memory. 2. **Lazy merge**: `GatheredTensor` contains the mmaped weights and tensor offset. It is a memory-efficient pseudo tensor. Only when `tensor.contiguous()` is called, it starts to load related weights to memory and merge into a single tensor. 3. **Release memory in time**: Save checkpoints shard by shard, and release the memory once a shard is saved. Throughout the process, only one shard of tensors are keeped in memory. ## How much benefit in speed and memory ? Experiments were conducted on a linux host with 1TB of memory. Here is a detailed comparision | | world size | peak memory(GB) | elapsed time(h:mm:ss) | |----------------------|------------|--------------|--------------------| | llama3-8B(old->new) | 8 | 90 -> 41 | 0:02:17 -> 0:01:10 | | llama2-13B(old->new) | 8 | 146 -> 54 | 0:02:30 -> 0:01:47 | | llama2-70B(old->new) | 16 | 789 -> 159 | 0:20:47 -> 0:20:45 | | qwen1.5-110B(old->new) | 32 | OOM -> 217 | ? -> 0:34:21 | | llama3-405B(old->new) | 192 | OOM -> 262 | ? -> 2:09:59 | You can reproduce with the following scripts ```sh # 1. install requirments apt-get install time # 2. prepare zero-3 checkpoints # 3. convert zero to fp32 checkpoints /usr/bin/time -v python zero_to_fp32.py . output_dir/ --safe_serialization ``` - **memory**: Theoretically, this PR reduces the memory cost from `2M` to `(1/n)M`, where `M` is the memory cost of the full weights, `n` is num_shards. - **speed**: The speed gain mainly comes from avoiding extra tensor copying. The benifit may be slight. ## Impl history - [v1](xu-song@19712a1#diff-6a2ca3427fa608c387b7351359f98cfc1313be6e960cee86344ff246bf1b8326R441-R447) : a hf_hub compatible approach. It has been discarded due to the controversial implementation of `data_ptr().` - [v2](https://github.com/microsoft/DeepSpeed/pull/6658/files): a simple approach with `torch.empty` --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]>
- Loading branch information
1 parent
f594dbe
commit dd40269
Showing
2 changed files
with
177 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
import deepspeed | ||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict | ||
from unit.common import DistributedTest | ||
|
||
|
||
class ModelWithSharedWeights(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.layer0 = nn.Linear(100, 100) | ||
self.layer1 = nn.Linear(200, 200) | ||
self.layer2 = nn.Linear(300, 300) | ||
# tie layer 1 and layer 2 | ||
self.layer1.weight = self.layer2.weight | ||
|
||
|
||
class TestCheckpointConvert(DistributedTest): | ||
world_size = 2 | ||
|
||
def test_convert_zero_checkpoint_to_fp32_state_dict(self, tmpdir): | ||
config = { | ||
"train_micro_batch_size_per_gpu": 2, | ||
"zero_allow_untested_optimizer": True, | ||
"zero_optimization": { | ||
"stage": 3 | ||
}, | ||
} | ||
model = ModelWithSharedWeights() | ||
optimizer = torch.optim.Adam(model.parameters()) | ||
|
||
deepspeed_engine, _, _, _ = deepspeed.initialize( | ||
config=config, | ||
model=model, | ||
optimizer=optimizer, | ||
) | ||
ds_save_dir = tmpdir / "checkpoint_ds" | ||
deepspeed_engine.save_checkpoint(ds_save_dir, tag="checkpoint") | ||
|
||
model = ModelWithSharedWeights() | ||
|
||
# save checkpoint | ||
fp32_save_dir = tmpdir / "checkpoint_fp32" | ||
convert_zero_checkpoint_to_fp32_state_dict(ds_save_dir, fp32_save_dir) | ||
|
||
# load state_dict from fp32 checkpoint | ||
state_dict = torch.load(fp32_save_dir / 'pytorch_model.bin') | ||
|
||
# check shared tensor | ||
assert id(state_dict['layer1.weight']) == id(state_dict['layer2.weight']) | ||
|
||
# load state_dict into model | ||
model.load_state_dict(state_dict, strict=True) |