Skip to content

Commit

Permalink
Checkpoint conversion tools (bigcode-project#14)
Browse files Browse the repository at this point in the history
* Checkpoint conversion tools

* Fix formatting

* 1) Provide args in converted checkpoint
2) Reshape TP and PP degrees

* Fix typo

* Fix link

* Tweak tag

* Fix converted TP and PP sizes

* For release mode

* Update README

* Nested embedding dicts
Iteration folder
latest checkpoint version file
  • Loading branch information
tjruwase authored and mayank31398 committed Jun 21, 2023
1 parent 51bdf95 commit a7e7d11
Show file tree
Hide file tree
Showing 5 changed files with 491 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tools/convert_checkpoint/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Introduction
This folder is a collection of scripts for converting checkpoints of one training framework (e.g., DeepSpeed) into that of a different framework (e.g., Megatron-LM). inspecting checkpoints. The folder also contains scripts for inspecting checkpoint files and folders, which could be useful when developing checkpoint conversion logic. At the time of creation, this folder contains scripts to convert DeepSpeed checkpoints to Megatron-LM checkpoints (this motivated this effort as part of the BigScience project).

Here are the list and details of checkpoint conversions provided by the available scripts.
1. [DeepSpeed to Megatron-LM](#DeepSpeed-to-Megatron)


## DeepSpeed to Megatron
The (current implementation of the) converter extracts args and model parameters from a DeepSpeed checkpoint (i.e., excludes other training states such as optimizer, scheduler, etc) and convert into a Megatron-LM checkpoint similarly containing only model parameters. The converter also provides a best-effort attempt to reshape the tensor-parallelism and pipeline parallelism degrees for the checkpoint. The resulting Megatron-LM checkpoint could be loaded into Megatron-LM framework for finetuning or inference. Tensor parallelism (TP) and pipeline parallelism (PP) are supported in the sense that the generated Megatron-LM checkpoint (folders and files) will be of the same TP and PP of the training that created the input DeepSpeed checkpoint. The entry point of the converter is `deepspeed_to_megatron.py`, which as the following usage:
```bash
python tools/convert_checkpoint/deepspeed_to_megatron.py -h
Convert DeepSpeed Checkpoint to Megatron Checkpoint
usage: deepspeed_to_megatron.py [-h] [--input_folder INPUT_FOLDER]
[--output_folder OUTPUT_FOLDER]
[--target_tp TARGET_TP]
[--target_pp TARGET_PP] [--for_release]

optional arguments:
-h, --help show this help message and exit
--input_folder INPUT_FOLDER
Input DeepSpeed Checkpoint folder
--output_folder OUTPUT_FOLDER
Output Megatron checkpoint folder
--target_tp TARGET_TP
Target TP degree
--target_pp TARGET_PP
Target PP degree
--for_release Convert for release purpose, reset some (progress)
counters.
```

The following scripts which proved useful for debugging are also included:
1. `inspect_deepspeed_checkpoint.py`: view the contents of a DeepSpeed checkpoint folder.
2. `inspect_checkpoint.py`: view the contents of a PyTorch checkpoint file.
189 changes: 189 additions & 0 deletions tools/convert_checkpoint/deepspeed_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import os
from typing import Dict
import torch

ZERO_FILE_PREFIX = 'zero_pp_rank_'
LAYER_FILE_PREFIX = 'layer_'
MP_RANK_FILE_PREFIX = 'mp_rank_'
EMBEDDING_LAYER_INDEX = 0
FINAL_LAYER_NORM_INDEX = -1
ARGS_KEY = 'args'
ITERATION_KEY = 'iteration'
SEQUENTIAL_LAYERS = [
'input_layernorm.weight', 'input_layernorm.bias',
'self_attention.dense.bias',
'post_attention_layernorm.weight', 'post_attention_layernorm.bias',
'mlp.dense_4h_to_h.bias',
'position_embeddings.weight'
]

LAYER_CONCAT_DIM = {
'self_attention.dense.weight': 1,
'mlp.dense_4h_to_h.weight': 1
}

class DeepSpeedCheckpoint(object):
def __init__(self, dir, tp_degree=None, pp_degree=None):
self.dir = dir
self.file_list = self._get_files(dir)
self.zero_files = self._get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX)
self.layer_files = self._get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX)
self.mp_rank_files = self._get_files_with_prefix(self.file_list, MP_RANK_FILE_PREFIX)
self.layer_keys = self._get_layer_keys()
self.layer_count = len(self.layer_keys)
self.original_tp_degree = len(self._get_files_with_prefix(self.layer_files, f'{LAYER_FILE_PREFIX}01'))
self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree
self.dp_degree = len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree)
self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree
self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree
self.global_state = {}

self._sanity_check()
self.pp_to_transformer_map = self._build_pp_transformer_map()
self.transformer_file_map = self._build_transformer_file_map()
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX)
self._build_global_state()



