diff --git a/neuralmonkey/encoders/cnn_encoder.py b/neuralmonkey/encoders/cnn_encoder.py index b9fa181da..d8a7cf928 100644 --- a/neuralmonkey/encoders/cnn_encoder.py +++ b/neuralmonkey/encoders/cnn_encoder.py @@ -49,32 +49,23 @@ def __init__(self, data_id, convolutions, local_response_normalization=True, dropout_keep_prob=0.5, attention_type=Attention): - """ - Initilizes and configures the computational graph creator. - - Arguments: + """Initialize a convolutional network for image processing. + Args: convolutions (list): Configuration convolutional layers. It is a list of tripplets of integers where the values are: size of the convolutional window, number of convolutional filters, and size of max-pooling window. If the max-pooling size is set to None, no pooling is performed. - data_id: Identifier of the data series in the dataset. - - image_height (int): Height of the input image in pixels. - - image_width (int): Width of the images (padded) - - pixel_dim (int): Number of color channels in the input images. - - batch_normalization (bool): Flag whether the batch normalization + image_height: Height of the input image in pixels. + image_width: Width of the images (padded) + pixel_dim: Number of color channels in the input images. + batch_normalization: Flag whether the batch normalization should be used between the convolutional layers. - - local_response_normalization (bool): Flag whether to use local + local_response_normalization: Flag whether to use local response normalization between the convolutional layers. - - dropout_placeholder (tf.Placeholder): Placeholder keeping the + dropout_placeholder: Placeholder keeping the dropout keeping probability """ @@ -92,16 +83,15 @@ def __init__(self, data_id, convolutions, self.dropout_placeholder = tf.placeholder( tf.float32, name="dropout") self.is_training = tf.placeholder(tf.bool, name="is_training") - self.input_op = \ - tf.placeholder(tf.float32, - shape=(None, image_height, - image_width, pixel_dim), - name="input_images") + self.input_op = tf.placeholder( + tf.float32, + shape=(None, image_height, image_width, pixel_dim), + name="input_images") - self.padding_masks = \ - tf.placeholder(tf.float32, - shape=(None, image_height, image_width, 1), - name="padding_masks") + self.padding_masks = tf.placeholder( + tf.float32, + shape=(None, image_height, image_width, 1), + name="padding_masks") last_layer = self.input_op last_padding_masks = self.padding_masks @@ -115,23 +105,14 @@ def __init__(self, data_id, convolutions, n_filters, pool_size) in enumerate(convolutions): with tf.variable_scope("cnn_layer_{}".format(i)): - conv_w = tf.get_variable( - "wieghts", - shape=[filter_size, filter_size, - last_n_channels, n_filters], - initializer=tf.truncated_normal_initializer( - stddev=.1)) - conv_b = tf.get_variable( - "biases", - shape=[n_filters], - initializer=tf.constant_initializer(.1)) - conv_activation = tf.nn.conv2d( - last_layer, conv_w, [1, 1, 1, 1], "SAME") + conv_b - last_layer = tf.nn.relu(conv_activation) + last_layer = _convolution( + last_layer, last_n_channels, filter_size, + n_filters) last_n_channels = n_filters self.image_processing_layers.append(last_layer) if pool_size: + # TODO do the pooling properly last_layer = tf.nn.max_pool( last_layer, [1, 2, 2, 1], [1, 2, 2, 1], "SAME") last_padding_masks = tf.nn.max_pool( @@ -148,7 +129,7 @@ def __init__(self, data_id, convolutions, last_layer) if batch_normalization: - last_layer = batch_norm( + last_layer = _batch_norm( last_layer, n_filters, self.is_training) last_layer = tf.nn.dropout( @@ -157,7 +138,6 @@ def __init__(self, data_id, convolutions, # last_layer shape is batch X height X width X channels last_layer = last_layer * last_padding_masks - # we average out by the image size -> shape is number # channels from the last convolution self.encoded = tf.reduce_mean(last_layer, [1, 2]) @@ -197,12 +177,27 @@ def feed_dict(self, dataset, train=False): f_dict[self.is_training] = train return f_dict + +def _convolution(last_layer: tf.Tensor, last_n_channels: int, + filter_size: int, n_filters: int) -> tf.Tensor: + """Applies convolution on a filter bank.""" + conv_w = tf.get_variable( + "wieghts", + shape=[filter_size, filter_size, last_n_channels, n_filters], + initializer=tf.truncated_normal_initializer(stddev=.1)) + conv_b = tf.get_variable("biases", shape=[n_filters], + initializer=tf.constant_initializer(.1)) + conv_activation = tf.nn.conv2d( + last_layer, conv_w, [1, 1, 1, 1], "SAME") + conv_b + # TODO assert shape + return tf.nn.relu(conv_activation) + + # pylint: disable=too-many-locals -def batch_norm(tensor, n_out, phase_train, scope='bn', scale_after_norm=True): - """ - Batch normalization on convolutional maps. +def _batch_norm(tensor, n_out, phase_train, scope='bn', scale_after_norm=True): + """ Batch normalization on convolutional maps. Taken from http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow