Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PSPNet Fully-convolutional #41

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
*.prototxt
*.h5
*.json
*.jpg
*.log
output/
58 changes: 58 additions & 0 deletions nnet/keras_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from keras import layers
from keras.backend import tf as ktf
from .tf_layers import adaptive_pooling_2d


class Interp(layers.Layer):
"""Bilinear interpolation
__call__ Takes two params. First param is layer we need to resize.
Second param is tensor which shape is target.
"""

def __init__(self, new_size=None, **kwargs):
self.new_size = new_size
super(Interp, self).__init__(**kwargs)

def build(self, input_shape):
super(Interp, self).build(input_shape)

def call(self, inputs, **kwargs):
assert(len(inputs) == 2)
shape = ktf.shape(inputs[1])
new_height, new_width = shape[1], shape[2]
resized = ktf.image.resize_images(inputs[0], [new_height, new_width],
align_corners=True)
return resized

def compute_output_shape(self, input_shape):
return tuple([input_shape[0][0], None, None, input_shape[0][3]])

def get_config(self):
config = super(Interp, self).get_config()
return config


class AdaptivePooling2D(layers.Layer):

def __init__(self, out_size, mode='avg', **kwargs):
if mode not in ['avg', 'max']:
msg = "Mode must be either 'max' or 'avg'. Got '{0}'"
raise ValueError(msg.format(mode))
self.out_size = out_size
self.mode = mode
super(AdaptivePooling2D, self).__init__(**kwargs)

def build(self, input_shape):
super(AdaptivePooling2D, self).build(input_shape)

def call(self, inputs, **kwargs):
return adaptive_pooling_2d(inputs, self.out_size, self.mode)

def compute_output_shape(self, input_shape):
return tuple([input_shape[0], self.out_size, self.out_size, input_shape[3]])

def get_config(self):
config = super(AdaptivePooling2D, self).get_config()
config['out_size'] = self.out_size
config['mode'] = self.mode
return config
99 changes: 21 additions & 78 deletions layers_builder.py → nnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,20 @@
from math import ceil
from keras import layers
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from keras.layers import BatchNormalization, Activation, Input, Dropout, ZeroPadding2D, Lambda
from keras.layers import BatchNormalization, Activation, Input, Dropout, ZeroPadding2D
from keras.layers.merge import Concatenate, Add
from keras.models import Model
from keras.optimizers import SGD
from keras.backend import tf as ktf
from .keras_layers import AdaptivePooling2D, Interp

import tensorflow as tf

learning_rate = 1e-3 # Layer specific learning rate
# Weight decay not implemented


def BN(name=""):
return BatchNormalization(momentum=0.95, name=name, epsilon=1e-5)


class Interp(layers.Layer):

def __init__(self, new_size, **kwargs):
self.new_size = new_size
super(Interp, self).__init__(**kwargs)

def build(self, input_shape):
super(Interp, self).build(input_shape)

def call(self, inputs, **kwargs):
new_height, new_width = self.new_size
resized = ktf.image.resize_images(inputs, [new_height, new_width],
align_corners=True)
return resized

def compute_output_shape(self, input_shape):
return tuple([None, self.new_size[0], self.new_size[1], input_shape[3]])

def get_config(self):
config = super(Interp, self).get_config()
config['new_size'] = self.new_size
return config


# def Interp(x, shape):
# new_height, new_width = shape
# resized = ktf.image.resize_images(x, [new_height, new_width],
# align_corners=True)
# return resized


def residual_conv(prev, level, pad=1, lvl=1, sub_lvl=1, modify_stride=False):
lvl = str(lvl)
sub_lvl = str(sub_lvl)
Expand Down Expand Up @@ -191,55 +158,36 @@ def ResNet(inp, layers):
return res


