Skip to content

Commit

Permalink
Convert meg ds to hf (bigcode-project#15)
Browse files Browse the repository at this point in the history
* add direct meg-ds to hf format script (NVIDIA#110)

* add direct meg-ds to hf format script (part2) (NVIDIA#111)

* add direct meg-ds to hf format script

* split into 2 function

* update the usage doc

* make scripts executable

* add shebang

Co-authored-by: Stas Bekman <[email protected]>
Co-authored-by: Stas Bekman <[email protected]>
  • Loading branch information
3 people authored and mayank31398 committed Jun 21, 2023
1 parent a7e7d11 commit 9ac6c16
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 17 deletions.
54 changes: 49 additions & 5 deletions tools/convert_checkpoint/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# Introduction
This folder is a collection of scripts for converting checkpoints of one training framework (e.g., DeepSpeed) into that of a different framework (e.g., Megatron-LM). inspecting checkpoints. The folder also contains scripts for inspecting checkpoint files and folders, which could be useful when developing checkpoint conversion logic. At the time of creation, this folder contains scripts to convert DeepSpeed checkpoints to Megatron-LM checkpoints (this motivated this effort as part of the BigScience project).

Here are the list and details of checkpoint conversions provided by the available scripts.
1. [DeepSpeed to Megatron-LM](#DeepSpeed-to-Megatron)
This folder is a collection of scripts for converting checkpoints of one training framework (e.g., DeepSpeed) into that of a different framework (e.g., Megatron-LM, HF Transformers).

The folder also contains scripts for inspecting checkpoint files and folders, which could be useful when developing checkpoint conversion logic. At the time of creation, this folder contains scripts to convert DeepSpeed checkpoints to Megatron-LM and HF Transformers checkpoints (this motivated this effort as part of the BigScience project).

Here are the list and details of checkpoint conversions provided by the available scripts:

1. [Megatron-DeepSpeed to Megatron-LM](#Megatron-DeepSpeed-to-Megatron)
1. [Megatron-DeepSpeed to HF Transformers](#Megatron-DeepSpeed-to-HF-Transformers)


## Megatron-DeepSpeed to Megatron

## DeepSpeed to Megatron
The (current implementation of the) converter extracts args and model parameters from a DeepSpeed checkpoint (i.e., excludes other training states such as optimizer, scheduler, etc) and convert into a Megatron-LM checkpoint similarly containing only model parameters. The converter also provides a best-effort attempt to reshape the tensor-parallelism and pipeline parallelism degrees for the checkpoint. The resulting Megatron-LM checkpoint could be loaded into Megatron-LM framework for finetuning or inference. Tensor parallelism (TP) and pipeline parallelism (PP) are supported in the sense that the generated Megatron-LM checkpoint (folders and files) will be of the same TP and PP of the training that created the input DeepSpeed checkpoint. The entry point of the converter is `deepspeed_to_megatron.py`, which as the following usage:
```bash
python tools/convert_checkpoint/deepspeed_to_megatron.py -h
Expand All @@ -31,4 +37,42 @@ optional arguments:

The following scripts which proved useful for debugging are also included:
1. `inspect_deepspeed_checkpoint.py`: view the contents of a DeepSpeed checkpoint folder.
2. `inspect_checkpoint.py`: view the contents of a PyTorch checkpoint file.
2. `inspect_checkpoint.py`: view the contents of a PyTorch checkpoint file.

## Megatron-DeepSpeed to HF Transformers

In order to convert from Megatron-DeepSpeed to HF Transformers, you can do this directly using:

```bash
python tools/convert_checkpoint/deepspeed_to_transformers.py \
--input_folder /path/to/Megatron-Deepspeed/checkpoint/global_step97500 \
--output_folder /path/to/transformers/checkpoint
```
since `transformers` currently only works with PP=1/TP=1 we use the defaults `--target_tp 1 --target_pp 1`.

The script taps into `transformers` and as of this writing requires `transformers@master` (or `transformers==4.11` if you read this later and a new version is released).

Note that you may run into problems with not having `megatron.enums` defined since `Megatron-Deepspeed` in the `bigscience-workshop` tree diverged from the `microsoft` tree. In such cases you can fix this on the fly by ensuring the former appears first in the `sys.path`. For example:


```bash
PYTHONPATH=/hf/Megatron-DeepSpeed-bigscience:/hf/Megatron-DeepSpeed-microsoft \
python tools/convert_checkpoint/deepspeed_to_transformers.py \
--input_folder /path/to/Megatron-Deepspeed/checkpoint/global_step97500 \
--output_folder /path/to/transformers/checkpoint
```

Alternatively, you can convert first from Megatron-DeepSpeed to Megatron and then to HF Transformers:

```bash
# 1. Megatron-DeepSpeed to Megatron
cd /hf/Megatron-DeepSpeed-bigscience
python tools/convert_checkpoint/deepspeed_to_megatron.py --target_tp 1 --target_pp 1 \
--input_folder /path/to/Megatron-Deepspeed/checkpoint/global_step97500 \
--output_folder /path/to/Megatron/checkpoint

# 2. Megatron to HF Transformers
cd /hf/transformers
python src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py \
/path/to/Megatron/checkpoint/iter_0097500/mp_rank_00/model_optim_rng.pt
```
26 changes: 14 additions & 12 deletions tools/convert_checkpoint/deepspeed_to_megatron.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

import argparse
import os
import torch
Expand All @@ -20,12 +22,12 @@ def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--input_folder', default=None, type=str, help='Input DeepSpeed Checkpoint folder')
parser.add_argument('--output_folder', default=None, type=str, help='Output Megatron checkpoint folder')
parser.add_argument('--target_tp', default=None, type=int, help='Target TP degree')
parser.add_argument('--target_pp', default=None, type=int, help='Target PP degree')
parser.add_argument('--target_tp', default=1, type=int, help='Target TP degree')
parser.add_argument('--target_pp', default=1, type=int, help='Target PP degree')
parser.add_argument('--for_release', action='store_true', help='Convert for release purpose, reset some (progress) counters.')
args = parser.parse_args()
print(f'args = {args}')
return args
return args


def _convert_ds_transformer_state(sd_list):
Expand All @@ -35,7 +37,7 @@ def _convert_ds_transformer_state(sd_list):
new_key = f'layers.{i}.{key}'
new_sd[new_key] = value

return new_sd
return new_sd

def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree):
path_list = []
Expand Down Expand Up @@ -69,16 +71,16 @@ def _save_checkpoint(file_path, chkpt_sd):


def _renest_sd(sd):
new_sd = OrderedDict()
new_sd = OrderedDict()
for key, value in sd.items():
a, b = key.split('.')
new_sd[a] = {b: value}
return new_sd


def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, for_release=False):
meg_encoder_sd = OrderedDict()
meg_embedding_sd = OrderedDict()
meg_encoder_sd = OrderedDict()
meg_embedding_sd = OrderedDict()
meg_embedding_for_head_sd = OrderedDict()

transformer_sd = ds_checkpoint.get_transformer_state(tp_index, pp_index)
Expand All @@ -97,7 +99,7 @@ def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index,
new_fields = fields[1:]
new_key = '.'.join(new_fields)
meg_embedding_for_head_sd[new_key] = value

final_norm_sd = ds_checkpoint.get_final_norm_state(tp_index)
new_final_norm_sd = {f'{FINAL_LAYER_NORM_KEY}.{key}': value for key, value in final_norm_sd.items()}
meg_encoder_sd.update(new_final_norm_sd)
Expand All @@ -120,7 +122,7 @@ def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index,
checkpoint_sd[ARGS_KEY].consumed_train_samples = 0
checkpoint_sd[ARGS_KEY].consumed_valid_samples = 0

_save_checkpoint(checkpoint_path, checkpoint_sd)
return checkpoint_sd


def _create_latest_file(base_folder, iteration):
Expand All @@ -131,7 +133,7 @@ def _create_latest_file(base_folder, iteration):

def main():
print(f'Convert DeepSpeed Checkpoint to Megatron Checkpoint')

args = parse_arguments()
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Megatron checkpoint in {args.output_folder}')

Expand All @@ -141,8 +143,8 @@ def main():
checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree, ds_checkpoint.pp_degree)
for i in range(0, ds_checkpoint.tp_degree):
for j in range(0, ds_checkpoint.pp_degree):
_create_rank_checkpoint(ds_checkpoint, checkpoint_paths[i][j], i, j, args.for_release)

sd = _create_rank_checkpoint(ds_checkpoint, i, j, args.for_release)
_save_checkpoint(checkpoint_paths[i][j], sd)

if __name__ == "__main__":
main()
83 changes: 83 additions & 0 deletions tools/convert_checkpoint/deepspeed_to_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python

import os
import torch
import json

from deepspeed_checkpoint import DeepSpeedCheckpoint
from deepspeed_to_megatron import _create_rank_checkpoint, parse_arguments

# the import was tested to work with this version
# https://github.com/huggingface/transformers/commit/0af901e83 if it diverges we may consider
# copying that version here instead
from transformers.models.megatron_gpt2.convert_megatron_gpt2_checkpoint import convert_megatron_checkpoint
from transformers import GPT2Config

def main():

# this first part comes mainly from deepspeed_to_megatron.main
args = parse_arguments()
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to HF Transformers checkpoint in {args.output_folder}')

ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp)
iteration = ds_checkpoint.get_iteration()
input_state_dict = _create_rank_checkpoint(ds_checkpoint, 0, 0, args.for_release)

# the 2nd part comes from transformers.models.megatron_gpt2.convert_megatron_gpt2_checkpoint.main
# Spell out all parameters in case the defaults change.
config = GPT2Config(
vocab_size=50257,
n_positions=1024,
n_ctx=1024,
n_embd=1024,
n_layer=24,
n_head=16,
n_inner=4096,
activation_function="gelu", # used to be "gelu_new" in earlier versions
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
gradient_checkpointing=False,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
)

# Convert.
print("Converting to HF Checkpoint")
output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)

basename = args.output_folder
os.makedirs(basename, exist_ok=True)

# Print the structure of converted state dict.
#if args.print_checkpoint_structure:
# recursive_print(None, output_state_dict)

# Store the config to file.
output_config_file = os.path.join(basename, "config.json")
output_config = config.to_dict()
output_config["architectures"] = ["GPT2LMHeadModel"]
output_config["model_type"] = "gpt2"
print(f'Saving config to "{output_config_file}"')
with open(output_config_file, "w") as f:
json.dump(output_config, f)

# Store the state_dict to file.
output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")
print(f'Saving checkpoint to "{output_checkpoint_file}"')
torch.save(output_state_dict, output_checkpoint_file)

print("Now add tokenizer files and upload to the hub")


if __name__ == "__main__":
main()

0 comments on commit 9ac6c16

Please sign in to comment.