forked from arnab39/cycleGAN-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
153 lines (104 loc) · 6.7 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import torch
from torch import nn
from torch.autograd import Variable
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import utils
from arch import define_Gen, define_Dis
import numpy as np
from sklearn.metrics import mean_absolute_error
from skimage.metrics import peak_signal_noise_ratio
from calculate_fid import calculate_fid
import tensorflow as tf #to print shape of tensor with tf.shape()
def test(args):
transform = transforms.Compose([transforms.Resize((args.crop_height,args.crop_width)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
dataset_dirs = utils.get_testdata_link(args.dataset_dir)
a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform)
b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform)
a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=False, num_workers=4)
b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=False, num_workers=4)
Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm,
use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm,
use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
utils.print_networks([Gab,Gba], ['Gab','Gba'])
try:
ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))
Gab.load_state_dict(ckpt['Gab'])
Gba.load_state_dict(ckpt['Gba'])
except:
print(' [*] No checkpoint!')
#run test and calculate evaluation metrics
a_real_test = iter(a_test_loader)
b_real_test=iter(b_test_loader)
device = torch.device("cuda")
batch_size= args.batch_size
#real examples for plotting (save pic)
a_real_example=(a_real_test.next()[0]).to(device)
b_real_example=(b_real_test.next()[0]).to(device)
Gab.eval()
Gba.eval()
list_T1_mae=[]
list_T1_psnr=[]
list_T1_fid=[]
# for b test dataset - corresponds to T1 images
for imagesT1, imagesT2 in zip(b_real_test, a_real_test):
with torch.no_grad():
imagesT1=imagesT1[0].to(device) #só o primeiro índice mas então o que têm os outros?
imagesT2=imagesT2[0].to(device)
a_fake_test = Gab(imagesT1) #T2 translated
b_recon_test = Gba(a_fake_test) #T1 reconstructed
b_fake_test = Gba(imagesT2) #T1 translated
a_recon_test = Gab(b_fake_test) #T2 reconstructed
imagesT1=imagesT1.cpu()
#brec_to_cpu=b_recon_test.cpu() #T1 reconstructed
bfake_to_cpu=b_fake_test.cpu() #T1 translated
#brec_to_cpu = brec_to_cpu.view(batch_size, 3, 256, 256).numpy()
bfake_to_cpu = bfake_to_cpu.view(batch_size, 3, 256, 256).numpy()
imagesT1=np.squeeze(imagesT1) # squeezed to be [3, 256, 256] before was [1, 3, 256, 256]
#brec_to_cpu=np.squeeze(brec_to_cpu) # squeezed to be [3, 256, 256] before was [1, 3, 256, 256]
bfake_to_cpu=np.squeeze(bfake_to_cpu) # squeezed to be [3, 256, 256] before was [1, 3, 256, 256]
imagesT1=imagesT1[0,:,:].numpy() #choose 1 channel of the RGB - output is [1,256,256]
#brec_to_cpu=brec_to_cpu[1,:,:] #choose 1 channel of the RGB
bfake_to_cpu=bfake_to_cpu[0,:,:] #choose 1 channel of the RGB
#images_fid=imagesT1.reshape((1, 256, 256)) # check if it is this or reshape(1,256,256) - see AE_T1T2 the shape and size of the tensors before going in the MAE
#brec_fid= brec_to_cpu.reshape((1, 256, 256))
#bfake_fid= bfake_to_cpu.reshape((1, 256, 256))
#squeeze all to be (256,256)
imagesT1=np.squeeze(imagesT1)
bfake_to_cpu=np.squeeze(bfake_to_cpu)
#change this to calculate the MAE, PSNR and FID between b_real (from the dataset of T1 images real) and b_fake (the translated T1 images from the T2 slices)
list_T1_mae.append(mean_absolute_error(imagesT1,bfake_to_cpu))
list_T1_psnr.append(peak_signal_noise_ratio(imagesT1,bfake_to_cpu))
list_T1_fid.append(calculate_fid(imagesT1,bfake_to_cpu))
# could add to see the shape/size of the list - should be flatten :
#print mean of MAE, PSNR, FID # compute the mean of the flatten array
print("Mean of MAE = " + str(np.mean(list_T1_mae)))
print("Mean of PSNR = " + str(np.mean(list_T1_psnr)))
print("Mean of FID = " + str(np.mean(list_T1_fid)))
#print variance of MAE, PSNR, FID # compute the variance of the flatten array
print("Variance of MAE = " + str(np.var(list_T1_mae)))
print("Variance of PSNR = " + str(np.var(list_T1_psnr)))
print("Variance of FID = " + str(np.var(list_T1_fid)))
#Example for saving pic - just using the first image example of the datasets to plot the image
with torch.no_grad():
#input is T2 images
b_fake_example = Gba(a_real_example) # output is the translated T1 image from the inputed T2 slice
a_recon_example = Gab(b_fake_example) # output is the reconstructed T2 slice
#input is T1 images
a_fake_example = Gab(b_real_example) # output is the translated T2 image from the inputed T1 slice
b_recon_example = Gba(a_fake_example) # output is the reconstructed T1 slice
# a_real_example= T2 real ; b_fake_example= T1 translated ; a_recon_example = T2 reconstructed | b_real_example= T1 real ; a_fake_example = T2 translated ; b_recon_example= T1 reconstructed
pic = (torch.cat([a_real_example, b_fake_example, a_recon_example, b_real_example, a_fake_example, b_recon_example], dim=0).data + 1) / 2.0
if not os.path.isdir(args.results_dir):
os.makedirs(args.results_dir)
torchvision.utils.save_image(pic, args.results_dir+'/sample.jpg', nrow=3)
b_real_example=np.squeeze(b_real_example.cpu())
b_fake_example=np.squeeze(b_fake_example.cpu())
b_real_example=b_real_example[0,:,:].numpy()
b_fake_example=b_fake_example[0,:,:].numpy()
print(mean_absolute_error(np.squeeze(b_real_example),np.squeeze(b_fake_example)))
print(peak_signal_noise_ratio(np.squeeze(b_real_example),np.squeeze(b_fake_example)))
print(calculate_fid(np.squeeze(b_real_example),np.squeeze(b_fake_example)))