-
Notifications
You must be signed in to change notification settings - Fork 142
/
pretrain.py
139 lines (116 loc) · 6.01 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import torch
import tencentpretrain.trainer as trainer
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.opts import *
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Path options.
parser.add_argument("--dataset_path", type=str, default="dataset.pt",
help="Path of the preprocessed dataset.")
parser.add_argument("--pretrained_model_path", type=str, default=None,
help="Path of the pretrained model.")
parser.add_argument("--output_model_path", type=str, required=True,
help="Path of the output model.")
parser.add_argument("--config_path", type=str, default="models/bert/base_config.json",
help="Config file of model hyper-parameters.")
# Training and saving options.
parser.add_argument("--total_steps", type=int, default=100000,
help="Total training steps.")
parser.add_argument("--save_checkpoint_steps", type=int, default=10000,
help="Specific steps to save model checkpoint.")
parser.add_argument("--report_steps", type=int, default=100,
help="Specific steps to print prompt.")
parser.add_argument("--accumulation_steps", type=int, default=1,
help="Specific steps to accumulate gradient.")
parser.add_argument("--batch_size", type=int, default=32,
help="Training batch size. The actual batch_size is [batch_size x world_size x accumulation_steps].")
parser.add_argument("--instances_buffer_size", type=int, default=25600,
help="The buffer size of instances in memory.")
parser.add_argument("--labels_num", type=int, required=False,
help="Number of prediction labels.")
parser.add_argument("--dropout", type=float, default=0.1, help="Dropout value.")
parser.add_argument("--seed", type=int, default=7, help="Random seed.")
# Preprocess options.
tokenizer_opts(parser)
tgt_tokenizer_opts(parser)
# Model options.
model_opts(parser)
# Model parallelism options.
mp_opts(parser)
parser.add_argument("--data_processor",
choices=["bert", "lm", "mlm", "bilm", "albert", "mt", "t5", "cls",
"prefixlm", "gsg", "bart", "cls_mlm", "vit", "vilt", "clip", "s2t", "beit", "dalle",
"llm_pretrain", "llm_sft"], default="bert",
help="The data processor of the pretraining model.")
parser.add_argument("--deep_init", action="store_true",
help="Scaling initialization of projection layers by a "
"factor of 1/sqrt(2N). Necessary to large models.")
# Masking options.
parser.add_argument("--whole_word_masking", action="store_true", help="Whole word masking.")
parser.add_argument("--span_masking", action="store_true", help="Span masking.")
parser.add_argument("--span_geo_prob", type=float, default=0.2,
help="Hyperparameter of geometric distribution for span masking.")
parser.add_argument("--span_max_length", type=int, default=10,
help="Max length for span masking.")
# Optimizer options.
optimization_opts(parser)
# GPU options.
parser.add_argument("--world_size", type=int, default=1, help="Total number of processes (GPUs) for training.")
parser.add_argument("--gpu_ranks", default=[], nargs='+', type=int, help="List of ranks of each process."
" Each process has a unique integer rank whose value is in the interval [0, world_size), and runs in a single GPU.")
parser.add_argument("--master_ip", default="tcp://localhost:12345", type=str, help="IP-Port of master for training.")
parser.add_argument("--backend", choices=["nccl", "gloo"], default="nccl", type=str, help="Distributed backend.")
# Deepspeed options.
deepspeed_opts(parser)
# lora options.
lora_opts(parser)
# Log options.
log_opts(parser)
args = parser.parse_args()
# construct lora dict parameters.
if args.use_lora:
args.lora_params = {
"lora_r": args.lora_r,
"lora_alpha": args.lora_alpha,
"lora_dropout": args.lora_dropout
}
else:
args.lora_params = None
if "cls" in args.target:
assert args.labels_num is not None, "Cls target needs the denotation of the number of labels."
# Load hyper-parameters from config file.
if args.config_path:
args = load_hyperparam(args)
ranks_num = len(args.gpu_ranks)
if args.deepspeed:
if args.world_size > 1:
args.dist_train = True
else:
args.dist_train = False
else:
if args.world_size > 1:
# Multiprocessing distributed mode.
assert torch.cuda.is_available(), "No available GPUs."
assert ranks_num <= args.world_size, "Started processes exceed `world_size` upper limit."
assert ranks_num <= torch.cuda.device_count(), "Started processes exceeds the available GPUs."
args.dist_train = True
args.ranks_num = ranks_num
print("Using distributed mode for training.")
elif args.world_size == 1 and ranks_num == 1:
# Single GPU mode.
assert torch.cuda.is_available(), "No available GPUs."
args.local_rank = args.gpu_ranks[0]
assert args.local_rank < torch.cuda.device_count(), "Invalid specified GPU device."
args.dist_train = False
args.single_gpu = True
print("Using GPU %d for training." % args.local_rank)
else:
# CPU mode.
assert ranks_num == 0, "GPUs are specified, please check the arguments."
args.dist_train = False
args.single_gpu = False
print("Using CPU mode for training.")
trainer.train_and_validate(args)
if __name__ == "__main__":
main()