Skip to content

Commit

Permalink
fixed overlapping convolutions and updated README
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeid committed Jul 4, 2017
1 parent def6101 commit a4d0478
Show file tree
Hide file tree
Showing 15 changed files with 49 additions and 30 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,36 @@ For description of the generator's output, a pretrained [VGG network](https://ar
<td><img src="lib/images/examples/results-direction.png" height="194px" width="194px"></td>
<td><img src="lib/images/style/great-wave-of-kanagawa.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/style/starry-night.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/style/alley-by-the-lake.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/style/scream.jpg" height="194px" width="194px"></td>
</tr>

<tr>
<td><img src="lib/images/content/nyc.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/examples/nyc-wave.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/examples/nyc-night.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/examples/nyc-alley.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/examples/nyc-scream.jpg" height="194px" width="194px"></td>
</tr>

<tr>
<td><img src="lib/images/content/beach.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/examples/beach-wave.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/examples/beach-night.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/examples/beach-alley.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/examples/beach-scream.jpg" height="194px" width="194px"></td>
</tr>

<tr>
<td><img src="lib/images/content/drawing.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/examples/drawing-wave.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/examples/drawing-night.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/examples/drawing-alley.jpg" height="194px" width="194px"></td>
<td><img src="lib/images/examples/drawing-scream.jpg" height="194px" width="194px"></td>
</tr>

</table>

## Prerequisites

