-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
68 lines (56 loc) · 2.27 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import math
import numpy as np
import os
from os import listdir
from os.path import join
import torchvision.transforms as transforms
import torch.nn.functional as F
from math import log10
import skimage.metrics as measure
def save_checkpoint(model, epoch, model_folder):
model_out_path = "checkpoints/"+ model_folder + "/best_psnr.pth"
state_dict = model.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].cpu()
if not os.path.exists("checkpoints"):
os.makedirs("checkpoints")
if not os.path.exists("checkpoints/" + model_folder):
os.makedirs("checkpoints/" + model_folder)
torch.save({
'epoch': epoch,
'state_dict': state_dict}, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def to_psnr(dehaze, gt):
mse = F.mse_loss(dehaze, gt, reduction='none')
mse_split = torch.split(mse, 1, dim=0)
mse_list = [torch.mean(torch.squeeze(mse_split[ind])).item() for ind in range(len(mse_split))]
intensity_max = 1.0
psnr_list = [10.0 * log10(intensity_max / mse) for mse in mse_list]
return psnr_list
def to_ssim_skimage(dehaze, gt):
dehaze_list = torch.split(dehaze, 1, dim=0)
gt_list = torch.split(gt, 1, dim=0)
dehaze_list_np = [dehaze_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))]
gt_list_np = [gt_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))]
print(dehaze_list_np[0].shape, gt_list_np[0].shape)
ssim_list = [measure.structural_similarity(np.squeeze(dehaze_list_np[ind]), gt_list_np[ind], data_range=1, channel_axis=-1) for ind in range(len(dehaze_list))]
return ssim_list
def edge_compute(x):
x_diffx = torch.abs(x[:,:,:,1:] - x[:,:,:,:-1])
x_diffy = torch.abs(x[:,:,1:,:] - x[:,:,:-1,:])
y = x.new(x.size())
y.fill_(0)
y[:,:,:,1:] += x_diffx
y[:,:,:,:-1] += x_diffx
y[:,:,1:,:] += x_diffy
y[:,:,:-1,:] += x_diffy
y = torch.sum(y,1,keepdim=True)/3
y /= 4
return y
def save_image(dehaze, image_name):
dehaze_images = torch.split(dehaze, 1, dim=0)
batch_num = len(dehaze_images)
for ind in range(batch_num):
utils.save_image(dehaze_images[ind], './results/{}'.format(image_name[ind][:-3] + 'png'))