def show_tp_embedding_map(self):
self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers')

def show_tp_final_norm_map(self):
self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers')

def show_pp_tranformer_map(self):
self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers')

def show_transformer_file_map(self):
self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files')

def _build_global_state(self):
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)

def get_iteration(self):
if not ITERATION_KEY in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)

return self.global_state[ITERATION_KEY]

def get_embedding_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_embedding_map.keys()
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]]
sd = self._merge_state_dicts(sd_list)
return sd

def get_args(self):
if not ARGS_KEY in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)

return self.global_state[ARGS_KEY]


def get_transformer_state(self, tp_index: int, pp_index: int) -> list:
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
t_list = []
for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
sd = self._merge_state_dicts(sd_list)
t_list.append(sd)
return t_list

def get_final_norm_state(self, tp_index:int) -> Dict:
assert tp_index in self.tp_to_final_norm_map.keys()
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'))
return sd

def _build_tp_other_layer_map(self, layer_index:int):
assert layer_index < len(self.layer_files)
layer_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[layer_index])
layer_file_partitions = self._partition_data(layer_files, self.tp_degree)
data_map = {i:flist for i, flist in enumerate(layer_file_partitions)}
return data_map

def _build_pp_transformer_map(self):
data_map = {}
transformer_layers = self.layer_keys[1:-1]
layers_per_pp = len(transformer_layers) // self.pp_degree
data_map = {i:transformer_layers[i*layers_per_pp:(i+1)*layers_per_pp] for i in range(0, self.pp_degree)}
return data_map

def _dump_mapping(self, data_map, map_tag = None):
if map_tag is not None:
print(f'Dump mapping: {map_tag}')
for k, v in data_map.items():
print(f'{k} = {v}')

def _build_transformer_file_map(self):
transformer_layer_keys = self.layer_keys[1:-1]
file_map = {}
layers_per_pp = len(transformer_layer_keys) // self.pp_degree
for key_index, layer_key in enumerate(transformer_layer_keys):
pp_index = key_index // layers_per_pp
layer_files = self._get_files_with_prefix(self.layer_files, layer_key)
layer_file_partitions = self._partition_data(layer_files, self.tp_degree)
for tp_index in range(self.tp_degree):
map_key = (tp_index, pp_index)
if not map_key in file_map.keys():
file_map[map_key] = []
file_map[map_key].append(layer_file_partitions[tp_index])

return file_map

def _sanity_check(self):
assert len(self.mp_rank_files) % self.tp_degree == 0
assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0
assert len(self.layer_keys) > 2
assert (len(self.layer_keys) - 2) % self.pp_degree == 0

def _get_files_with_prefix(self, all_files, prefix):
file_list = []
for file_path in all_files:
_, fname = os.path.split(file_path)
if fname.startswith(prefix):
file_list.append(file_path)

return sorted(file_list)

def validate_files(self):
for file in self.file_list:
if not os.path.isfile(file):
print(f'Error: {file} is not existent')

def _get_files(self, dir):
file_list = []
for root, dirs, files in os.walk(dir):
for file in files:
file_list.append(os.path.join(root, file))
return file_list

def _get_layer_keys(self):
key_set = set()
key_len = len(LAYER_FILE_PREFIX) + 2
for file_path in self.layer_files:
_, fname = os.path.split(file_path)
key_set.add(fname[:key_len])
return sorted(list(key_set))

def _partition_data(self, data_list, num_partitions):
num_elems = len(data_list)
assert num_elems % num_partitions == 0
partition_size = num_elems // num_partitions
partitions_list = [data_list[i:i+partition_size] for i in range(0, num_elems, partition_size)]
return partitions_list

def _merge_state_dicts(self, sd_list):
merged_sd = {}
for key in sd_list[0].keys():
if not key in SEQUENTIAL_LAYERS:
cat_dim = LAYER_CONCAT_DIM.get(key, 0)
merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim)
else:
merged_sd[key] = sd_list[0][key]
return merged_sd
148 changes: 148 additions & 0 deletions tools/convert_checkpoint/deepspeed_to_megatron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import argparse
import os
import torch
from collections import OrderedDict
from deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint

MODEL_KEY = 'model'
ARGS_KEY = 'args'
LANGUGAGE_MODEL_KEY = 'language_model'
EMBEDDING_KEY = 'embedding'
ENCODER_KEY = 'encoder'
WORD_EMBEDDINGS_FOR_HEAD_KEY = 'word_embeddings_for_head'
WORD_EMBEDDINGS_KEY = 'word_embeddings'
FINAL_LAYER_NORM_KEY ='final_layernorm'
CHECKPOINT_VERSION_KEY = 'checkpoint_version'
CHECKPOINT_VERSION_VALUE = 3.0
ITERATION_KEY = 'iteration'

