-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
174 lines (147 loc) · 7.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import argparse
import glob
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as T
from PIL import Image
from torch.backends import cudnn
from torch.utils.data import Dataset
from torchvision.transforms import RandomCrop
def parse_args():
desc = 'Pytorch Implementation of \'Restormer: Efficient Transformer for High-Resolution Image Restoration\''
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--data_path', type=str, default='data')
parser.add_argument('--data_name', type=str, default='degraded', choices=['SR', 'DN'])
parser.add_argument('--save_path', type=str, default='result')
parser.add_argument('--num_blocks', nargs='+', type=int, default=[4, 6, 6, 8],
help='number of transformer blocks for each level')
parser.add_argument('--num_heads', nargs='+', type=int, default=[1, 2, 4, 8],
help='number of attention heads for each level')
parser.add_argument('--channels', nargs='+', type=int, default=[48, 96, 192, 384],
help='number of channels for each level')
parser.add_argument('--expansion_factor', type=float, default=2.66, help='factor of channel expansion for GDFN')
parser.add_argument('--num_refinement', type=int, default=4, help='number of channels for refinement stage')
parser.add_argument('--num_iter', type=int, default=300000, help='iterations of training')
parser.add_argument('--batch_size', nargs='+', type=int, default=[28, 20, 16, 8, 4, 4],
help='batch size of loading images for progressive learning')
parser.add_argument('--patch_size', nargs='+', type=int, default=[64, 80, 96, 128, 160, 192],
help='patch size of each image for progressive learning')
parser.add_argument('--lr', type=float, default=0.0003, help='initial learning rate')
parser.add_argument('--milestone', nargs='+', type=int, default=[92000, 156000, 204000, 240000, 276000],
help='when to change patch size and batch size')
parser.add_argument('--workers', type=int, default=8, help='number of data loading workers')
parser.add_argument('--seed', type=int, default=-1, help='random seed (-1 for no manual seed)')
parser.add_argument('--model_file', type=str, default=None, help='path of pre-trained model file')
parser.add_argument('--checkpoint_file', type=str, default=None, help='path of checkpoint file to resume training')
return init_args(parser.parse_args())
class Config(object):
def __init__(self, args):
self.data_path = args.data_path
self.data_name = args.data_name
self.save_path = args.save_path
self.num_blocks = args.num_blocks
self.num_heads = args.num_heads
self.channels = args.channels
self.expansion_factor = args.expansion_factor
self.num_refinement = args.num_refinement
self.num_iter = args.num_iter
self.batch_size = args.batch_size
self.patch_size = args.patch_size
self.lr = args.lr
self.milestone = args.milestone
self.workers = args.workers
self.model_file = args.model_file
self.checkpoint_file = args.checkpoint_file
def init_args(args):
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
if args.seed >= 0:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
cudnn.deterministic = True
cudnn.benchmark = False
return Config(args)
def pad_image_needed(img, size):
width, height = T.get_image_size(img)
if width < size[1]:
img = T.pad(img, [size[1] - width, 0], padding_mode='reflect')
if height < size[0]:
img = T.pad(img, [0, size[0] - height], padding_mode='reflect')
return img
class DegradeDataset(Dataset):
def __init__(self, data_path, data_name, data_type, patch_size=None, length=None):
super().__init__()
self.data_name, self.data_type, self.patch_size = data_name, data_type, patch_size
self.input_images = sorted(glob.glob('{}/{}/{}/input/*.jpg'.format(data_path, data_name, data_type)))
self.output_images = sorted(glob.glob('{}/{}/{}/output/*.jpg'.format(data_path, data_name, data_type)))
# make sure the length of training and testing different
self.num = len(self.input_images)
self.sample_num = length if data_type == 'train' else self.num
def __len__(self):
return self.sample_num
def __getitem__(self, idx):
image_name = os.path.basename(self.input_images[idx % self.num])
input = T.to_tensor(Image.open(self.input_images[idx % self.num]))
output = T.to_tensor(Image.open(self.output_images[idx % self.num]))
h, w = input.shape[1:]
if self.data_type == 'train':
# make sure the image could be cropped
input = pad_image_needed(input, (self.patch_size, self.patch_size))
output = pad_image_needed(output, (self.patch_size, self.patch_size))
i, j, th, tw = RandomCrop.get_params(input, (self.patch_size, self.patch_size))
input = T.crop(input, i, j, th, tw)
output = T.crop(output, i, j, th, tw)
if torch.rand(1) < 0.5:
input = T.hflip(input)
output = T.hflip(output)
if torch.rand(1) < 0.5:
input = T.vflip(input)
output = T.vflip(output)
else:
# padding in case images are not multiples of 8
new_h, new_w = ((h + 8) // 8) * 8, ((w + 8) // 8) * 8
pad_h = new_h - h if h % 8 != 0 else 0
pad_w = new_w - w if w % 8 != 0 else 0
input = F.pad(input, (0, pad_w, 0, pad_h), 'reflect')
output = F.pad(output, (0, pad_w, 0, pad_h), 'reflect')
return input, output, image_name, h, w
def rgb_to_y(x):
rgb_to_grey = torch.tensor([0.256789, 0.504129, 0.097906], dtype=x.dtype, device=x.device).view(1, -1, 1, 1)
return torch.sum(x * rgb_to_grey, dim=1, keepdim=True).add(16.0)
def psnr(x, y, data_range=255.0):
x, y = x / data_range, y / data_range
mse = torch.mean((x - y) ** 2)
score = - 10 * torch.log10(mse)
return score
def ssim(x, y, kernel_size=11, kernel_sigma=1.5, data_range=255.0, k1=0.01, k2=0.03):
x, y = x / data_range, y / data_range
# average pool image if the size is large enough
f = max(1, round(min(x.size()[-2:]) / 256))
if f > 1:
x, y = F.avg_pool2d(x, kernel_size=f), F.avg_pool2d(y, kernel_size=f)
# gaussian filter
coords = torch.arange(kernel_size, dtype=x.dtype, device=x.device)
coords -= (kernel_size - 1) / 2.0
g = coords ** 2
g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * kernel_sigma ** 2)).exp()
g /= g.sum()
kernel = g.unsqueeze(0).repeat(x.size(1), 1, 1, 1)
# compute
c1, c2 = k1 ** 2, k2 ** 2
n_channels = x.size(1)
mu_x = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels)
mu_y = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels)
mu_xx, mu_yy, mu_xy = mu_x ** 2, mu_y ** 2, mu_x * mu_y
sigma_xx = F.conv2d(x ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx
sigma_yy = F.conv2d(y ** 2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy
sigma_xy = F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xy
# contrast sensitivity (CS) with alpha = beta = gamma = 1.
cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2)
# structural similarity (SSIM)
ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs
return ss.mean()