-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
83 lines (65 loc) · 2.67 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import argparse
from tqdm import tqdm
import yaml
from attrdict import AttrMap
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.nn import functional as F
from data import Dataset
from utils import gpu_manage, save_image
from models.generator import UNet
from utils import save_image_from_tensors, get_metrics
def predict(config, args):
gpu_manage(args)
dataset = Dataset(args.test_dir)
data_loader = DataLoader(dataset=dataset, num_workers=config.threads, batch_size=1, shuffle=False)
gen = UNet(in_ch=config.in_ch, out_ch=config.out_ch, gpu_ids=args.gpu_ids)
param = torch.load(args.pretrained)
gen.load_state_dict(param)
criterionMSE = nn.MSELoss()
if args.cuda:
gen = gen.cuda(0)
criterionMSE = criterionMSE.cuda(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
avg_mse = 0
avg_psnr = 0
avg_ssim = 0
with torch.no_grad():
for i, batch in enumerate(tqdm(data_loader)):
input_, ground_truth = Variable(batch[0]), Variable(batch[1])
filename = batch[2][0]
input_ = F.interpolate(input_, size=256).to(device)
ground_truth = F.interpolate(ground_truth, size=256).to(device)
output = gen(input_)
save_image_from_tensors(input_, output, ground_truth, config.out_dir, i, 0, filename)
mse, psnr, ssim = get_metrics(output, ground_truth, criterionMSE)
print(filename)
print('MSE: {:.4f}'.format(mse))
print('PSNR: {:.4f} dB'.format(psnr))
print('SSIM: {:.4f} dB'.format(ssim))
avg_mse += mse
avg_psnr += psnr
avg_ssim += ssim
avg_mse = avg_mse / len(data_loader)
avg_psnr = avg_psnr / len(data_loader)
avg_ssim = avg_ssim / len(data_loader)
print('Average MSE: {:.4f}'.format(avg_mse))
print('Average PSNR: {:.4f} dB'.format(avg_psnr))
print('Average SSIM: {:.4f} dB'.format(avg_ssim))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--test_dir', type=str, required=True)
parser.add_argument('--out_dir', type=str, required=True)
parser.add_argument('--pretrained', type=str, required=True)
parser.add_argument('--cuda', action='store_true')
parser.add_argument('--gpu_ids', type=int, default=[0])
parser.add_argument('--manualSeed', type=int, default=0)
args = parser.parse_args()
with open(args.config, 'r') as f:
config = yaml.load(f)
config = AttrMap(config)
predict(config, args)