-
Notifications
You must be signed in to change notification settings - Fork 0
/
gist.py~
57 lines (44 loc) · 1.98 KB
/
gist.py~
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def put_kernels_on_grid (kernel, (grid_Y, grid_X), pad=1):
'''Visualize conv. features as an image (mostly for the 1st layer).
Place kernel into a grid, with some paddings between adjacent filters.
Args:
kernel: tensor of shape [Y, X, NumChannels, NumKernels]
(grid_Y, grid_X): shape of the grid. Require: NumKernels == grid_Y * grid_X
User is responsible of how to break into two multiples.
pad: number of black pixels around each filter (between them)
Return:
Tensor of shape [(Y+pad)*grid_Y, (X+pad)*grid_X, NumChannels, 1].
'''
# pad X and Y
x1 = tf.pad(kernel, tf.constant( [[pad,0],[pad,0],[0,0],[0,0]] ))
# X and Y dimensions, w.r.t. padding
Y = kernel.get_shape()[0] + pad
X = kernel.get_shape()[1] + pad
# put NumKernels to the 1st dimension
x2 = tf.transpose(x1, (3, 0, 1, 2))
# organize grid on Y axis
x3 = tf.reshape(x2, tf.pack([grid_X, Y * grid_Y, X, 3]))
# switch X and Y axes
x4 = tf.transpose(x3, (0, 2, 1, 3))
# organize grid on X axis
x5 = tf.reshape(x4, tf.pack([1, X * grid_X, Y * grid_Y, 3]))
# back to normal order (not combining with the next step for clarity)
x6 = tf.transpose(x5, (2, 1, 3, 0))
# to tf.image_summary order [batch_size, height, width, channels],
# where in this case batch_size == 1
x7 = tf.transpose(x6, (3, 0, 1, 2))
# scale to [0, 1]
x_min = tf.reduce_min(x7)
x_max = tf.reduce_max(x7)
x8 = (x7 - x_min) / (x_max - x_min)
# scale to [0, 255] and convert to uint8
return tf.image.convert_image_dtype(x8, dtype=tf.uint8)
#
# ... and somewhere inside "def train():"
#
# Visualize conv1 features
with tf.variable_scope('conv1') as scope_conv:
weights = tf.get_variable('weights')
grid_x = grid_y = 8 # to get a square grid for 64 conv1 features
grid = put_kernels_on_grid (weights, (grid_y, grid_x))
tf.image_summary('conv1/features', grid, max_images=1)