-
Notifications
You must be signed in to change notification settings - Fork 91
/
main.py
53 lines (40 loc) · 1.31 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse,json,random,os
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torchvision as tv
from trainer import Model
from opts import get_opts
def main():
# Load options
parser = argparse.ArgumentParser(description='Attribute Learner')
parser.add_argument('--config', type=str, help = 'Path to config .opt file. Leave blank if loading from opts.py')
conf = parser.parse_args()
opt = torch.load(conf.config) if conf.config else get_opts()
print('===Options==')
d=vars(opt)
for k in d.keys():
print(k,':',d[k])
# Fix seed
random.seed(opt.manual_seed)
np.random.seed(opt.manual_seed)
torch.manual_seed(opt.manual_seed)
torch.cuda.manual_seed_all(opt.manual_seed)
cudnn.benchmark = True
# Create working directories
try:
os.makedirs(opt.out_path)
os.makedirs(os.path.join(opt.out_path,'checkpoints'))
os.makedirs(os.path.join(opt.out_path,'log_files'))
print( 'Directory {} was successfully created.'.format(opt.out_path))
except OSError:
print( 'Directory {} already exists.'.format(opt.out_path))
pass
# Training
M = Model(opt)
M.train()
'''
TODO: M.test()
'''
if __name__ == '__main__':
main()