* [Python 3.5](https://www.python.org/downloads/release/python-350/)
* [TensorFlow](https://www.tensorflow.org/) (>= r1.0)
* [TensorFlow](https://www.tensorflow.org/) (>= r1.2)
* [scikit-image](http://scikit-image.org/docs/dev/api/skimage.html)
* [NumPy](http://www.numpy.org/)

Expand Down Expand Up @@ -106,4 +106,4 @@ python test.py --styles

Descriminative model trained on image classification used for gathering descriptive statistics for the loss measure.

The weights used by the VGG network. This file is not in this repository due to its size. You must download it and place in the working directory. The program will complain and ask for you to download it with a supplied link if it does not find it.
The weights used by the VGG network. This file is not in this repository due to its size. You must download it and place in the working directory. The program will complain and ask for you to download it with a supplied link if it does not find it.
Binary file modified lib/images/content/beach.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed lib/images/examples/beach-alley.jpg
Binary file not shown.
Binary file added lib/images/examples/beach-scream.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed lib/images/examples/drawing-alley.jpg
Binary file not shown.
Binary file added lib/images/examples/drawing-scream.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed lib/images/examples/nyc-alley.jpg
Binary file not shown.
Binary file added lib/images/examples/nyc-scream.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed lib/images/style/alley-by-the-lake.jpg
Binary file not shown.
Binary file added lib/images/style/scream.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions lib/images/train2014
8 changes: 4 additions & 4 deletions src/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ def build(self, img):
self.padded = self.__pad(img, 40)

self.conv1 = self.__conv_block(self.padded, maps_shape=[9, 9, 3, 32], stride=1, name='conv1')
self.conv2 = self.__conv_block(self.conv1, maps_shape=[3, 3, 32, 64], stride=2, name='conv2')
self.conv3 = self.__conv_block(self.conv2, maps_shape=[3, 3, 64, 128], stride=2, name='conv3')
self.conv2 = self.__conv_block(self.conv1, maps_shape=[2, 2, 32, 64], stride=2, name='conv2')
self.conv3 = self.__conv_block(self.conv2, maps_shape=[2, 2, 64, 128], stride=2, name='conv3')

self.resid1 = self.__residual_block(self.conv3, maps_shape=[3, 3, 128, 128], stride=1, name='resid1')
self.resid2 = self.__residual_block(self.resid1, maps_shape=[3, 3, 128, 128], stride=1, name='resid2')
self.resid3 = self.__residual_block(self.resid2, maps_shape=[3, 3, 128, 128], stride=1, name='resid3')
self.resid4 = self.__residual_block(self.resid3, maps_shape=[3, 3, 128, 128], stride=1, name='resid4')
self.resid5 = self.__residual_block(self.resid4, maps_shape=[3, 3, 128, 128], stride=1, name='resid5')

self.conv4 = self.__upsample_block(self.resid5, maps_shape=[3, 3, 64, 128], stride=2, name='conv4')
self.conv5 = self.__upsample_block(self.conv4, maps_shape=[3, 3, 32, 64], stride=2, name='conv5')
self.conv4 = self.__upsample_block(self.resid5, maps_shape=[2, 2, 64, 128], stride=2, name='conv4')
self.conv5 = self.__upsample_block(self.conv4, maps_shape=[2, 2, 32, 64], stride=2, name='conv5')
self.conv6 = self.__conv_block(self.conv5, maps_shape=[9, 9, 32, 3], stride=1, name='conv6', activation=None)

self.output = tf.nn.sigmoid(self.conv6)
Expand Down
7 changes: 7 additions & 0 deletions src/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
Description: helper class containing various methods with their functions ranging from image retrieval to auxiliary math helpers
"""

import logging
import numpy as np
import tensorflow as tf
from functools import reduce
import skimage
import skimage.io
import skimage.transform
import sys
from scipy.misc import toimage


# Configure the python logger
def config_logging():
logging.basicConfig(level=logging.INFO, stream=sys.stdout)


# Compute the content loss given a variable image (x) and a content image (c)
def get_content_loss(x, c, layer):
with tf.name_scope('get_content_loss'):
Expand Down
29 changes: 13 additions & 16 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import argparse
import generator
import helpers
import logging
import os
import tensorflow as tf
Expand All @@ -19,13 +20,15 @@
# Model Hyper Params
CONTENT_LAYER = 'conv3_3'
STYLE_LAYERS = {'conv1_2': .25, 'conv2_2': .25, 'conv3_3': .25, 'conv4_3': .25}
EPOCHS = 250000
LEARNING_RATE = .0001
assert sum(STYLE_LAYERS.values()) == 1, "Style layer weights must up to 1"
EPOCHS = 30000
LEARNING_RATE = .001
TRAINING_DIMS = {'height': 256, 'width': 256}
RETRAIN = False

# Loss term weights
CONTENT_WEIGHT = 1.
STYLE_WEIGHT = 3.
STYLE_WEIGHT = .3
TV_WEIGHT = .1

# Default image paths
Expand All @@ -34,35 +37,29 @@
TRAINED_MODELS_PATH = DIR_PATH + '/../lib/generators/'
TRAIN_PATH = None

# Logging params
# Logging params and config
PRINT_TRAINING_STATUS = True
PRINT_EVERY_N = 100

# Logging config
log_dir = DIR_PATH + '/../log/'
if not os.path.isdir(log_dir):
os.makedirs(log_dir)
print('Directory "%s" was created for logging.' % log_dir)
log_path = ''.join([log_dir, str(time.time()), '.log'])
logging.basicConfig(filename=log_path, level=logging.INFO)
print("Printing log to %s" % log_path)
PRINT_EVERY_N = 10
helpers.config_logging()


# Parse arguments and assign them to their respective global variables
def parse_args():
global TRAIN_PATH
global TRAIN_PATH, RETRAIN

# Create flags and assign values to their respective variables
parser = argparse.ArgumentParser()
parser.add_argument('train', help="path to image with style to learn")
parser.add_argument('--retrain', action="store_true", help="whether or not to retrain a model")
args = parser.parse_args()
TRAIN_PATH = os.path.abspath(args.train)
RETRAIN = args.retrain


parse_args()
with tf.Session() as sess:
with tf.variable_scope('generator'):
gen = generator.Generator()
t = trainer.Trainer(sess, gen, TRAIN_PATH, TRAINING_DIMS, PRINT_TRAINING_STATUS, PRINT_EVERY_N)
t.train(EPOCHS, LEARNING_RATE, CONTENT_LAYER, CONTENT_WEIGHT, STYLE_LAYERS, STYLE_WEIGHT, TV_WEIGHT)
t.train(EPOCHS, LEARNING_RATE, CONTENT_LAYER, CONTENT_WEIGHT, STYLE_LAYERS, STYLE_WEIGHT, TV_WEIGHT, RETRAIN)
sess.close()
22 changes: 18 additions & 4 deletions src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, session, net, train_path, train_dims, print_training_status=T
'out_dir': self.current_path + '/../output/',
'style_file': train_path,
'trained_generators_dir': self.current_path + '/../lib/generators/',
'training_dir': self.current_path + '/../lib/train2014/',
'training_dir': self.current_path + '/../lib/images/train2014/',
'training_url': 'http://msvocds.blob.core.windows.net/coco2014/train2014.zip'
}
self.session = session
Expand All @@ -31,7 +31,7 @@ def __init__(self, session, net, train_path, train_dims, print_training_status=T
self.print_training_status = print_training_status
self.train_n = print_every_n

def train(self, epochs, learning_rate, content_layer, content_weight, style_layers, style_weight, tv_weight):
def train(self, epochs, learning_rate, content_layer, content_weight, style_layers, style_weight, tv_weight, retrain=False):
# Check if there is training data available and initialize generator network
self.__check_for_examples()

Expand All @@ -58,6 +58,13 @@ def train(self, epochs, learning_rate, content_layer, content_weight, style_laye
with tf.name_scope('vgg_variable'):
variable_model = vgg16.Vgg16()
variable_model.build(variable_img, shape=art_shape[1:])

# Continue from a pretrained model
if retrain:
print("tryue")
name = os.path.basename(self.paths['style_file']).replace('.jpg', '')
saver = tf.train.Saver()
saver.restore(self.session, "%s/%s/%s" % (self.paths['trained_generators_dir'], name, name))

# Loss ops
with tf.name_scope('loss'):
Expand Down Expand Up @@ -88,8 +95,9 @@ def train(self, epochs, learning_rate, content_layer, content_weight, style_laye
# Populate the training data
logging.info("Initializing session and loading training images..")
example = self.__next_example(height=art_shape[1], width=art_shape[2])
self.session.run(tf.local_variables_initializer())
self.session.run(tf.global_variables_initializer())

# Initialize threads and begin training
logging.info("Begining training..")
coord = tf.train.Coordinator()
Expand All @@ -109,6 +117,13 @@ def train(self, epochs, learning_rate, content_layer, content_weight, style_laye

if self.print_training_status and i % self.train_n == 0:
logging.info("Epoch %06d | Loss %.06f" % (i, loss))
in_path = self.current_path + '/../lib/images/content/nyc.jpg'
input_img, input_shape = helpers.load_img_to(in_path, height=self.train_height, width=self.train_width)
input_img = input_img.reshape([1] + input_shape).astype(np.float32)
path_out = self.current_path + '/../output/' + str(start_time) + '.jpg'
img = self.session.run(variable_img, feed_dict={variable_placeholder: input_img})
helpers.render(img, path_out=path_out)


# Alert that training has been completed and print the run time
elapsed = time.time() - start_time
Expand Down Expand Up @@ -170,7 +185,6 @@ def ask_to_unzip(path):
if num_training_files <= 1:
ask_to_download()

# Retrieves next example image from queue

# Returns a new training example
def __next_example(self, height, width):
Expand Down

0 comments on commit a4d0478

Please sign in to comment.