Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A faster and more memory-efficient implementation of zero_to_fp32 #6658

Merged
merged 12 commits into from
Nov 18, 2024
150 changes: 117 additions & 33 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import math
import os
import re
import gc
import json
import numpy as np
from tqdm import tqdm
from collections import OrderedDict
from dataclasses import dataclass
Expand Down Expand Up @@ -146,8 +148,8 @@ def parse_model_states(files):
def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in files:
state_dict = torch.load(f, map_location=device)
for f in tqdm(files, desc='Loading checkpoint shards'):
state_dict = torch.load(f, map_location=device, mmap=True)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
Expand Down Expand Up @@ -179,19 +181,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
else:
raise ValueError(f"unknown zero stage {zero_stage}")

if zero_stage <= 2:
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
#
# XXX: could make the script more memory efficient for when there are multiple groups - it
# will require matching the sub-lists of param_shapes for each param group flattened tensor

fp32_flat_groups = [
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
]

fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
return zero_stage, world_size, fp32_flat_groups


Expand Down Expand Up @@ -398,9 +388,56 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")


class GatheredTensor:
"""
A pseudo tensor that collects partitioned weights.
It is more memory efficient when there are multiple groups.
"""

def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
self.flat_groups = flat_groups
self.flat_groups_offset = flat_groups_offset
self.offset = offset
self.partitioned_numel = partitioned_numel
self.shape = shape
self.dtype = self.flat_groups[0][0].dtype

def contiguous(self):
"""
Merge partitioned weights from flat_groups into a single tensor.
"""
end_idx = self.offset + self.partitioned_numel
world_size = len(self.flat_groups)
pad_flat_param_chunks = []

for rank_i in range(world_size):
# for each rank, we need to collect weights from related group/groups
flat_groups_at_rank_i = self.flat_groups[rank_i]
start_group_id = None
end_group_id = None
for group_id in range(len(self.flat_groups_offset)):
if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
start_group_id = group_id
if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
end_group_id = group_id
break
# collect weights from related group/groups
for group_id in range(start_group_id, end_group_id + 1):
flat_tensor = flat_groups_at_rank_i[group_id]
start_offset = self.offset - self.flat_groups_offset[group_id]
end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])

# collect weights from all ranks
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
return param


def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
avail_numel = fp32_flat_groups[0].numel() * world_size
avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size

# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any

Expand All @@ -424,7 +461,8 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
offset = 0
total_numel = 0
total_params = 0
for name, shape in tqdm(param_shapes.items(), desc='Gathering Sharded Weights'):
flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1
Expand All @@ -435,10 +473,9 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)

# XXX: memory usage doubles here
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
0).narrow(0, 0, unpartitioned_numel).view(shape)
# memory efficient tensor
tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
state_dict[name] = tensor
offset += partitioned_numel

offset *= world_size
Expand Down Expand Up @@ -473,7 +510,29 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer
return state_dict


def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
def to_torch_tensor(state_dict, return_empty_tensor=False):
"""
Convert state_dict of GatheredTensor to torch tensor
"""
converted_tensors = {}
for name, tensor in state_dict.items():
tensor_id = id(tensor)
if tensor_id in converted_tensors:
shared_tensor = state_dict[converted_tensors[tensor_id]]
state_dict[name] = shared_tensor
else:
converted_tensors[tensor_id] = name
if return_empty_tensor:
state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
else:
state_dict[name] = tensor.contiguous()
return state_dict


def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
tag=None,
exclude_frozen_parameters=False,
lazy_mode=False):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
Expand All @@ -483,14 +542,12 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
- ``checkpoint_dir``: path to the desired checkpoint folder
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters
- ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
Convert the pesduo tensor to torch tensor by ``.contiguous()``

Returns:
- pytorch ``state_dict``

Note: this approach may not work if your application doesn't have sufficient free CPU memory and
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
the checkpoint.

A typical usage might be ::

from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
Expand All @@ -506,6 +563,16 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f

If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.

Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
the checkpoint. Or you can load state_dict in lazy mode ::

from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
for name, lazy_tensor in state_dict.item():
tensor = lazy_tensor.contiguous() # to cpu
print(name, tensor)
# del tensor to release memory if it no longer in use
"""
if tag is None:
latest_path = os.path.join(checkpoint_dir, 'latest')
Expand All @@ -520,7 +587,11 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
if not os.path.isdir(ds_checkpoint_dir):
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")

return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
if lazy_mode:
return state_dict
else:
return to_torch_tensor(state_dict)


def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
Expand All @@ -541,6 +612,7 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters
"""

# Dependency pre-check
if safe_serialization:
try:
Expand All @@ -556,13 +628,18 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
raise

# Convert zero checkpoint to state_dict
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
tag,
exclude_frozen_parameters,
lazy_merge=True)

# Shard the model if it is too big.
weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
if max_shard_size is not None:
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(state_dict,
# an memory-efficient approach for sharding
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size)
else:
Expand All @@ -571,15 +648,22 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
state_dict_split = StateDictSplit(is_sharded=False,
filename_to_tensors={weights_name: list(state_dict.keys())})

# Save the model
# Save the model by shard
os.makedirs(output_dir, exist_ok=True)
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
shard_state_dict = to_torch_tensor(shard_state_dict)
output_path = os.path.join(output_dir, shard_file)
if safe_serialization:
save_file(shard, output_path, metadata={"format": "pt"})
save_file(shard_state_dict, output_path, metadata={"format": "pt"})
else:
torch.save(shard, output_path)
torch.save(shard_state_dict, output_path)
# release the memory of current shard
for tensor_name in shard_state_dict:
del state_dict[tensor_name]
del shard_state_dict
gc.collect()

# Save index if sharded
if state_dict_split.is_sharded:
Expand Down
60 changes: 60 additions & 0 deletions tests/unit/checkpoint/test_convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import torch.nn as nn

import deepspeed
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from unit.common import DistributedTest


class ModelWithSharedWeights(nn.Module):

def __init__(self):
super().__init__()
self.layer0 = nn.Linear(100, 100)
self.layer1 = nn.Linear(200, 200)
self.layer2 = nn.Linear(300, 300)
# tie layer 1 and layer 2
self.layer1.weight = self.layer2.weight


class TestCheckpointConvert(DistributedTest):
world_size = 2

def test_convert_zero_checkpoint_to_fp32_state_dict(self, tmpdir):
config = {
"train_micro_batch_size_per_gpu": 2,
"zero_allow_untested_optimizer": True,
"zero_optimization": {
"stage": 3
},
}
model = ModelWithSharedWeights()
optimizer = torch.optim.Adam(model.parameters())

deepspeed_engine, _, _, _ = deepspeed.initialize(
config=config,
model=model,
optimizer=optimizer,
)
ds_save_dir = tmpdir / "checkpoint_ds"
deepspeed_engine.save_checkpoint(ds_save_dir, tag="checkpoint")

model = ModelWithSharedWeights()

# save checkpoint
fp32_save_dir = tmpdir / "checkpoint_fp32"
convert_zero_checkpoint_to_fp32_state_dict(ds_save_dir, fp32_save_dir)

# load state_dict from fp32 checkpoint
state_dict = torch.load(fp32_save_dir / 'pytorch_model.bin')

# check shared tensor
assert id(state_dict['layer1.weight']) == id(state_dict['layer2.weight'])

# load state_dict into model
model.load_state_dict(state_dict, strict=True)
Loading