def interp_block(prev_layer, level, feature_map_shape, input_shape):
if input_shape == (473, 473):
kernel_strides_map = {1: 60,
2: 30,
3: 20,
6: 10}
elif input_shape == (713, 713):
kernel_strides_map = {1: 90,
2: 45,
3: 30,
6: 15}
else:
print("Pooling parameters for input shape ",
input_shape, " are not defined.")
exit(1)

def interp_block(prev_block, level):
names = [
"conv5_3_pool" + str(level) + "_conv",
"conv5_3_pool" + str(level) + "_conv_bn"
]
kernel = (kernel_strides_map[level], kernel_strides_map[level])
strides = (kernel_strides_map[level], kernel_strides_map[level])
prev_layer = AveragePooling2D(kernel, strides=strides)(prev_layer)

prev_layer = AdaptivePooling2D(level, mode='avg')(prev_block)
prev_layer = Conv2D(512, (1, 1), strides=(1, 1), name=names[0],
use_bias=False)(prev_layer)
prev_layer = BN(name=names[1])(prev_layer)
prev_layer = Activation('relu')(prev_layer)
# prev_layer = Lambda(Interp, arguments={
# 'shape': feature_map_shape})(prev_layer)
prev_layer = Interp(feature_map_shape)(prev_layer)

prev_layer = Interp()([prev_layer, prev_block])
return prev_layer


def build_pyramid_pooling_module(res, input_shape):
def build_pyramid_pooling_module(res):
"""Build the Pyramid Pooling Module."""
# ---PSPNet concat layers with Interpolation
feature_map_size = tuple(int(ceil(input_dim / 8.0))
for input_dim in input_shape)
print("PSP module will interpolate to a final feature map size of %s" %
(feature_map_size, ))

interp_block1 = interp_block(res, 1, feature_map_size, input_shape)
interp_block2 = interp_block(res, 2, feature_map_size, input_shape)
interp_block3 = interp_block(res, 3, feature_map_size, input_shape)
interp_block6 = interp_block(res, 6, feature_map_size, input_shape)
interp_block1 = interp_block(res, 1)
interp_block2 = interp_block(res, 2)
interp_block3 = interp_block(res, 3)
interp_block6 = interp_block(res, 6)

