diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index e69ecd9acb5a..c0768deae62b 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -21,7 +21,9 @@ import math import os import re +import gc import json +import numpy as np from tqdm import tqdm from collections import OrderedDict from dataclasses import dataclass @@ -146,8 +148,8 @@ def parse_model_states(files): def parse_optim_states(files, ds_checkpoint_dir): total_files = len(files) state_dicts = [] - for f in files: - state_dict = torch.load(f, map_location=device) + for f in tqdm(files, desc='Loading checkpoint shards'): + state_dict = torch.load(f, map_location=device, mmap=True) # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights # and also handle the case where it was already removed by another helper script state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) @@ -179,19 +181,7 @@ def parse_optim_states(files, ds_checkpoint_dir): else: raise ValueError(f"unknown zero stage {zero_stage}") - if zero_stage <= 2: - fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] - elif zero_stage == 3: - # if there is more than one param group, there will be multiple flattened tensors - one - # flattened tensor per group - for simplicity merge them into a single tensor - # - # XXX: could make the script more memory efficient for when there are multiple groups - it - # will require matching the sub-lists of param_shapes for each param group flattened tensor - - fp32_flat_groups = [ - torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) - ] - + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] return zero_stage, world_size, fp32_flat_groups @@ -398,9 +388,56 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") +class GatheredTensor: + """ + A pseudo tensor that collects partitioned weights. + It is more memory efficient when there are multiple groups. + """ + + def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape): + self.flat_groups = flat_groups + self.flat_groups_offset = flat_groups_offset + self.offset = offset + self.partitioned_numel = partitioned_numel + self.shape = shape + self.dtype = self.flat_groups[0][0].dtype + + def contiguous(self): + """ + Merge partitioned weights from flat_groups into a single tensor. + """ + end_idx = self.offset + self.partitioned_numel + world_size = len(self.flat_groups) + pad_flat_param_chunks = [] + + for rank_i in range(world_size): + # for each rank, we need to collect weights from related group/groups + flat_groups_at_rank_i = self.flat_groups[rank_i] + start_group_id = None + end_group_id = None + for group_id in range(len(self.flat_groups_offset)): + if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]: + start_group_id = group_id + if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]: + end_group_id = group_id + break + # collect weights from related group/groups + for group_id in range(start_group_id, end_group_id + 1): + flat_tensor = flat_groups_at_rank_i[group_id] + start_offset = self.offset - self.flat_groups_offset[group_id] + end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id] + pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset]) + + # collect weights from all ranks + pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0) + param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous() + return param + + def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): param_shapes = zero_model_states[0].param_shapes - avail_numel = fp32_flat_groups[0].numel() * world_size + avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each # param, re-consolidating each param, while dealing with padding if any @@ -424,7 +461,8 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero offset = 0 total_numel = 0 total_params = 0 - for name, shape in tqdm(param_shapes.items(), desc='Gathering Sharded Weights'): + flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]])) + for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'): unpartitioned_numel = shape.numel() total_numel += unpartitioned_numel total_params += 1 @@ -435,10 +473,9 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" ) - # XXX: memory usage doubles here - state_dict[name] = torch.cat( - tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), - 0).narrow(0, 0, unpartitioned_numel).view(shape) + # memory efficient tensor + tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape) + state_dict[name] = tensor offset += partitioned_numel offset *= world_size @@ -473,7 +510,29 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer return state_dict -def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): +def to_torch_tensor(state_dict, return_empty_tensor=False): + """ + Convert state_dict of GatheredTensor to torch tensor + """ + converted_tensors = {} + for name, tensor in state_dict.items(): + tensor_id = id(tensor) + if tensor_id in converted_tensors: + shared_tensor = state_dict[converted_tensors[tensor_id]] + state_dict[name] = shared_tensor + else: + converted_tensors[tensor_id] = name + if return_empty_tensor: + state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype) + else: + state_dict[name] = tensor.contiguous() + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, + tag=None, + exclude_frozen_parameters=False, + lazy_mode=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example @@ -483,14 +542,12 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f - ``checkpoint_dir``: path to the desired checkpoint folder - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` - ``exclude_frozen_parameters``: exclude frozen parameters + - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient. + Convert the pesduo tensor to torch tensor by ``.contiguous()`` Returns: - pytorch ``state_dict`` - Note: this approach may not work if your application doesn't have sufficient free CPU memory and - you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with - the checkpoint. - A typical usage might be :: from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint @@ -506,6 +563,16 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + Note: the above usage may not work if your application doesn't have sufficient free CPU memory. + You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. Or you can load state_dict in lazy mode :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu + for name, lazy_tensor in state_dict.item(): + tensor = lazy_tensor.contiguous() # to cpu + print(name, tensor) + # del tensor to release memory if it no longer in use """ if tag is None: latest_path = os.path.join(checkpoint_dir, 'latest') @@ -520,7 +587,11 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f if not os.path.isdir(ds_checkpoint_dir): raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") - return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + if lazy_mode: + return state_dict + else: + return to_torch_tensor(state_dict) def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, @@ -541,6 +612,7 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` - ``exclude_frozen_parameters``: exclude frozen parameters """ + # Dependency pre-check if safe_serialization: try: @@ -556,13 +628,18 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, raise # Convert zero checkpoint to state_dict - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, + tag, + exclude_frozen_parameters, + lazy_mode=True) # Shard the model if it is too big. weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin" if max_shard_size is not None: filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") - state_dict_split = split_torch_state_dict_into_shards(state_dict, + # an memory-efficient approach for sharding + empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True) + state_dict_split = split_torch_state_dict_into_shards(empty_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size) else: @@ -571,15 +648,22 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, state_dict_split = StateDictSplit(is_sharded=False, filename_to_tensors={weights_name: list(state_dict.keys())}) - # Save the model + # Save the model by shard + os.makedirs(output_dir, exist_ok=True) filename_to_tensors = state_dict_split.filename_to_tensors.items() for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"): - shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors} + shard_state_dict = to_torch_tensor(shard_state_dict) output_path = os.path.join(output_dir, shard_file) if safe_serialization: - save_file(shard, output_path, metadata={"format": "pt"}) + save_file(shard_state_dict, output_path, metadata={"format": "pt"}) else: - torch.save(shard, output_path) + torch.save(shard_state_dict, output_path) + # release the memory of current shard + for tensor_name in shard_state_dict: + del state_dict[tensor_name] + del shard_state_dict + gc.collect() # Save index if sharded if state_dict_split.is_sharded: diff --git a/tests/unit/checkpoint/test_convert_checkpoint.py b/tests/unit/checkpoint/test_convert_checkpoint.py new file mode 100644 index 000000000000..68fdecb32e16 --- /dev/null +++ b/tests/unit/checkpoint/test_convert_checkpoint.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn as nn + +import deepspeed +from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict +from unit.common import DistributedTest + + +class ModelWithSharedWeights(nn.Module): + + def __init__(self): + super().__init__() + self.layer0 = nn.Linear(100, 100) + self.layer1 = nn.Linear(200, 200) + self.layer2 = nn.Linear(300, 300) + # tie layer 1 and layer 2 + self.layer1.weight = self.layer2.weight + + +class TestCheckpointConvert(DistributedTest): + world_size = 2 + + def test_convert_zero_checkpoint_to_fp32_state_dict(self, tmpdir): + config = { + "train_micro_batch_size_per_gpu": 2, + "zero_allow_untested_optimizer": True, + "zero_optimization": { + "stage": 3 + }, + } + model = ModelWithSharedWeights() + optimizer = torch.optim.Adam(model.parameters()) + + deepspeed_engine, _, _, _ = deepspeed.initialize( + config=config, + model=model, + optimizer=optimizer, + ) + ds_save_dir = tmpdir / "checkpoint_ds" + deepspeed_engine.save_checkpoint(ds_save_dir, tag="checkpoint") + + model = ModelWithSharedWeights() + + # save checkpoint + fp32_save_dir = tmpdir / "checkpoint_fp32" + convert_zero_checkpoint_to_fp32_state_dict(ds_save_dir, fp32_save_dir) + + # load state_dict from fp32 checkpoint + state_dict = torch.load(fp32_save_dir / 'pytorch_model.bin') + + # check shared tensor + assert id(state_dict['layer1.weight']) == id(state_dict['layer2.weight']) + + # load state_dict into model + model.load_state_dict(state_dict, strict=True)