forked from bigcode-project/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Checkpoint conversion tools (bigcode-project#14)
* Checkpoint conversion tools * Fix formatting * 1) Provide args in converted checkpoint 2) Reshape TP and PP degrees * Fix typo * Fix link * Tweak tag * Fix converted TP and PP sizes * For release mode * Update README * Nested embedding dicts Iteration folder latest checkpoint version file
- Loading branch information
1 parent
51bdf95
commit a7e7d11
Showing
5 changed files
with
491 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Introduction | ||
This folder is a collection of scripts for converting checkpoints of one training framework (e.g., DeepSpeed) into that of a different framework (e.g., Megatron-LM). inspecting checkpoints. The folder also contains scripts for inspecting checkpoint files and folders, which could be useful when developing checkpoint conversion logic. At the time of creation, this folder contains scripts to convert DeepSpeed checkpoints to Megatron-LM checkpoints (this motivated this effort as part of the BigScience project). | ||
|
||
Here are the list and details of checkpoint conversions provided by the available scripts. | ||
1. [DeepSpeed to Megatron-LM](#DeepSpeed-to-Megatron) | ||
|
||
|
||
## DeepSpeed to Megatron | ||
The (current implementation of the) converter extracts args and model parameters from a DeepSpeed checkpoint (i.e., excludes other training states such as optimizer, scheduler, etc) and convert into a Megatron-LM checkpoint similarly containing only model parameters. The converter also provides a best-effort attempt to reshape the tensor-parallelism and pipeline parallelism degrees for the checkpoint. The resulting Megatron-LM checkpoint could be loaded into Megatron-LM framework for finetuning or inference. Tensor parallelism (TP) and pipeline parallelism (PP) are supported in the sense that the generated Megatron-LM checkpoint (folders and files) will be of the same TP and PP of the training that created the input DeepSpeed checkpoint. The entry point of the converter is `deepspeed_to_megatron.py`, which as the following usage: | ||
```bash | ||
python tools/convert_checkpoint/deepspeed_to_megatron.py -h | ||
Convert DeepSpeed Checkpoint to Megatron Checkpoint | ||
usage: deepspeed_to_megatron.py [-h] [--input_folder INPUT_FOLDER] | ||
[--output_folder OUTPUT_FOLDER] | ||
[--target_tp TARGET_TP] | ||
[--target_pp TARGET_PP] [--for_release] | ||
|
||
optional arguments: | ||
-h, --help show this help message and exit | ||
--input_folder INPUT_FOLDER | ||
Input DeepSpeed Checkpoint folder | ||
--output_folder OUTPUT_FOLDER | ||
Output Megatron checkpoint folder | ||
--target_tp TARGET_TP | ||
Target TP degree | ||
--target_pp TARGET_PP | ||
Target PP degree | ||
--for_release Convert for release purpose, reset some (progress) | ||
counters. | ||
``` | ||
|
||
The following scripts which proved useful for debugging are also included: | ||
1. `inspect_deepspeed_checkpoint.py`: view the contents of a DeepSpeed checkpoint folder. | ||
2. `inspect_checkpoint.py`: view the contents of a PyTorch checkpoint file. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import os | ||
from typing import Dict | ||
import torch | ||
|
||
ZERO_FILE_PREFIX = 'zero_pp_rank_' | ||
LAYER_FILE_PREFIX = 'layer_' | ||
MP_RANK_FILE_PREFIX = 'mp_rank_' | ||
EMBEDDING_LAYER_INDEX = 0 | ||
FINAL_LAYER_NORM_INDEX = -1 | ||
ARGS_KEY = 'args' | ||
ITERATION_KEY = 'iteration' | ||
SEQUENTIAL_LAYERS = [ | ||
'input_layernorm.weight', 'input_layernorm.bias', | ||
'self_attention.dense.bias', | ||
'post_attention_layernorm.weight', 'post_attention_layernorm.bias', | ||
'mlp.dense_4h_to_h.bias', | ||
'position_embeddings.weight' | ||
] | ||
|
||
LAYER_CONCAT_DIM = { | ||
'self_attention.dense.weight': 1, | ||
'mlp.dense_4h_to_h.weight': 1 | ||
} | ||
|
||
class DeepSpeedCheckpoint(object): | ||
def __init__(self, dir, tp_degree=None, pp_degree=None): | ||
self.dir = dir | ||
self.file_list = self._get_files(dir) | ||
self.zero_files = self._get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX) | ||
self.layer_files = self._get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX) | ||
self.mp_rank_files = self._get_files_with_prefix(self.file_list, MP_RANK_FILE_PREFIX) | ||
self.layer_keys = self._get_layer_keys() | ||
self.layer_count = len(self.layer_keys) | ||
self.original_tp_degree = len(self._get_files_with_prefix(self.layer_files, f'{LAYER_FILE_PREFIX}01')) | ||
self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree | ||
self.dp_degree = len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree) | ||
self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree | ||
self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree | ||
self.global_state = {} | ||
|
||
self._sanity_check() | ||
self.pp_to_transformer_map = self._build_pp_transformer_map() | ||
self.transformer_file_map = self._build_transformer_file_map() | ||
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) | ||
self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX) | ||
self._build_global_state() | ||
|
||
|
||
|
||
def show_tp_embedding_map(self): | ||
self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers') | ||
|
||
def show_tp_final_norm_map(self): | ||
self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers') | ||
|
||
def show_pp_tranformer_map(self): | ||
self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers') | ||
|
||
def show_transformer_file_map(self): | ||
self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files') | ||
|
||
def _build_global_state(self): | ||
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) | ||
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) | ||
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) | ||
|
||
def get_iteration(self): | ||
if not ITERATION_KEY in self.global_state: | ||
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) | ||
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) | ||
|
||
return self.global_state[ITERATION_KEY] | ||
|
||
def get_embedding_state(self, tp_index: int) -> Dict: | ||
assert tp_index in self.tp_to_embedding_map.keys() | ||
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]] | ||
sd = self._merge_state_dicts(sd_list) | ||
return sd | ||
|
||
def get_args(self): | ||
if not ARGS_KEY in self.global_state: | ||
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) | ||
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) | ||
|
||
return self.global_state[ARGS_KEY] | ||
|
||
|
||
def get_transformer_state(self, tp_index: int, pp_index: int) -> list: | ||
assert tp_index < self.tp_degree | ||
assert pp_index < self.pp_degree | ||
t_list = [] | ||
for fname_list in self.transformer_file_map[(tp_index, pp_index)]: | ||
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] | ||
sd = self._merge_state_dicts(sd_list) | ||
t_list.append(sd) | ||
return t_list | ||
|
||
def get_final_norm_state(self, tp_index:int) -> Dict: | ||
assert tp_index in self.tp_to_final_norm_map.keys() | ||
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu')) | ||
return sd | ||
|
||
def _build_tp_other_layer_map(self, layer_index:int): | ||
assert layer_index < len(self.layer_files) | ||
layer_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[layer_index]) | ||
layer_file_partitions = self._partition_data(layer_files, self.tp_degree) | ||
data_map = {i:flist for i, flist in enumerate(layer_file_partitions)} | ||
return data_map | ||
|
||
def _build_pp_transformer_map(self): | ||
data_map = {} | ||
transformer_layers = self.layer_keys[1:-1] | ||
layers_per_pp = len(transformer_layers) // self.pp_degree | ||
data_map = {i:transformer_layers[i*layers_per_pp:(i+1)*layers_per_pp] for i in range(0, self.pp_degree)} | ||
return data_map | ||
|
||
def _dump_mapping(self, data_map, map_tag = None): | ||
if map_tag is not None: | ||
print(f'Dump mapping: {map_tag}') | ||
for k, v in data_map.items(): | ||
print(f'{k} = {v}') | ||
|
||
def _build_transformer_file_map(self): | ||
transformer_layer_keys = self.layer_keys[1:-1] | ||
file_map = {} | ||
layers_per_pp = len(transformer_layer_keys) // self.pp_degree | ||
for key_index, layer_key in enumerate(transformer_layer_keys): | ||
pp_index = key_index // layers_per_pp | ||
layer_files = self._get_files_with_prefix(self.layer_files, layer_key) | ||
layer_file_partitions = self._partition_data(layer_files, self.tp_degree) | ||
for tp_index in range(self.tp_degree): | ||
map_key = (tp_index, pp_index) | ||
if not map_key in file_map.keys(): | ||
file_map[map_key] = [] | ||
file_map[map_key].append(layer_file_partitions[tp_index]) | ||
|
||
return file_map | ||
|
||
def _sanity_check(self): | ||
assert len(self.mp_rank_files) % self.tp_degree == 0 | ||
assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0 | ||
assert len(self.layer_keys) > 2 | ||
assert (len(self.layer_keys) - 2) % self.pp_degree == 0 | ||
|
||
def _get_files_with_prefix(self, all_files, prefix): | ||
file_list = [] | ||
for file_path in all_files: | ||
_, fname = os.path.split(file_path) | ||
if fname.startswith(prefix): | ||
file_list.append(file_path) | ||
|
||
return sorted(file_list) | ||
|
||
def validate_files(self): | ||
for file in self.file_list: | ||
if not os.path.isfile(file): | ||
print(f'Error: {file} is not existent') | ||
|
||
def _get_files(self, dir): | ||
file_list = [] | ||
for root, dirs, files in os.walk(dir): | ||
for file in files: | ||
file_list.append(os.path.join(root, file)) | ||
return file_list | ||
|
||
def _get_layer_keys(self): | ||
key_set = set() | ||
key_len = len(LAYER_FILE_PREFIX) + 2 | ||
for file_path in self.layer_files: | ||
_, fname = os.path.split(file_path) | ||
key_set.add(fname[:key_len]) | ||
return sorted(list(key_set)) | ||
|
||
def _partition_data(self, data_list, num_partitions): | ||
num_elems = len(data_list) | ||
assert num_elems % num_partitions == 0 | ||
partition_size = num_elems // num_partitions | ||
partitions_list = [data_list[i:i+partition_size] for i in range(0, num_elems, partition_size)] | ||
return partitions_list | ||
|
||
def _merge_state_dicts(self, sd_list): | ||
merged_sd = {} | ||
for key in sd_list[0].keys(): | ||
if not key in SEQUENTIAL_LAYERS: | ||
cat_dim = LAYER_CONCAT_DIM.get(key, 0) | ||
merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim) | ||
else: | ||
merged_sd[key] = sd_list[0][key] | ||
return merged_sd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import argparse | ||
import os | ||
import torch | ||
from collections import OrderedDict | ||
from deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint | ||
|
||
MODEL_KEY = 'model' | ||
ARGS_KEY = 'args' | ||
LANGUGAGE_MODEL_KEY = 'language_model' | ||
EMBEDDING_KEY = 'embedding' | ||
ENCODER_KEY = 'encoder' | ||
WORD_EMBEDDINGS_FOR_HEAD_KEY = 'word_embeddings_for_head' | ||
WORD_EMBEDDINGS_KEY = 'word_embeddings' | ||
FINAL_LAYER_NORM_KEY ='final_layernorm' | ||
CHECKPOINT_VERSION_KEY = 'checkpoint_version' | ||
CHECKPOINT_VERSION_VALUE = 3.0 | ||
ITERATION_KEY = 'iteration' | ||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--input_folder', default=None, type=str, help='Input DeepSpeed Checkpoint folder') | ||
parser.add_argument('--output_folder', default=None, type=str, help='Output Megatron checkpoint folder') | ||
parser.add_argument('--target_tp', default=None, type=int, help='Target TP degree') | ||
parser.add_argument('--target_pp', default=None, type=int, help='Target PP degree') | ||
parser.add_argument('--for_release', action='store_true', help='Convert for release purpose, reset some (progress) counters.') | ||
args = parser.parse_args() | ||
print(f'args = {args}') | ||
return args | ||
|
||
|
||
def _convert_ds_transformer_state(sd_list): | ||
new_sd = OrderedDict() | ||
for i, sd in enumerate(sd_list): | ||
for key, value in sd.items(): | ||
new_key = f'layers.{i}.{key}' | ||
new_sd[new_key] = value | ||
|
||
return new_sd | ||
|
||
def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): | ||
path_list = [] | ||
iter_folder = f'iter_{iteration:07d}' | ||
for i in range(0, tp_degree): | ||
path_list.append([]) | ||
for j in range(0, pp_degree): | ||
rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}' | ||
ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt') | ||
path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path)) | ||
|
||
return path_list | ||
|
||
|
||
def _create_megatron_dict(): | ||
language_model_dict = { | ||
EMBEDDING_KEY: {}, | ||
ENCODER_KEY: {} | ||
} | ||
megatron_dict = { | ||
MODEL_KEY: {LANGUGAGE_MODEL_KEY: language_model_dict}, | ||
CHECKPOINT_VERSION_KEY: CHECKPOINT_VERSION_VALUE | ||
} | ||
return megatron_dict | ||
|
||
|
||
def _save_checkpoint(file_path, chkpt_sd): | ||
dir, _ = os.path.split(file_path) | ||
os.makedirs(dir, exist_ok=True) | ||
torch.save(chkpt_sd, file_path) | ||
|
||
|
||
def _renest_sd(sd): | ||
new_sd = OrderedDict() | ||
for key, value in sd.items(): | ||
a, b = key.split('.') | ||
new_sd[a] = {b: value} | ||
return new_sd | ||
|
||
|
||
def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, for_release=False): | ||
meg_encoder_sd = OrderedDict() | ||
meg_embedding_sd = OrderedDict() | ||
meg_embedding_for_head_sd = OrderedDict() | ||
|
||
transformer_sd = ds_checkpoint.get_transformer_state(tp_index, pp_index) | ||
meg_encoder_sd.update(_convert_ds_transformer_state(transformer_sd)) | ||
|
||
if pp_index in [0, ds_checkpoint.pp_degree - 1]: | ||
embedding_sd = ds_checkpoint.get_embedding_state(tp_index) | ||
nested_embedding_sd = _renest_sd(embedding_sd) | ||
if pp_index == 0: | ||
meg_embedding_sd.update(nested_embedding_sd) | ||
|
||
if pp_index == ds_checkpoint.pp_degree -1: | ||
for key, value in embedding_sd.items(): | ||
if key.startswith(WORD_EMBEDDINGS_KEY): | ||
fields = key.split('.') | ||
new_fields = fields[1:] | ||
new_key = '.'.join(new_fields) | ||
meg_embedding_for_head_sd[new_key] = value | ||
|
||
final_norm_sd = ds_checkpoint.get_final_norm_state(tp_index) | ||
new_final_norm_sd = {f'{FINAL_LAYER_NORM_KEY}.{key}': value for key, value in final_norm_sd.items()} | ||
meg_encoder_sd.update(new_final_norm_sd) | ||
|
||
checkpoint_sd = _create_megatron_dict() | ||
|
||
iteration = ds_checkpoint.get_iteration() | ||
checkpoint_sd[ITERATION_KEY] = iteration | ||
if pp_index == 0: | ||
checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][EMBEDDING_KEY] = meg_embedding_sd | ||
checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ENCODER_KEY] = meg_encoder_sd | ||
if pp_index == ds_checkpoint.pp_degree -1: | ||
checkpoint_sd[MODEL_KEY][WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd | ||
|
||
checkpoint_sd[ARGS_KEY] = ds_checkpoint.get_args() | ||
# Adjust specific fields | ||
checkpoint_sd[ARGS_KEY].tensor_model_parallel_size = ds_checkpoint.tp_degree | ||
checkpoint_sd[ARGS_KEY].pipeline_model_parallel_size = ds_checkpoint.pp_degree | ||
if for_release: | ||
checkpoint_sd[ARGS_KEY].consumed_train_samples = 0 | ||
checkpoint_sd[ARGS_KEY].consumed_valid_samples = 0 | ||
|
||
_save_checkpoint(checkpoint_path, checkpoint_sd) | ||
|
||
|
||
def _create_latest_file(base_folder, iteration): | ||
file_path = os.path.join(base_folder, 'latest_checkpointed_iteration.txt') | ||
os.makedirs(base_folder, exist_ok=True) | ||
with open(file_path, 'w') as f: | ||
f.write(str(iteration)) | ||
|
||
def main(): | ||
print(f'Convert DeepSpeed Checkpoint to Megatron Checkpoint') | ||
|
||
args = parse_arguments() | ||
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Megatron checkpoint in {args.output_folder}') | ||
|
||
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp) | ||
iteration = ds_checkpoint.get_iteration() | ||
_create_latest_file(args.output_folder, iteration) | ||
checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree, ds_checkpoint.pp_degree) | ||
for i in range(0, ds_checkpoint.tp_degree): | ||
for j in range(0, ds_checkpoint.pp_degree): | ||
_create_rank_checkpoint(ds_checkpoint, checkpoint_paths[i][j], i, j, args.for_release) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.