From 9ac6c1686bcb655c744b07a7a412ae67fd751d52 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 22 Sep 2021 12:42:55 -0700 Subject: [PATCH] Convert meg ds to hf (#15) * add direct meg-ds to hf format script (#110) * add direct meg-ds to hf format script (part2) (#111) * add direct meg-ds to hf format script * split into 2 function * update the usage doc * make scripts executable * add shebang Co-authored-by: Stas Bekman Co-authored-by: Stas Bekman --- tools/convert_checkpoint/README.md | 54 ++++++++++-- .../deepspeed_to_megatron.py | 26 +++--- .../deepspeed_to_transformers.py | 83 +++++++++++++++++++ 3 files changed, 146 insertions(+), 17 deletions(-) mode change 100644 => 100755 tools/convert_checkpoint/deepspeed_to_megatron.py create mode 100755 tools/convert_checkpoint/deepspeed_to_transformers.py diff --git a/tools/convert_checkpoint/README.md b/tools/convert_checkpoint/README.md index 29cfaa805b..3f74bb1aa4 100644 --- a/tools/convert_checkpoint/README.md +++ b/tools/convert_checkpoint/README.md @@ -1,11 +1,17 @@ # Introduction -This folder is a collection of scripts for converting checkpoints of one training framework (e.g., DeepSpeed) into that of a different framework (e.g., Megatron-LM). inspecting checkpoints. The folder also contains scripts for inspecting checkpoint files and folders, which could be useful when developing checkpoint conversion logic. At the time of creation, this folder contains scripts to convert DeepSpeed checkpoints to Megatron-LM checkpoints (this motivated this effort as part of the BigScience project). -Here are the list and details of checkpoint conversions provided by the available scripts. -1. [DeepSpeed to Megatron-LM](#DeepSpeed-to-Megatron) +This folder is a collection of scripts for converting checkpoints of one training framework (e.g., DeepSpeed) into that of a different framework (e.g., Megatron-LM, HF Transformers). +The folder also contains scripts for inspecting checkpoint files and folders, which could be useful when developing checkpoint conversion logic. At the time of creation, this folder contains scripts to convert DeepSpeed checkpoints to Megatron-LM and HF Transformers checkpoints (this motivated this effort as part of the BigScience project). + +Here are the list and details of checkpoint conversions provided by the available scripts: + +1. [Megatron-DeepSpeed to Megatron-LM](#Megatron-DeepSpeed-to-Megatron) +1. [Megatron-DeepSpeed to HF Transformers](#Megatron-DeepSpeed-to-HF-Transformers) + + +## Megatron-DeepSpeed to Megatron -## DeepSpeed to Megatron The (current implementation of the) converter extracts args and model parameters from a DeepSpeed checkpoint (i.e., excludes other training states such as optimizer, scheduler, etc) and convert into a Megatron-LM checkpoint similarly containing only model parameters. The converter also provides a best-effort attempt to reshape the tensor-parallelism and pipeline parallelism degrees for the checkpoint. The resulting Megatron-LM checkpoint could be loaded into Megatron-LM framework for finetuning or inference. Tensor parallelism (TP) and pipeline parallelism (PP) are supported in the sense that the generated Megatron-LM checkpoint (folders and files) will be of the same TP and PP of the training that created the input DeepSpeed checkpoint. The entry point of the converter is `deepspeed_to_megatron.py`, which as the following usage: ```bash python tools/convert_checkpoint/deepspeed_to_megatron.py -h @@ -31,4 +37,42 @@ optional arguments: The following scripts which proved useful for debugging are also included: 1. `inspect_deepspeed_checkpoint.py`: view the contents of a DeepSpeed checkpoint folder. -2. `inspect_checkpoint.py`: view the contents of a PyTorch checkpoint file. \ No newline at end of file +2. `inspect_checkpoint.py`: view the contents of a PyTorch checkpoint file. + +## Megatron-DeepSpeed to HF Transformers + +In order to convert from Megatron-DeepSpeed to HF Transformers, you can do this directly using: + +```bash +python tools/convert_checkpoint/deepspeed_to_transformers.py \ +--input_folder /path/to/Megatron-Deepspeed/checkpoint/global_step97500 \ +--output_folder /path/to/transformers/checkpoint +``` +since `transformers` currently only works with PP=1/TP=1 we use the defaults `--target_tp 1 --target_pp 1`. + +The script taps into `transformers` and as of this writing requires `transformers@master` (or `transformers==4.11` if you read this later and a new version is released). + +Note that you may run into problems with not having `megatron.enums` defined since `Megatron-Deepspeed` in the `bigscience-workshop` tree diverged from the `microsoft` tree. In such cases you can fix this on the fly by ensuring the former appears first in the `sys.path`. For example: + + +```bash +PYTHONPATH=/hf/Megatron-DeepSpeed-bigscience:/hf/Megatron-DeepSpeed-microsoft \ +python tools/convert_checkpoint/deepspeed_to_transformers.py \ +--input_folder /path/to/Megatron-Deepspeed/checkpoint/global_step97500 \ +--output_folder /path/to/transformers/checkpoint +``` + +Alternatively, you can convert first from Megatron-DeepSpeed to Megatron and then to HF Transformers: + +```bash +# 1. Megatron-DeepSpeed to Megatron +cd /hf/Megatron-DeepSpeed-bigscience +python tools/convert_checkpoint/deepspeed_to_megatron.py --target_tp 1 --target_pp 1 \ +--input_folder /path/to/Megatron-Deepspeed/checkpoint/global_step97500 \ +--output_folder /path/to/Megatron/checkpoint + +# 2. Megatron to HF Transformers +cd /hf/transformers +python src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py \ +/path/to/Megatron/checkpoint/iter_0097500/mp_rank_00/model_optim_rng.pt +``` diff --git a/tools/convert_checkpoint/deepspeed_to_megatron.py b/tools/convert_checkpoint/deepspeed_to_megatron.py old mode 100644 new mode 100755 index d67d184087..022759372c --- a/tools/convert_checkpoint/deepspeed_to_megatron.py +++ b/tools/convert_checkpoint/deepspeed_to_megatron.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python + import argparse import os import torch @@ -20,12 +22,12 @@ def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--input_folder', default=None, type=str, help='Input DeepSpeed Checkpoint folder') parser.add_argument('--output_folder', default=None, type=str, help='Output Megatron checkpoint folder') - parser.add_argument('--target_tp', default=None, type=int, help='Target TP degree') - parser.add_argument('--target_pp', default=None, type=int, help='Target PP degree') + parser.add_argument('--target_tp', default=1, type=int, help='Target TP degree') + parser.add_argument('--target_pp', default=1, type=int, help='Target PP degree') parser.add_argument('--for_release', action='store_true', help='Convert for release purpose, reset some (progress) counters.') args = parser.parse_args() print(f'args = {args}') - return args + return args def _convert_ds_transformer_state(sd_list): @@ -35,7 +37,7 @@ def _convert_ds_transformer_state(sd_list): new_key = f'layers.{i}.{key}' new_sd[new_key] = value - return new_sd + return new_sd def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): path_list = [] @@ -69,7 +71,7 @@ def _save_checkpoint(file_path, chkpt_sd): def _renest_sd(sd): - new_sd = OrderedDict() + new_sd = OrderedDict() for key, value in sd.items(): a, b = key.split('.') new_sd[a] = {b: value} @@ -77,8 +79,8 @@ def _renest_sd(sd): def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, for_release=False): - meg_encoder_sd = OrderedDict() - meg_embedding_sd = OrderedDict() + meg_encoder_sd = OrderedDict() + meg_embedding_sd = OrderedDict() meg_embedding_for_head_sd = OrderedDict() transformer_sd = ds_checkpoint.get_transformer_state(tp_index, pp_index) @@ -97,7 +99,7 @@ def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, new_fields = fields[1:] new_key = '.'.join(new_fields) meg_embedding_for_head_sd[new_key] = value - + final_norm_sd = ds_checkpoint.get_final_norm_state(tp_index) new_final_norm_sd = {f'{FINAL_LAYER_NORM_KEY}.{key}': value for key, value in final_norm_sd.items()} meg_encoder_sd.update(new_final_norm_sd) @@ -120,7 +122,7 @@ def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, checkpoint_sd[ARGS_KEY].consumed_train_samples = 0 checkpoint_sd[ARGS_KEY].consumed_valid_samples = 0 - _save_checkpoint(checkpoint_path, checkpoint_sd) + return checkpoint_sd def _create_latest_file(base_folder, iteration): @@ -131,7 +133,7 @@ def _create_latest_file(base_folder, iteration): def main(): print(f'Convert DeepSpeed Checkpoint to Megatron Checkpoint') - + args = parse_arguments() print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Megatron checkpoint in {args.output_folder}') @@ -141,8 +143,8 @@ def main(): checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree, ds_checkpoint.pp_degree) for i in range(0, ds_checkpoint.tp_degree): for j in range(0, ds_checkpoint.pp_degree): - _create_rank_checkpoint(ds_checkpoint, checkpoint_paths[i][j], i, j, args.for_release) - + sd = _create_rank_checkpoint(ds_checkpoint, i, j, args.for_release) + _save_checkpoint(checkpoint_paths[i][j], sd) if __name__ == "__main__": main() diff --git a/tools/convert_checkpoint/deepspeed_to_transformers.py b/tools/convert_checkpoint/deepspeed_to_transformers.py new file mode 100755 index 0000000000..69375642c2 --- /dev/null +++ b/tools/convert_checkpoint/deepspeed_to_transformers.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python + +import os +import torch +import json + +from deepspeed_checkpoint import DeepSpeedCheckpoint +from deepspeed_to_megatron import _create_rank_checkpoint, parse_arguments + +# the import was tested to work with this version +# https://github.com/huggingface/transformers/commit/0af901e83 if it diverges we may consider +# copying that version here instead +from transformers.models.megatron_gpt2.convert_megatron_gpt2_checkpoint import convert_megatron_checkpoint +from transformers import GPT2Config + +def main(): + + # this first part comes mainly from deepspeed_to_megatron.main + args = parse_arguments() + print(f'Converting DeepSpeed checkpoint in {args.input_folder} to HF Transformers checkpoint in {args.output_folder}') + + ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp) + iteration = ds_checkpoint.get_iteration() + input_state_dict = _create_rank_checkpoint(ds_checkpoint, 0, 0, args.for_release) + + # the 2nd part comes from transformers.models.megatron_gpt2.convert_megatron_gpt2_checkpoint.main + # Spell out all parameters in case the defaults change. + config = GPT2Config( + vocab_size=50257, + n_positions=1024, + n_ctx=1024, + n_embd=1024, + n_layer=24, + n_head=16, + n_inner=4096, + activation_function="gelu", # used to be "gelu_new" in earlier versions + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + gradient_checkpointing=False, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + ) + + # Convert. + print("Converting to HF Checkpoint") + output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config) + + basename = args.output_folder + os.makedirs(basename, exist_ok=True) + + # Print the structure of converted state dict. + #if args.print_checkpoint_structure: + # recursive_print(None, output_state_dict) + + # Store the config to file. + output_config_file = os.path.join(basename, "config.json") + output_config = config.to_dict() + output_config["architectures"] = ["GPT2LMHeadModel"] + output_config["model_type"] = "gpt2" + print(f'Saving config to "{output_config_file}"') + with open(output_config_file, "w") as f: + json.dump(output_config, f) + + # Store the state_dict to file. + output_checkpoint_file = os.path.join(basename, "pytorch_model.bin") + print(f'Saving checkpoint to "{output_checkpoint_file}"') + torch.save(output_state_dict, output_checkpoint_file) + + print("Now add tokenizer files and upload to the hub") + + +if __name__ == "__main__": + main()