-
Notifications
You must be signed in to change notification settings - Fork 7
/
parameters.py
176 lines (171 loc) · 10 KB
/
parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import argparse
def get_parameters():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--src_file", default=None, type=str,
help="The input data file name.")
parser.add_argument("--tgt_file", default=None, type=str,
help="The output data file name.")
parser.add_argument("--label_file", default=None, type=str,
help="The label data file name.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
parser.add_argument("--config_path", default=None, type=str,
help="Bert config file path.")
parser.add_argument("--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--log_dir",
default='',
type=str,
required=True,
help="The output directory where the log will be written.")
parser.add_argument("--model_recover_path",
default=None,
type=str,
required=True,
help="The file of fine-tuned pretraining model.")
parser.add_argument("--optim_recover_path",
default=None,
type=str,
help="The file of pretraining optimizer.")
# Other parameters
parser.add_argument("--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--do_train",
action='store_true',
help="Whether to run training.")
parser.add_argument("--do_eval",
action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--train_batch_size",
default=32,
type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size",
default=64,
type=int,
help="Total batch size for eval.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--label_smoothing", default=0, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay",
default=0.01,
type=float,
help="The weight decay rate for Adam.")
parser.add_argument("--finetune_decay",
action='store_true',
help="Weight decay to the original weights.")
parser.add_argument("--num_train_epochs",
default=3.0,
type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion",
default=0.1,
type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--hidden_dropout_prob", default=0.1, type=float,
help="Dropout rate for hidden states.")
parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float,
help="Dropout rate for attention probabilities.")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--fp32_embedding', action='store_true',
help="Whether to use 32-bit float precision instead of 16-bit for embeddings")
parser.add_argument('--loss_scale', type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument('--amp', action='store_true',
help="Whether to use amp for fp16")
parser.add_argument('--from_scratch', action='store_true',
help="Initialize parameters with random values (i.e., training from scratch).")
parser.add_argument('--new_segment_ids', action='store_true',
help="Use new segment ids for bi-uni-directional LM.")
parser.add_argument('--new_pos_ids', action='store_true',
help="Use new position ids for LMs.")
parser.add_argument('--tokenized_input', action='store_true',
help="Whether the input is tokenized.")
parser.add_argument('--max_len_a', type=int, default=0,
help="Truncate_config: maximum length of segment A.")
parser.add_argument('--max_len_b', type=int, default=0,
help="Truncate_config: maximum length of segment B.")
parser.add_argument('--trunc_seg', default='',
help="Truncate_config: first truncate segment A/B (option: a, b).")
parser.add_argument('--always_truncate_tail', action='store_true',
help="Truncate_config: Whether we should always truncate tail.")
parser.add_argument("--mask_prob", default=0.15, type=float,
help="Number of prediction is sometimes less than max_pred when sequence is short.")
parser.add_argument("--mask_prob_eos", default=0, type=float,
help="Number of prediction is sometimes less than max_pred when sequence is short.")
parser.add_argument('--max_pred', type=int, default=20,
help="Max tokens of prediction.")
parser.add_argument("--num_workers", default=0, type=int,
help="Number of workers for the data loader.")
parser.add_argument('--mask_source_words', action='store_true',
help="Whether to mask source words for training")
parser.add_argument('--skipgram_prb', type=float, default=0.0,
help='prob of ngram mask')
parser.add_argument('--skipgram_size', type=int, default=1,
help='the max size of ngram mask')
parser.add_argument('--mask_whole_word', action='store_true',
help="Whether masking a whole word.")
parser.add_argument('--do_l2r_training', action='store_true',
help="Whether to do left to right training")
parser.add_argument('--has_sentence_oracle', action='store_true',
help="Whether to have sentence level oracle for training. "
"Only useful for summary generation")
parser.add_argument('--max_position_embeddings', type=int, default=None,
help="max position embeddings")
parser.add_argument('--relax_projection', action='store_true',
help="Use different projection layers for tasks.")
parser.add_argument('--ffn_type', default=0, type=int,
help="0: default mlp; 1: W((Wx+b) elem_prod x);")
parser.add_argument('--num_qkv', default=0, type=int,
help="Number of different <Q,K,V>.")
parser.add_argument('--seg_emb', action='store_true',
help="Using segment embedding for self-attention.")
parser.add_argument('--s2s_special_token', action='store_true',
help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
parser.add_argument('--s2s_add_segment', action='store_true',
help="Additional segmental for the encoder of S2S.")
parser.add_argument('--s2s_share_segment', action='store_true',
help="Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment).")
parser.add_argument('--pos_shift', action='store_true',
help="Using position shift for fine-tuning.")
parser.add_argument('--use_SRL', action='store_true',
help="Using Cross-task Interaction.")
parser.add_argument('--use_bwloss', action='store_true', help="Using Bag of Words Loss.")
return parser