-
Notifications
You must be signed in to change notification settings - Fork 29
/
utils.py
171 lines (145 loc) · 6.31 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from __future__ import print_function
import numpy as np
from os import listdir, remove
from os.path import join
# Temporary Wrapping
# from scipy.misc import imread, imresize
import cv2
def imread(path, output_mode='RGB'):
# return cv2.imread(path, cv2.IMREAD_COLOR)
image = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_COLOR)
if output_mode == 'RGB':
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def imsave(path, image, input_mode='RGB'):
if input_mode == 'RGB':
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
return cv2.imwrite(path, image)
def imresize(image, dst_size: (list, tuple), interp='nearest'):
"""
:param image:
:param dst_size: [w, h]
:param interp:
:return:
"""
if isinstance(dst_size, list):
dst_size = tuple(dst_size)
return cv2.resize(image, dst_size, interpolation=cv2.INTER_NEAREST)
def imresize_square(image, long_side: int, interp='nearest'):
top, bottom, left, right = (0, 0, 0, 0)
h, w, _ = image.shape
longest_edge = max(h, w)
if h < longest_edge:
dh = longest_edge - h
top = dh // 2
bottom = dh - top
elif w < longest_edge:
dw = longest_edge - w
left = dw // 2
right = dw - left
if h != w:
background = [0, 0, 0] # IMPROVE: use dominant color, or tf.pad(mode='REFLECT')
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=background)
return cv2.resize(image, (long_side, long_side), interpolation=cv2.INTER_NEAREST)
def list_images(directory):
images = []
for file in listdir(directory):
name = file.lower()
if name.endswith('.png'): # IMPROVE: use any() instead
images.append(join(directory, file))
elif name.endswith('.jpg'):
images.append(join(directory, file))
elif name.endswith('.jpeg'):
images.append(join(directory, file))
return images
def get_train_images(paths, resize_len=512, crop_height=256, crop_width=256):
images = []
for path in paths:
try:
image = imread(path, output_mode='RGB')
except Exception as e:
image = None
if image is None:
print(f"[WARN] Bypassed unreadable train image: {path}")
continue
height, width, _ = image.shape
# UPDATE: images have same w but different h, must padding after imresize()
# if height < width:
# new_height = resize_len
# new_width = int(width * new_height / height)
# else:
# new_width = resize_len
# new_height = int(height * new_width / width)
# image = imresize(image, [new_height, new_width], interp='nearest')
image = imresize_square(image, resize_len, interp='nearest')
new_width, new_height = resize_len, resize_len
# crop the image
start_h = np.random.choice(new_height - crop_height + 1)
start_w = np.random.choice(new_width - crop_width + 1)
image = image[start_h:(start_h + crop_height), start_w:(start_w + crop_width), :]
images.append(image)
if len(images) == 0:
return []
while len(images) < len(paths): # bypassed something, so append duplications
images.append(images[0])
images = np.stack(images, axis=0)
return images
def single_inputs_generator(paired_filenames, content_path, style_path, constrained_longer_side=1200):
def constrained_resize(image):
h, w, _ = image.shape
if constrained_longer_side >= (h if h > w else w):
return image
h = constrained_longer_side if h > w else int(constrained_longer_side * h / w)
w = constrained_longer_side if w >= h else int(constrained_longer_side * w / h)
return imresize(image, [h, w], interp='nearest')
for content_filename, style_filename in paired_filenames:
try:
content_image = constrained_resize(imread(join(content_path, content_filename), output_mode='RGB'))
style_image = constrained_resize(imread(join(style_path, style_filename), output_mode='RGB'))
except Exception as e:
print(f'[ERROR] Failed reading test image: {e}')
continue # bypass
yield [content_image], [style_image]
# NOTE: this computation will also be tracked by AutoGraph
# Normalizes the `content_features` with scaling and offset from `style_features`.
def AdaIN(content_features, style_features, alpha=1, epsilon=1e-5):
import tensorflow as tf
# UPDATE: keep_dims -> keepdims
content_mean, content_variance = tf.nn.moments(content_features, [1, 2], keepdims=True)
style_mean, style_variance = tf.nn.moments(style_features, [1, 2], keepdims=True)
normalized_content_features = tf.nn.batch_normalization(
content_features, content_mean, content_variance, style_mean, tf.sqrt(style_variance), epsilon
)
normalized_content_features = alpha * normalized_content_features + (1 - alpha) * content_features
return normalized_content_features
def pre_process_dataset(dir_path, shorter_side=512):
import tensorlayer as tl
paths = tl.files.load_file_list(dir_path, regx='\\.(jpg|jpeg|png)', keep_prefix=True)
print('\norigin files number: %d\n' % len(paths))
num_delete = 0
for path in paths:
try:
image = imread(path, output_mode='RGB')
except IOError:
num_delete += 1
print('Cant read this file, will delete it')
remove(path)
if len(image.shape) != 3 or image.shape[2] != 3:
num_delete += 1
remove(path)
print('\nimage.shape:', image.shape, ' Remove image <%s>\n' % path)
else:
height, width, _ = image.shape
if height < width:
new_height = shorter_side
new_width = int(width * new_height / height)
else:
new_width = shorter_side
new_height = int(height * new_width / width)
try:
image = imresize(image, [new_height, new_width], interp='nearest')
except Exception():
print('Cant resize this file, will delete it')
num_delete += 1
remove(path)
print('\n\ndelete %d files! Current number of files: %d\n\n' % (num_delete, len(paths) - num_delete))