-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
125 lines (110 loc) · 5.62 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import datetime
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from model import Restormer
from utils import parse_args, DegradeDataset, rgb_to_y, psnr, ssim
device = "cuda:0"
def test_loop(net, data_loader, num_iter):
net.eval()
total_psnr, total_ssim, count = 0.0, 0.0, 0
with torch.no_grad():
test_bar = tqdm(data_loader, initial=1, dynamic_ncols=True)
for lr, hr, name, h, w in test_bar:
lr, hr = lr.to(device), hr.to(device)
out = torch.clamp((torch.clamp(model(lr)[:, :, :h, :w], 0, 1).mul(255)), 0, 255).byte()
hr = torch.clamp(hr[:, :, :h, :w].mul(255), 0, 255).byte()
# compute the metrics with Y channel and double precision
y, gt = rgb_to_y(out.double()), rgb_to_y(hr.double())
current_psnr, current_ssim = psnr(y, gt), ssim(y, gt)
total_psnr += current_psnr.item()
total_ssim += current_ssim.item()
count += 1
save_path = '{}/{}/{}'.format(args.save_path, args.data_name, name[0])
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
Image.fromarray(out.squeeze(dim=0).permute(1, 2, 0).contiguous().cpu().numpy()).save(save_path)
test_bar.set_description('Test Iter: [{}/{}] PSNR: {:.2f} SSIM: {:.3f}'
.format(num_iter, 1 if args.model_file else args.num_iter,
total_psnr / count, total_ssim / count))
return total_psnr / count, total_ssim / count
def save_loop(net, data_loader, num_iter):
global best_psnr, best_ssim
val_psnr, val_ssim = test_loop(net, data_loader, num_iter)
results['PSNR'].append('{:.2f}'.format(val_psnr))
results['SSIM'].append('{:.3f}'.format(val_ssim))
# ensure the lengths of all lists in results are equal
if len(results['Loss']) < len(results['PSNR']):
results['Loss'].extend([''] * (len(results['PSNR']) - len(results['Loss'])))
# save statistics
data_frame = pd.DataFrame(data=results, index=range(1, len(results['PSNR']) + 1))
data_frame.to_csv('{}/{}.csv'.format(args.save_path, args.data_name), index_label='Iter', float_format='%.3f')
if val_psnr > best_psnr and val_ssim > best_ssim:
best_psnr, best_ssim = val_psnr, val_ssim
with open('{}/{}.txt'.format(args.save_path, args.data_name), 'w') as f:
f.write('Iter: {} PSNR:{:.2f} SSIM:{:.3f}'.format(num_iter, best_psnr, best_ssim))
torch.save(model.state_dict(), '{}/{}.pth'.format(args.save_path, args.data_name))
# Save checkpoint
checkpoint = {
'iteration': num_iter,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_psnr': best_psnr,
'best_ssim': best_ssim,
'results': results
}
torch.save(checkpoint, '{}/checkpoint.pth'.format(args.save_path))
if __name__ == '__main__':
args = parse_args()
test_dataset = DegradeDataset(args.data_path, args.data_name, 'test')
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=args.workers)
results, best_psnr, best_ssim = {'PSNR': [], 'SSIM': [], 'Loss': []}, 0.0, 0.0
model = Restormer(args.num_blocks, args.num_heads, args.channels, args.num_refinement, args.expansion_factor).to(device)
start_iter = 1
optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.num_iter, eta_min=1e-6)
if args.model_file:
model.load_state_dict(torch.load(args.model_file))
save_loop(model, test_loader, 1)
elif args.checkpoint_file:
checkpoint = torch.load(args.checkpoint_file)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_iter = checkpoint['iteration'] + 1
best_psnr = checkpoint['best_psnr']
best_ssim = checkpoint['best_ssim']
results = checkpoint['results']
save_loop(model, test_loader, start_iter)
total_loss, total_num, i = 0.0, 0, 0
train_bar = tqdm(range(start_iter, args.num_iter + 1), initial=start_iter, dynamic_ncols=True)
for n_iter in train_bar:
# progressive learning
if n_iter == start_iter or n_iter - 1 in args.milestone:
end_iter = args.milestone[i] if i < len(args.milestone) else args.num_iter
start_iter = args.milestone[i - 1] if i > 0 else 0
length = args.batch_size[i] * (end_iter - start_iter)
train_dataset = DegradeDataset(args.data_path, args.data_name, 'train', args.patch_size[i], length)
train_loader = iter(DataLoader(train_dataset, args.batch_size[i], True, num_workers=args.workers))
i += 1
# train
model.train()
lr, hr, name, h, w = next(train_loader)
lr, hr = lr.to(device), hr.to(device)
out = model(lr)
loss = F.l1_loss(out, hr)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_num += lr.size(0)
total_loss += loss.item() * lr.size(0)
train_bar.set_description('Train Iter: [{}/{}] Loss: {:.3f}'.format(n_iter, args.num_iter, total_loss / total_num))
lr_scheduler.step()
if n_iter % 1000 == 0:
results['Loss'].append('{:.3f}'.format(total_loss / total_num))
save_loop(model, test_loader, n_iter)