-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathmain.py
executable file
·119 lines (87 loc) · 3.59 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
""" PyTorch implimentation of VAE and Super-Resolution VAE.
Reposetory Author:
Ioannis Gatopoulos, 2020
"""
import os
from datetime import datetime
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from src import *
def train_model(dataset, model, writer=None):
train_loader, valid_loader, test_loader = dataloader(dataset)
data_shape = get_data_shape(train_loader)
model = nn.DataParallel(globals()[model](data_shape).to(args.device))
model.module.initialize(train_loader)
criterion = ELBOLoss()
optimizer = torch.optim.Adamax(model.parameters(), lr=2e-3, betas=(0.9, 0.999), eps=1e-7)
scheduler = LowerBoundedExponentialLR(optimizer, gamma=0.999999, lower_bound=0.0001)
n_parameters(model, writer)
for epoch in range(1, args.epochs):
# Train and Validation epoch
train_losses = train(model, criterion, optimizer, scheduler, train_loader)
valid_losses = evaluate(model, criterion, valid_loader)
# Visual Evaluation
generate(model, args.n_samples, epoch, writer)
reconstruction(model, valid_loader, args.n_samples, epoch, writer)
# Saving Model and Loggin
is_saved = save_model(model, optimizer, valid_losses['nelbo'], epoch)
logging(epoch, train_losses, valid_losses, is_saved, writer)
def load_and_evaluate(dataset, model, writer=None):
pth = './src/models/'
# configure paths
pth = os.path.join(pth, 'pretrained', args.model, args.dataset)
pth_inf = os.path.join(pth, 'inference', 'model.pth')
pth_train = os.path.join(pth, 'trainable', 'model.pth')
# get data
train_loader, valid_loader, test_loader = dataloader(dataset)
data_shape = get_data_shape(train_loader)
# deifine model
model = globals()[model](data_shape).to(args.device)
model.initialize(train_loader)
# load trained weights for inference
checkpoint = torch.load(pth_train)
try:
model.load_state_dict(checkpoint['model_state_dict'])
print('Model successfully loaded!')
except RuntimeError:
print('* Failed to load the model. Parameter mismatch.')
quit()
model = nn.DataParallel(model).to(args.device)
model.eval()
criterion = ELBOLoss()
# Evaluation of the model
# --- calculate elbo ---
test_losses = evaluate(model, criterion, test_loader)
print('ELBO: {} bpd'.format(test_losses['bpd']))
# --- image generation ---
generate(model, n_samples=15*15)
# --- image reconstruction ---
reconstruction(model, test_loader, n_samples=15)
# --- image interpolation ---
interpolation(model, test_loader, n_samples=15)
# --- calculate nll ---
bpd = calculate_nll(model, test_loader, criterion, args, iw_samples=args.iw_test)
print('NLL with {} weighted samples: {:4.2f}'.format(args.iw_test, bpd))
# ----- main -----
def main():
# Print configs
print_args(args)
# Control random seeds
fix_random_seed(seed=args.seed)
# Initialize TensorBoad writer (if enabled)
writer = None
if args.use_tb:
writer = SummaryWriter(log_dir='./logs/'+args.dataset+'_'+args.model+'_'+args.tags +
datetime.now().strftime("/%d-%m-%Y/%H-%M-%S"))
writer.add_text('args', namespace2markdown(args))
# Train model
train_model(args.dataset, args.model, writer)
# Evaluate best (latest saved) model
load_and_evaluate(args.dataset, args.model, writer)
# End Experiment
writer.close()
print('\n'+24*'='+' Experiment Ended '+24*'=')
# ----- python main.py -----
if __name__ == "__main__":
main()