def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--input_folder', default=None, type=str, help='Input DeepSpeed Checkpoint folder')
parser.add_argument('--output_folder', default=None, type=str, help='Output Megatron checkpoint folder')
parser.add_argument('--target_tp', default=None, type=int, help='Target TP degree')
parser.add_argument('--target_pp', default=None, type=int, help='Target PP degree')
parser.add_argument('--for_release', action='store_true', help='Convert for release purpose, reset some (progress) counters.')
args = parser.parse_args()
print(f'args = {args}')
return args


def _convert_ds_transformer_state(sd_list):
new_sd = OrderedDict()
for i, sd in enumerate(sd_list):
for key, value in sd.items():
new_key = f'layers.{i}.{key}'
new_sd[new_key] = value

return new_sd

def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree):
path_list = []
iter_folder = f'iter_{iteration:07d}'
for i in range(0, tp_degree):
path_list.append([])
for j in range(0, pp_degree):
rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}'
ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt')
path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path))

return path_list


def _create_megatron_dict():
language_model_dict = {
EMBEDDING_KEY: {},
ENCODER_KEY: {}
}
megatron_dict = {
MODEL_KEY: {LANGUGAGE_MODEL_KEY: language_model_dict},
CHECKPOINT_VERSION_KEY: CHECKPOINT_VERSION_VALUE
}
return megatron_dict


def _save_checkpoint(file_path, chkpt_sd):
dir, _ = os.path.split(file_path)
os.makedirs(dir, exist_ok=True)
torch.save(chkpt_sd, file_path)


def _renest_sd(sd):
new_sd = OrderedDict()
for key, value in sd.items():
a, b = key.split('.')
new_sd[a] = {b: value}
return new_sd


def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, for_release=False):
meg_encoder_sd = OrderedDict()
meg_embedding_sd = OrderedDict()
meg_embedding_for_head_sd = OrderedDict()

transformer_sd = ds_checkpoint.get_transformer_state(tp_index, pp_index)
meg_encoder_sd.update(_convert_ds_transformer_state(transformer_sd))

if pp_index in [0, ds_checkpoint.pp_degree - 1]:
embedding_sd = ds_checkpoint.get_embedding_state(tp_index)
nested_embedding_sd = _renest_sd(embedding_sd)
if pp_index == 0:
meg_embedding_sd.update(nested_embedding_sd)

if pp_index == ds_checkpoint.pp_degree -1:
for key, value in embedding_sd.items():
if key.startswith(WORD_EMBEDDINGS_KEY):
fields = key.split('.')
new_fields = fields[1:]
new_key = '.'.join(new_fields)
meg_embedding_for_head_sd[new_key] = value

final_norm_sd = ds_checkpoint.get_final_norm_state(tp_index)
new_final_norm_sd = {f'{FINAL_LAYER_NORM_KEY}.{key}': value for key, value in final_norm_sd.items()}
meg_encoder_sd.update(new_final_norm_sd)

checkpoint_sd = _create_megatron_dict()

iteration = ds_checkpoint.get_iteration()
checkpoint_sd[ITERATION_KEY] = iteration
if pp_index == 0:
checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][EMBEDDING_KEY] = meg_embedding_sd
checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ENCODER_KEY] = meg_encoder_sd
if pp_index == ds_checkpoint.pp_degree -1:
checkpoint_sd[MODEL_KEY][WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd

checkpoint_sd[ARGS_KEY] = ds_checkpoint.get_args()
# Adjust specific fields
checkpoint_sd[ARGS_KEY].tensor_model_parallel_size = ds_checkpoint.tp_degree
checkpoint_sd[ARGS_KEY].pipeline_model_parallel_size = ds_checkpoint.pp_degree
if for_release:
checkpoint_sd[ARGS_KEY].consumed_train_samples = 0
checkpoint_sd[ARGS_KEY].consumed_valid_samples = 0

_save_checkpoint(checkpoint_path, checkpoint_sd)


def _create_latest_file(base_folder, iteration):
file_path = os.path.join(base_folder, 'latest_checkpointed_iteration.txt')
os.makedirs(base_folder, exist_ok=True)
with open(file_path, 'w') as f:
f.write(str(iteration))

def main():
print(f'Convert DeepSpeed Checkpoint to Megatron Checkpoint')

args = parse_arguments()
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Megatron checkpoint in {args.output_folder}')

ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp)
iteration = ds_checkpoint.get_iteration()
_create_latest_file(args.output_folder, iteration)
checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree, ds_checkpoint.pp_degree)
for i in range(0, ds_checkpoint.tp_degree):
for j in range(0, ds_checkpoint.pp_degree):
_create_rank_checkpoint(ds_checkpoint, checkpoint_paths[i][j], i, j, args.for_release)


if __name__ == "__main__":
main()
Loading

0 comments on commit a7e7d11

Please sign in to comment.