-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
99 lines (88 loc) · 4.77 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
from torch.optim import AdamW
from torch.optim. lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from model import Restormer
from utils import parse_args, DegradeDataset, rgb_to_y, psnr, ssim
device = "cuda:0"
def test_loop(net, data_loader, num_iter):
net.eval()
total_psnr, total_ssim, count = 0.0, 0.0, 0
with torch.no_grad():
test_bar = tqdm(data_loader, initial=1, dynamic_ncols=True)
for lr, hr, name, h, w in test_bar:
lr, hr = lr.to(device), hr.to(device)
out = torch.clamp((torch.clamp(model(lr)[:, :, :h, :w], 0, 1).mul(255)), 0, 255).byte()
hr = torch.clamp(hr[:, :, :h, :w].mul(255), 0, 255).byte()
# computer the metrics with Y channel and double precision
y, gt = rgb_to_y(out.double()), rgb_to_y(hr.double())
current_psnr, current_ssim = psnr(y, gt), ssim(y, gt)
total_psnr += current_psnr.item()
total_ssim += current_ssim.item()
count += 1
save_path = '{}/{}/{}'.format(args.save_path, args.data_name, name[0])
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
Image.fromarray(out.squeeze(dim=0).permute(1, 2, 0).contiguous().cpu().numpy()).save(save_path)
test_bar.set_description('Test Iter: [{}/{}] PSNR: {:.2f} SSIM: {:.3f}'
.format(num_iter, 1 if args.model_file else args.num_iter,
total_psnr / count, total_ssim / count))
return total_psnr / count, total_ssim / count
def save_loop(net, data_loader, num_iter):
global best_psnr, best_ssim
val_psnr, val_ssim = test_loop(net, data_loader, num_iter)
results['PSNR'].append('{:.2f}'.format(val_psnr))
results['SSIM'].append('{:.3f}'.format(val_ssim))
# save statistics
data_frame = pd.DataFrame(data=results, index=range(1, (num_iter if args.model_file else num_iter // 1000) + 1))
data_frame.to_csv('{}/{}.csv'.format(args.save_path, args.data_name), index_label='Iter', float_format='%.3f')
if val_psnr > best_psnr and val_ssim > best_ssim:
best_psnr, best_ssim = val_psnr, val_ssim
with open('{}/{}.txt'.format(args.save_path, args.data_name), 'w') as f:
f.write('Iter: {} PSNR:{:.2f} SSIM:{:.3f}'.format(num_iter, best_psnr, best_ssim))
torch.save(model.state_dict(), '{}/{}.pth'.format(args.save_path, args.data_name))
if __name__ == '__main__':
args = parse_args()
test_dataset =DegradeDataset(args.data_path, args.data_name, 'test')
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=args.workers)
results, best_psnr, best_ssim = {'PSNR': [], 'SSIM': []}, 0.0, 0.0
model = Restormer(args.num_blocks, args.num_heads, args.channels, args.num_refinement, args.expansion_factor).to(device)
if args.model_file:
model.load_state_dict(torch.load(args.model_file))
save_loop(model, test_loader, 1)
else:
optimizer = AdamW(model.parameters(), lr=args. lr, weight_decay=1e-4)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.num_iter, eta_min=1e-6)
total_loss, total_num, results['Loss'], i = 0.0, 0, [], 0
train_bar = tqdm(range(1, args.num_iter + 1), initial=1, dynamic_ncols=True)
for n_iter in train_bar:
# progressive learning
if n_iter == 1 or n_iter - 1 in args.milestone:
end_iter = args.milestone[i] if i < len(args.milestone) else args.num_iter
start_iter = args.milestone[i - 1] if i > 0 else 0
length = args.batch_size[i] * (end_iter - start_iter)
train_dataset =DegradeDataset(args.data_path, args.data_name, 'train', args.patch_size[i], length)
train_loader = iter(DataLoader(train_dataset, args.batch_size[i], True, num_workers=args.workers))
i += 1
# train
model.train()
lr, hr, name, h, w = next(train_loader)
lr, hr = lr.to(device), hr.to(device)
out = model(lr)
loss = F.l1_loss(out, hr)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_num += lr.size(0)
total_loss += loss.item() * lr.size(0)
train_bar.set_description('Train Iter: [{}/{}] Loss: {:.3f}'
.format(n_iter, args.num_iter, total_loss / total_num))
lr_scheduler.step()
if n_iter % 1000 == 0:
results['Loss'].append('{:.3f}'.format(total_loss / total_num))
save_loop(model, test_loader, n_iter)