-
Notifications
You must be signed in to change notification settings - Fork 7
/
main.py
155 lines (135 loc) · 7.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# coding: utf-8
import argparse
import os
import time
import random
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
time.sleep(2)
import val_g
import val_multitask_include_d
import val_nomulti_include_d
import train_g
import train_d_vgg
import train_multitask_vgg
def start(args):
if args.is_val:
print("Go into the validation stage")
if args.is_multitask:
val_multitask_include_d.val(args)
else:
if args.d_name != 'null':
val_nomulti_include_d.val(args)
else:
val_g.val(args)
else:
print("Go into the train stage")
if args.is_multitask:
train_multitask_vgg.train(args)
else:
if args.d_name != 'null':
train_d_vgg.train(args)
else:
train_g.train(args)
def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
BATCH_SIZE = 1
# DATA_DIRECTORY = ['/home/SharedSSD/guozihao/dataset/VOC2012/',]
DATA_DIRECTORY = ['/home/gzh/Workspace/Dataset/VOC2012/', ]
IGNORE_LABEL = 255
IMG_SIZE = None # None means we won't use any scaleing or mirroring or resizeing,the input image is origin image.
LEARNING_RATE = 1e-4 # if the model includes d, we should change it to 3e-5
MOMENTUM = 0.9
NUM_CLASSES = 21
NUM_STEPS = 100000 + 1
POWER = 0.9
RANDOM_SEED = random.randint(0, 2 ** 31 - 1)
IS_VAL = False
IS_MULTITASK = False
SAVE_NUM_IMAGES = 1
SAVE_PRED_EVERY = 5000
WEIGHT_DECAY = 0.0003
D_NAME = 'null' # options:null, disc_add_vgg, disc_add_res50
G_NAME = 'res_50' # options:vgg_32,vgg_16,vgg_8,res_50
LAMBD = 0.01
parser = argparse.ArgumentParser(description="VGG for Semantic Segmentation")
parser.add_argument("--batch_size", type=int, default=BATCH_SIZE,
help="Number of images sent to the network in one step.")
parser.add_argument("--data_dir", type=list, default=DATA_DIRECTORY,
help="Path to the directory containing the PASCAL VOC dataset.")
parser.add_argument("--ignore_label", type=int, default=IGNORE_LABEL,
help="The index of the label to ignore during the training.")
parser.add_argument("--img_size", type=tuple, default=IMG_SIZE,
help="Comma_separated string with height and width of images.")
parser.add_argument("--learning_rate", type=float, default=LEARNING_RATE,
help="Base learning rate for training with polynomial decay.")
parser.add_argument("--lambd", type=float, default=LAMBD,
help="a constant for constrainting the D-model loss")
parser.add_argument("--momentum", type=float, default=MOMENTUM,
help="Momentum component of the optimiser.")
parser.add_argument("--num_classes", type=int, default=NUM_CLASSES,
help="Number of classes to predict (including background).")
parser.add_argument("--num_steps", type=int, default=NUM_STEPS,
help="Number of training steps.")
parser.add_argument("--power", type=float, default=POWER,
help="Decay parameter to compute the learning rate.")
parser.add_argument("--random_seed", type=int, default=RANDOM_SEED,
help="Random seed to have reproducible results.")
parser.add_argument("--d_name", type=str, default=D_NAME,
help="which d_model can be choosed")
parser.add_argument("--g_name", type=str, default=G_NAME,
help="which g_model can be choosed")
parser.add_argument("--save_num_images", type=int, default=SAVE_NUM_IMAGES,
help="How many images to save.")
parser.add_argument("--save_pred_every", type=int, default=SAVE_PRED_EVERY,
help="Save summaries and checkpoint every often.")
parser.add_argument("--weight_decay", type=float, default=WEIGHT_DECAY,
help="Regularisation parameter for L2-loss.")
parser.add_argument("--is_val", type=bool, default=IS_VAL,
help="Use the Val")
parser.add_argument("--is_multitask", type=bool, default=IS_MULTITASK,
help="train with using the multitask")
parser.add_argument("--is_training", action="store_true",
help="Whether to updates the running means and variances during the training.")
parser.add_argument("--not_restore_last", action="store_true",
help="Whether to not restore last (FC) layers.")
parser.add_argument("--random_mirror", type=bool, default=True,
help="Whether to randomly mirror the inputs during the training.")
parser.add_argument("--random_scale", type=bool, default=True,
help="Whether to randomly scale the inputs during the training.")
parser.add_argument("--random_crop", type=bool, default=True,
help="Whether to randomly scale the inputs during the training.")
return parser, parser.parse_args()
if __name__ == '__main__':
parser, args = get_arguments()
if args.is_multitask:
RESTORE_FROM = './weights/is_multi/%s/%s/%f/' % (args.g_name, args.d_name, args.learning_rate)
LOG_DIR = './tblogs/val/is_multi/%s/%s/%f/' % (
args.g_name, args.d_name, args.learning_rate) if args.is_val else './tblogs/train/is_multi/%s/%s/%f/' % (
args.g_name, args.d_name, args.learning_rate)
VALID_IMAGE_STORE_PATH = './valid_imgs/is_multi/%s/%s/%f/' % (args.g_name, args.d_name, args.learning_rate)
else:
RESTORE_FROM = './weights/no_multi/%s/%s/%f/' % (args.g_name, args.d_name, args.learning_rate)
LOG_DIR = './tblogs/val/no_multi/%s/%s/%f/' % (
args.g_name, args.d_name, args.learning_rate) if args.is_val else './tblogs/train/no_multi/%s/%s/%f/' % (
args.g_name, args.d_name, args.learning_rate)
VALID_IMAGE_STORE_PATH = './valid_imgs/no_multi/%s/%s/%f/' % (args.g_name, args.d_name, args.learning_rate)
# BASEWEIGHT_FROM = {'res50': '/home/SharedSSD/guozihao/weights/resnet-50.ckpt',
# 'vgg16': '/home/SharedSSD/guozihao/weights/vgg16.npy',
# 'g': './weights/no_multi/%s/null/0.000100/' % (args.g_name)}
BASEWEIGHT_FROM = {'res50': '/home/gzh/Workspace/Weight/resnet50/resnet-50.ckpt',
'vgg16': '/home/gzh/Workspace/Weight/vgg16/vgg16.npy',
'g': '/home/gzh/Workspace/Weight/weights/%s/disc_add_vgg/0.000100' % (args.g_name)}
parser.add_argument("--log_dir", type=str, default=LOG_DIR,
help="Where to save tensorboard log of the model.")
parser.add_argument("--restore_from", type=str, default=RESTORE_FROM,
help="Where restore model parameters from.")
parser.add_argument("--valid_image_store_path", type=str, default=VALID_IMAGE_STORE_PATH,
help="Where store valid image files")
parser.add_argument("--baseweight_from", type=dict, default=BASEWEIGHT_FROM,
help="Where base model weight from")
args = parser.parse_args()
start(args)