# concat all these layers. resulted
# shape=(1,feature_map_size_x,feature_map_size_y,4096)
res = Concatenate()([res,
first = Interp()([res, res]) # needed because didn't work concatenate

res = Concatenate()([first,
interp_block6,
interp_block3,
interp_block2,
Expand All @@ -252,9 +200,10 @@ def build_pspnet(nb_classes, resnet_layers, input_shape, activation='softmax'):
print("Building a PSPNet based on ResNet %i expecting inputs of shape %s predicting %i classes" % (
resnet_layers, input_shape, nb_classes))

inp = Input((input_shape[0], input_shape[1], 3))
inp = Input((None, None, 3))
res = ResNet(inp, layers=resnet_layers)
psp = build_pyramid_pooling_module(res, input_shape)

psp = build_pyramid_pooling_module(res)

x = Conv2D(512, (3, 3), strides=(1, 1), padding="same", name="conv5_4",
use_bias=False)(psp)
Expand All @@ -263,16 +212,10 @@ def build_pspnet(nb_classes, resnet_layers, input_shape, activation='softmax'):
x = Dropout(0.1)(x)

x = Conv2D(nb_classes, (1, 1), strides=(1, 1), name="conv6")(x)
# x = Lambda(Interp, arguments={'shape': (
# input_shape[0], input_shape[1])})(x)
x = Interp([input_shape[0], input_shape[1]])(x)
x = Activation('softmax')(x)

x = Interp()([x, inp])
x = Activation(activation)(x)

model = Model(inputs=inp, outputs=x)

# Solver
sgd = SGD(lr=learning_rate, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd,
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
61 changes: 61 additions & 0 deletions nnet/tf_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import tensorflow as tf


def adaptive_pooling_2d(inputs, output_size: int, mode: str):
"""
Performs a pooling operation that results in a fixed size:
output_size x output_size.

Used by spatial_pyramid_pool. Refer to appendix A in [1].

Args:
inputs: A 4D Tensor (B, H, W, C)
output_size: The output size of the pooling operation.
mode: The pooling mode {max, avg}

Returns:
A list of tensors, for each output bin.
The list contains output_size * output_size elements, where
each elment is a Tensor (N, C).

References:
[1] He, Kaiming et al (2015):
Spatial Pyramid Pooling in Deep Convolutional Networks
for Visual Recognition.
https://arxiv.org/pdf/1406.4729.pdf.

Ported from: https://github.com/luizgh/Lasagne/commit/c01e3d922a5712ca4c54617a15a794c23746ac8c
"""
inputs_shape = tf.shape(inputs)
batch = tf.cast(tf.gather(inputs_shape, 0), tf.int32)
h = tf.cast(tf.gather(inputs_shape, 1), tf.int32)
w = tf.cast(tf.gather(inputs_shape, 2), tf.int32)
channels = tf.cast(tf.gather(inputs_shape, 3), tf.int32)
if mode == 'max':
pooling_op = tf.reduce_max
elif mode == 'avg':
pooling_op = tf.reduce_mean
else:
msg = "Mode must be either 'max' or 'avg'. Got '{0}'"
raise ValueError(msg.format(mode))
result = []
n = output_size
for row in range(output_size):
for col in range(output_size):
# start_h = floor(row / n * h)
start_h = tf.cast(
tf.floor(tf.multiply(tf.divide(row, n), tf.cast(h, tf.float32))), tf.int32)
# end_h = ceil((row + 1) / n * h)
end_h = tf.cast(
tf.ceil(tf.multiply(tf.divide((row + 1), n), tf.cast(h, tf.float32))), tf.int32)
# start_w = floor(col / n * w)
start_w = tf.cast(
tf.floor(tf.multiply(tf.divide(col, n), tf.cast(w, tf.float32))), tf.int32)
# end_w = ceil((col + 1) / n * w)
end_w = tf.cast(
tf.ceil(tf.multiply(tf.divide((col + 1), n), tf.cast(w, tf.float32))), tf.int32)
pooling_region = inputs[:, start_h:end_h, start_w:end_w, :]
pool_result = pooling_op(
pooling_region, axis=(1, 2), keepdims=True)
result.append(pool_result)
return tf.reshape(tf.concat(result, axis=1), [batch, output_size, output_size, channels])
13 changes: 7 additions & 6 deletions pspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import numpy as np
from scipy import misc, ndimage
from keras import backend as K
from keras.models import model_from_json, load_model
from keras.models import model_from_json, load_model, Model
import tensorflow as tf
import layers_builder as layers
from nnet.model import build_pspnet
from nnet.keras_layers import Interp, AdaptivePooling2D
from python_utils import utils
from python_utils.preprocessing import preprocess_img
from keras.utils.generic_utils import CustomObjectScope
Expand All @@ -27,15 +28,15 @@ def __init__(self, nb_classes, resnet_layers, input_shape, weights):
if 'pspnet' in weights:
if os.path.isfile(json_path) and os.path.isfile(h5_path):
print("Keras model & weights found, loading...")
with CustomObjectScope({'Interp': layers.Interp}):
with CustomObjectScope({'Interp': Interp, 'AdaptivePooling2D': AdaptivePooling2D}):
with open(json_path, 'r') as file_handle:
self.model = model_from_json(file_handle.read())
self.model.load_weights(h5_path)
else:
print("No Keras model & weights found, import from npy weights.")
self.model = layers.build_pspnet(nb_classes=nb_classes,
resnet_layers=resnet_layers,
input_shape=self.input_shape)
self.model = build_pspnet(nb_classes=nb_classes,
resnet_layers=resnet_layers,
input_shape=self.input_shape)
self.set_npy_weights(weights)
else:
print('Load pre-trained weights')
Expand Down