Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgraded for Python3.6+ and PyTorch 1.1.0+ #25

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ Pytorch implementation for reproducing StackGAN_v2 results in the paper [StackGA

<img src="examples/framework.jpg" width="900px" height="350px"/>


Note: Code has been updated for Python3 usage. Thank you [David Stap](https://github.com/davidstap/AttnGAN) for your help upgrading the original StackGAN-v2 file. Also, sometimes during training my computer randomly shut down. I think was because the GPU was pulling in too much power, but be aware of this.

### Dependencies
python 2.7
python 3.6+

Pytorch
Pytorch 1.1.0+

In addition, please add the project folder to PYTHONPATH and `pip install` the following packages:
- `tensorboard`
- `tensorboardX`
- `python-dateutil`
- `easydict`
- `pandas`
Expand Down Expand Up @@ -56,11 +56,11 @@ In addition, please add the project folder to PYTHONPATH and `pip install` the f


**Pretrained Model**
- [StackGAN-v2 for bird](https://drive.google.com/open?id=1s5Yf3nFiXx0lltMFOiJWB6s1LP24RcwH). Download and save it to `models/` (The [inception score](https://github.com/hanzhanggit/StackGAN-inception-model) for this Model is 4.04±0.05)
- [StackGAN-v2 for dog](https://drive.google.com/open?id=1zcwYfvhsKqb8svQDecTbx_mdYy3TG3F0). Download and save it to `models/` (The [inception score](https://github.com/openai/improved-gan/tree/master/inception_score) for this Model is 9.55±0.11)
- [StackGAN-v2 for cat](https://drive.google.com/open?id=1yPX62c-eCLCNxpziGX9qF_V6Verom3v9). Download and save it to `models/`
- [StackGAN-v2 for bedroom](https://drive.google.com/open?id=1Kqowg0ZLZbN1ek5N-YqEw9TlZeI3XV-K). Download and save it to `models/`
- [StackGAN-v2 for church](https://drive.google.com/open?id=13Pw4PZOkiAM5y_KoOwBzlXK9eQ2hHLfT). Download and save it to `models/`
- [StackGAN-v2 for bird](). Download and save it to `models/` (The [inception score](https://github.com/hanzhanggit/StackGAN-inception-model) for this Model is 4.04±0.05)
- [StackGAN-v2 for dog](). Download and save it to `models/` (The [inception score](https://github.com/openai/improved-gan/tree/master/inception_score) for this Model is 9.55±0.11)
- [StackGAN-v2 for cat](). Download and save it to `models/`
- [StackGAN-v2 for bedroom](). Download and save it to `models/`
- [StackGAN-v2 for church](). Download and save it to `models/`



Expand Down
10 changes: 5 additions & 5 deletions code/cfg/birds_3stages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ CONFIG_NAME: '3stages'

DATASET_NAME: 'birds'
EMBEDDING_TYPE: 'cnn-rnn'
DATA_DIR: '../data/birds'
DATA_DIR: 'data/birds'
GPU_ID: '0'
WORKERS: 4

Expand All @@ -13,11 +13,11 @@ TREE:

TRAIN:
FLAG: True
NET_G: '' # '../output/birds_3stages/Model/netG_epoch_700.pth'
NET_D: '' # '../output/birds_3stages/Model/netD'
BATCH_SIZE: 24
NET_G: '' # 'output/birds_3stages/Model/netG_epoch_700.pth'
NET_D: '' # 'output/birds_3stages/Model/netD'
BATCH_SIZE: 9 #24
MAX_EPOCH: 600
SNAPSHOT_INTERVAL: 2000
SNAPSHOT_INTERVAL: 1000 #2000
DISCRIMINATOR_LR: 0.0002
GENERATOR_LR: 0.0002
COEFF:
Expand Down
79 changes: 27 additions & 52 deletions code/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,18 @@
from __future__ import print_function
from __future__ import unicode_literals


import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
import PIL
import torch.utils.data as data
import os
import os.path
import pickle
import random
import numpy as np
import pandas as pd
from miscc.config import cfg

import torch.utils.data as data
from PIL import Image
import os
import os.path
import six
import string
import sys
import torch

from miscc.config import cfg
from PIL import Image

if sys.version_info[0] == 2:
import cPickle as pickle
else:
Expand All @@ -37,8 +28,7 @@ def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def get_imgs(img_path, imsize, bbox=None,
transform=None, normalize=None):
def get_imgs(img_path, imsize, bbox=None, transform=None, normalize=None):
img = Image.open(img_path).convert('RGB')
width, height = img.size
if bbox is not None:
Expand All @@ -57,7 +47,7 @@ def get_imgs(img_path, imsize, bbox=None,
ret = []
for i in range(cfg.TREE.BRANCH_NUM):
if i < (cfg.TREE.BRANCH_NUM - 1):
re_img = transforms.Scale(imsize[i])(img)
re_img = transforms.Resize(imsize[i])(img)
else:
re_img = img
ret.append(normalize(re_img))
Expand All @@ -71,7 +61,7 @@ def __init__(self, root, split_dir='train', custom_classes=None,
root = os.path.join(root, split_dir)
classes, class_to_idx = self.find_classes(root, custom_classes)
imgs = self.make_dataset(classes, class_to_idx)
if len(imgs) == 0:
if imgs:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

Expand All @@ -93,13 +83,13 @@ def __init__(self, root, split_dir='train', custom_classes=None,
base_size = base_size * 2
print('num_classes', self.num_classes)

def find_classes(self, dir, custom_classes):
def find_classes(self, directory, custom_classes):
classes = []

for d in os.listdir(dir):
for d in os.listdir(directory):
if os.path.isdir:
if custom_classes is None or d in custom_classes:
classes.append(os.path.join(dir, d))
classes.append(os.path.join(directory, d))
print('Valid classes: ', len(classes), classes)

classes.sort()
Expand All @@ -120,19 +110,15 @@ def make_dataset(self, classes, class_to_idx):

def __getitem__(self, index):
path, target = self.imgs[index]
imgs_list = get_imgs(path, self.imsize,
transform=self.transform,
normalize=self.norm)

imgs_list = get_imgs(path, self.imsize, transform=self.transform, normalize=self.norm)
return imgs_list

def __len__(self):
return len(self.imgs)


class LSUNClass(data.Dataset):
def __init__(self, db_path, base_size=64,
transform=None, target_transform=None):
def __init__(self, db_path, base_size=64, transform=None, target_transform=None):
import lmdb
self.db_path = db_path
self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False,
Expand Down Expand Up @@ -168,9 +154,7 @@ def __getitem__(self, index):
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
imgs = get_imgs(buf, self.imsize,
transform=self.transform,
normalize=self.norm)
imgs = get_imgs(buf, self.imsize, transform=self.transform, normalize=self.norm)
return imgs

def __len__(self):
Expand Down Expand Up @@ -215,19 +199,15 @@ def __init__(self, data_dir, split='train', embedding_type='cnn-rnn',
def load_bbox(self):
data_dir = self.data_dir
bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt')
df_bounding_boxes = pd.read_csv(bbox_path,
delim_whitespace=True,
header=None).astype(int)
df_bounding_boxes = pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int)
#
filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt')
df_filenames = \
pd.read_csv(filepath, delim_whitespace=True, header=None)
df_filenames = pd.read_csv(filepath, delim_whitespace=True, header=None)
filenames = df_filenames[1].tolist()
print('Total filenames: ', len(filenames), filenames[0])
#
filename_bbox = {img_file[:-4]: [] for img_file in filenames}
numImgs = len(filenames)
for i in xrange(0, numImgs):
for i, item in enumerate(filenames): # this is the range of the number of images
# bbox = [x-left, y-top, width, height]
bbox = df_bounding_boxes.iloc[i][1:].tolist()

Expand All @@ -240,14 +220,13 @@ def load_all_captions(self):
def load_captions(caption_name): # self,
cap_path = caption_name
with open(cap_path, "r") as f:
captions = f.read().decode('utf8').split('\n')
captions = [cap.replace("\ufffd\ufffd", " ")
for cap in captions if len(cap) > 0]
captions = f.read().split('\n')
captions = [cap.replace("\ufffd\ufffd", " ") for cap in captions if len(cap) > 0]
return captions

caption_dict = {}
for key in self.filenames:
caption_name = '%s/text/%s.txt' % (self.data_dir, key)
caption_name = '%s/text_c10/%s.txt' % (self.data_dir, key)
captions = load_captions(caption_name)
caption_dict[key] = captions
return caption_dict
Expand All @@ -261,7 +240,7 @@ def load_embedding(self, data_dir, embedding_type):
embedding_filename = '/skip-thought-embeddings.pickle'

with open(data_dir + embedding_filename, 'rb') as f:
embeddings = pickle.load(f)
embeddings = pickle.load(f, encoding="bytes")
embeddings = np.array(embeddings)
# embedding_shape = [embeddings.shape[-1]]
print('embeddings: ', embeddings.shape)
Expand All @@ -270,7 +249,7 @@ def load_embedding(self, data_dir, embedding_type):
def load_class_id(self, data_dir, total_num):
if os.path.isfile(data_dir + '/class_info.pickle'):
with open(data_dir + '/class_info.pickle', 'rb') as f:
class_id = pickle.load(f)
class_id = pickle.load(f, encoding="bytes")
else:
class_id = np.arange(total_num)
return class_id
Expand All @@ -293,21 +272,18 @@ def prepair_training_pairs(self, index):
# captions = self.captions[key]
embeddings = self.embeddings[index, :, :]
img_name = '%s/images/%s.jpg' % (data_dir, key)
imgs = get_imgs(img_name, self.imsize,
bbox, self.transform, normalize=self.norm)
imgs = get_imgs(img_name, self.imsize, bbox, self.transform, normalize=self.norm)

wrong_ix = random.randint(0, len(self.filenames) - 1)
if(self.class_id[index] == self.class_id[wrong_ix]):
if self.class_id[index] == self.class_id[wrong_ix]:
wrong_ix = random.randint(0, len(self.filenames) - 1)
wrong_key = self.filenames[wrong_ix]
if self.bbox is not None:
wrong_bbox = self.bbox[wrong_key]
else:
wrong_bbox = None
wrong_img_name = '%s/images/%s.jpg' % \
(data_dir, wrong_key)
wrong_imgs = get_imgs(wrong_img_name, self.imsize,
wrong_bbox, self.transform, normalize=self.norm)
wrong_img_name = '%s/images/%s.jpg' % (data_dir, wrong_key)
wrong_imgs = get_imgs(wrong_img_name, self.imsize, wrong_bbox, self.transform, normalize=self.norm)

embedding_ix = random.randint(0, embeddings.shape[0] - 1)
embedding = embeddings[embedding_ix, :]
Expand All @@ -327,8 +303,7 @@ def prepair_test_pairs(self, index):
# captions = self.captions[key]
embeddings = self.embeddings[index, :, :]
img_name = '%s/images/%s.jpg' % (data_dir, key)
imgs = get_imgs(img_name, self.imsize,
bbox, self.transform, normalize=self.norm)
imgs = get_imgs(img_name, self.imsize, bbox, self.transform, normalize=self.norm)

if self.target_transform is not None:
embeddings = self.target_transform(embeddings)
Expand Down
26 changes: 11 additions & 15 deletions code/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import print_function
from miscc.config import cfg, cfg_from_file

import torch
import torchvision.transforms as transforms

Expand All @@ -16,9 +18,6 @@
sys.path.append(dir_path)


from miscc.config import cfg, cfg_from_file


# 19 classes --> 7 valid classes with 8,555 images
DOG_LESS = ['n02084071', 'n01322604', 'n02112497', 'n02113335', 'n02111277',
'n02084732', 'n02111129', 'n02103406', 'n02112826', 'n02111626',
Expand Down Expand Up @@ -95,8 +94,7 @@ def parse_args():

now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
output_dir = '../output/%s_%s_%s' % \
(cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
output_dir = '../output/%s_%s_%s' % (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

split_dir, bshuffle = 'train', True
if not cfg.TRAIN.FLAG:
Expand All @@ -107,7 +105,7 @@ def parse_args():
# Get data loader
imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1))
image_transform = transforms.Compose([
transforms.Scale(int(imsize * 76 / 64)),
transforms.Resize(int(imsize * 76 / 64)),
transforms.RandomCrop(imsize),
transforms.RandomHorizontalFlip()])
if cfg.DATA_DIR.find('lsun') != -1:
Expand All @@ -128,9 +126,8 @@ def parse_args():
transform=image_transform)
assert dataset
num_gpu = len(cfg.GPU_ID.split(','))
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS))

# Define models and go to train/evaluate
if not cfg.GAN.B_CONDITION:
Expand All @@ -146,9 +143,8 @@ def parse_args():
algo.evaluate(split_dir)
end_t = time.time()
print('Total time for training:', end_t - start_t)
''' Running time comparison for 10epoch with batch_size 24 on birds dataset
T(1gpu) = 1.383 T(2gpus)
- gpu 2: 2426.228544 -> 4min/epoch
- gpu 2 & 3: 1754.12295008 -> 2.9min/epoch
- gpu 3: 2514.02744293
'''
# Running time comparison for 10epoch with batch_size 24 on birds dataset
# T(1gpu) = 1.383 T(2gpus)
# - gpu 2: 2426.228544 -> 4min/epoch
# - gpu 2 & 3: 1754.12295008 -> 2.9min/epoch
# - gpu 3: 2514.02744293
13 changes: 5 additions & 8 deletions code/miscc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def _merge_a_into_b(a, b):
options in b whenever they are also specified in a.
"""
if type(a) is not edict:
return
raise TypeError('{} is not a valid edict type'.format(a))

for k, v in a.iteritems():
for k, v in a.items():
# a must specify keys that are in b
if not b.has_key(k):
if k not in b:
raise KeyError('{} is not a valid config key'.format(k))

# the types must match, too
Expand All @@ -81,17 +81,14 @@ def _merge_a_into_b(a, b):
if isinstance(b[k], np.ndarray):
v = np.array(v, dtype=b[k].dtype)
else:
raise ValueError(('Type mismatch ({} vs. {}) '
'for config key: {}').format(type(b[k]),
type(v), k))
raise TypeError(('Type mismatch ({} vs. {}) for config key: {}'.format(type(b[k]), type(v), k)))

# recursively merge dicts
if type(v) is edict:
try:
_merge_a_into_b(a[k], b[k])
except:
print('Error under config key: {}'.format(k))
raise
raise KeyError('Error under config key: {}'.format(k))
else:
b[k] = v

Expand Down
Loading