-
Notifications
You must be signed in to change notification settings - Fork 46
/
utils.py
58 lines (38 loc) · 1.27 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
import cv2
import tensorflow as tf
DIV2K_RGB_MEAN = np.array([0.4488, 0.4371, 0.4040]) * 255
def resolve_single(model, lr):
return resolve(model, tf.expand_dims(lr, axis=0))[0]
def resolve(model, lr_batch):
lr_batch = tf.cast(lr_batch, tf.float32)
sr_batch = model(lr_batch)
sr_batch = tf.clip_by_value(sr_batch, 0, 255)
sr_batch = tf.round(sr_batch)
sr_batch = tf.cast(sr_batch, tf.uint8)
return sr_batch
def evaluate(model, dataset):
psnr_values = []
for lr, hr in dataset:
sr = resolve(model, lr)
#cv2.imwrite("img.jpg", sr[0])
psnr_value = psnr(hr, sr)[0]
psnr_values.append(psnr_value)
return tf.reduce_mean(psnr_values)
def normalize(x, rgb_mean=DIV2K_RGB_MEAN):
return (x - rgb_mean) / 127.5
def denormalize(x, rgb_mean=DIV2K_RGB_MEAN):
return x * 127.5 + rgb_mean
def normalize_01(x):
"""Normalizes RGB images to [0, 1]."""
return x / 255.0
def normalize_m11(x):
"""Normalizes RGB images to [-1, 1]."""
return x / 127.5 - 1
def denormalize_m11(x):
"""Inverse of normalize_m11."""
return (x + 1) * 127.5
def psnr(x1, x2):
return tf.image.psnr(x1, x2, max_val=255)
def subpixel_conv2d(scale):
return lambda x: tf.nn.depth_to_space(x, scale)