-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.py
75 lines (58 loc) · 2.82 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
arg_lists = []
parser = argparse.ArgumentParser()
def add_argument_group(name):
arg = parser.add_argument_group(name)
arg_lists.append(arg)
return arg
# Dataset
data_arg = add_argument_group('Dataset')
data_arg.add_argument('--dataset', type=str, default='bs-ergb')
data_arg.add_argument('--data_root', type=str, default='../BS-ERGB/')
data_arg.add_argument('--sparse_data_root', type=str, default='../../../media/ogam/FC8E9F708E9F21E6/BS-ERGB_sparse')
# data_arg.add_argument('--dataset', type=str, default='Davis')
# data_arg.add_argument('--data_root', type=str, default='/home/zhihao/DATA-M2/video_interpolation/Davis/')
# data_arg.add_argument('--dataset', type=str, default='ucf101')
# data_arg.add_argument('--data_root', type=str, default='/home/zhihao/DATA-M2/video_interpolation/UCF/')
# Model
model_arg = add_argument_group('Model')
model_choices = ["MAEVI", "EVFIT"]
model_arg.add_argument('--model', choices=model_choices, type=str, default="MAEVI")
model_arg.add_argument('--nbr_frame', type=int, default=4)
model_arg.add_argument('--joinType', choices=["concat", "add", "none"], default="concat")
# Training / test parameters
learn_arg = add_argument_group('Learning')
learn_arg.add_argument('--loss', type=str, default='1*L1')
learn_arg.add_argument('--lr', type=float, default=2e-4)
learn_arg.add_argument('--beta1', type=float, default=0.9)
learn_arg.add_argument('--beta2', type=float, default=0.999)
learn_arg.add_argument('--batch_size', type=int, default=4)
learn_arg.add_argument('--voxel_grid_size', type=int, default=128)
learn_arg.add_argument('--test_batch_size', type=int, default=1)
learn_arg.add_argument('--start_epoch', type=int, default=1)
learn_arg.add_argument('--max_epoch', type=int, default=100)
learn_arg.add_argument('--resume', action='store_true')
learn_arg.add_argument('--resume_exp', type=str, default=None)
learn_arg.add_argument('--checkpoint_dir', type=str, default=".")
learn_arg.add_argument("--load_from", type=str, default='pretrained/model_best.pth')
learn_arg.add_argument("--pretrained", type=str,
help="Load from a pretrained model.")
# Misc
misc_arg = add_argument_group('Misc')
misc_arg.add_argument('--exp_name', type=str, default='exp')
misc_arg.add_argument('--log_iter', type=int, default=100)
misc_arg.add_argument('--num_gpu', type=int, default=2)
misc_arg.add_argument('--random_seed', type=int, default=103)
misc_arg.add_argument('--num_workers', type=int, default=4)
misc_arg.add_argument('--val_freq', type=int, default=1)
def get_args():
"""Parses all of the arguments above
"""
args, unparsed = parser.parse_known_args()
if args.num_gpu > 0:
setattr(args, 'cuda', True)
else:
setattr(args, 'cuda', False)
if len(unparsed) > 1:
print("Unparsed args: {}".format(unparsed))
return args, unparsed