Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgraded for Python3.6+ and PyTorch 1.1.0+ #25

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ Pytorch implementation for reproducing StackGAN_v2 results in the paper [StackGA

<img src="examples/framework.jpg" width="900px" height="350px"/>


Note: Code has been updated for Python3 usage. Thank you [David Stap](https://github.com/davidstap/AttnGAN) for your help upgrading the original StackGAN-v2 file.

### Dependencies
python 2.7
python 3.6+

Pytorch
Pytorch 1.1.0+

In addition, please add the project folder to PYTHONPATH and `pip install` the following packages:
- `tensorboard`
Expand Down
10 changes: 5 additions & 5 deletions code/cfg/birds_3stages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ CONFIG_NAME: '3stages'

DATASET_NAME: 'birds'
EMBEDDING_TYPE: 'cnn-rnn'
DATA_DIR: '../data/birds'
DATA_DIR: 'data/birds'
GPU_ID: '0'
WORKERS: 4

Expand All @@ -13,11 +13,11 @@ TREE:

TRAIN:
FLAG: True
NET_G: '' # '../output/birds_3stages/Model/netG_epoch_700.pth'
NET_D: '' # '../output/birds_3stages/Model/netD'
BATCH_SIZE: 24
NET_G: '' # 'output/birds_3stages/Model/netG_epoch_700.pth'
NET_D: '' # 'output/birds_3stages/Model/netD'
BATCH_SIZE: 9 #24
MAX_EPOCH: 600
SNAPSHOT_INTERVAL: 2000
SNAPSHOT_INTERVAL: 1000 #2000
DISCRIMINATOR_LR: 0.0002
GENERATOR_LR: 0.0002
COEFF:
Expand Down
72 changes: 25 additions & 47 deletions code/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,18 @@
from __future__ import print_function
from __future__ import unicode_literals


import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
import PIL
import torch.utils.data as data
import os
import os.path
import pickle
import random
import numpy as np
import pandas as pd
from miscc.config import cfg

import torch.utils.data as data
from PIL import Image
import os
import os.path
import six
import string
import sys
import torch

from miscc.config import cfg
from PIL import Image

if sys.version_info[0] == 2:
import cPickle as pickle
else:
Expand All @@ -37,8 +28,7 @@ def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def get_imgs(img_path, imsize, bbox=None,
transform=None, normalize=None):
def get_imgs(img_path, imsize, bbox=None, transform=None, normalize=None):
img = Image.open(img_path).convert('RGB')
width, height = img.size
if bbox is not None:
Expand All @@ -57,7 +47,7 @@ def get_imgs(img_path, imsize, bbox=None,
ret = []
for i in range(cfg.TREE.BRANCH_NUM):
if i < (cfg.TREE.BRANCH_NUM - 1):
re_img = transforms.Scale(imsize[i])(img)
re_img = transforms.Resize(imsize[i])(img)
else:
re_img = img
ret.append(normalize(re_img))
Expand Down Expand Up @@ -120,19 +110,15 @@ def make_dataset(self, classes, class_to_idx):

def __getitem__(self, index):
path, target = self.imgs[index]
imgs_list = get_imgs(path, self.imsize,
transform=self.transform,
normalize=self.norm)

imgs_list = get_imgs(path, self.imsize, transform=self.transform, normalize=self.norm)
return imgs_list

def __len__(self):
return len(self.imgs)


class LSUNClass(data.Dataset):
def __init__(self, db_path, base_size=64,
transform=None, target_transform=None):
def __init__(self, db_path, base_size=64, transform=None, target_transform=None):
import lmdb
self.db_path = db_path
self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False,
Expand Down Expand Up @@ -168,9 +154,7 @@ def __getitem__(self, index):
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
imgs = get_imgs(buf, self.imsize,
transform=self.transform,
normalize=self.norm)
imgs = get_imgs(buf, self.imsize, transform=self.transform, normalize=self.norm)
return imgs

def __len__(self):
Expand Down Expand Up @@ -215,19 +199,16 @@ def __init__(self, data_dir, split='train', embedding_type='cnn-rnn',
def load_bbox(self):
data_dir = self.data_dir
bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt')
df_bounding_boxes = pd.read_csv(bbox_path,
delim_whitespace=True,
header=None).astype(int)
df_bounding_boxes = pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int)
#
filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt')
df_filenames = \
pd.read_csv(filepath, delim_whitespace=True, header=None)
df_filenames = pd.read_csv(filepath, delim_whitespace=True, header=None)
filenames = df_filenames[1].tolist()
print('Total filenames: ', len(filenames), filenames[0])
#
filename_bbox = {img_file[:-4]: [] for img_file in filenames}
numImgs = len(filenames)
for i in xrange(0, numImgs):
for i in range(numImgs):
# bbox = [x-left, y-top, width, height]
bbox = df_bounding_boxes.iloc[i][1:].tolist()

Expand All @@ -240,14 +221,13 @@ def load_all_captions(self):
def load_captions(caption_name): # self,
cap_path = caption_name
with open(cap_path, "r") as f:
captions = f.read().decode('utf8').split('\n')
captions = [cap.replace("\ufffd\ufffd", " ")
for cap in captions if len(cap) > 0]
captions = f.read().split('\n')
captions = [cap.replace("\ufffd\ufffd", " ") for cap in captions if len(cap) > 0]
return captions

caption_dict = {}
for key in self.filenames:
caption_name = '%s/text/%s.txt' % (self.data_dir, key)
caption_name = '%s/text_c10/%s.txt' % (self.data_dir, key)
captions = load_captions(caption_name)
caption_dict[key] = captions
return caption_dict
Expand All @@ -261,7 +241,8 @@ def load_embedding(self, data_dir, embedding_type):
embedding_filename = '/skip-thought-embeddings.pickle'

with open(data_dir + embedding_filename, 'rb') as f:
embeddings = pickle.load(f)
embeddings = pickle.load(f, encoding="bytes")
# embeddings = pickle.load(f, encoding="latin-1")
embeddings = np.array(embeddings)
# embedding_shape = [embeddings.shape[-1]]
print('embeddings: ', embeddings.shape)
Expand All @@ -270,7 +251,8 @@ def load_embedding(self, data_dir, embedding_type):
def load_class_id(self, data_dir, total_num):
if os.path.isfile(data_dir + '/class_info.pickle'):
with open(data_dir + '/class_info.pickle', 'rb') as f:
class_id = pickle.load(f)
class_id = pickle.load(f, encoding="bytes")
# class_id = pickle.load(f, encoding="latin-1")
else:
class_id = np.arange(total_num)
return class_id
Expand All @@ -293,21 +275,18 @@ def prepair_training_pairs(self, index):
# captions = self.captions[key]
embeddings = self.embeddings[index, :, :]
img_name = '%s/images/%s.jpg' % (data_dir, key)
imgs = get_imgs(img_name, self.imsize,
bbox, self.transform, normalize=self.norm)
imgs = get_imgs(img_name, self.imsize, bbox, self.transform, normalize=self.norm)

wrong_ix = random.randint(0, len(self.filenames) - 1)
if(self.class_id[index] == self.class_id[wrong_ix]):
if self.class_id[index] == self.class_id[wrong_ix]:
wrong_ix = random.randint(0, len(self.filenames) - 1)
wrong_key = self.filenames[wrong_ix]
if self.bbox is not None:
wrong_bbox = self.bbox[wrong_key]
else:
wrong_bbox = None
wrong_img_name = '%s/images/%s.jpg' % \
(data_dir, wrong_key)
wrong_imgs = get_imgs(wrong_img_name, self.imsize,
wrong_bbox, self.transform, normalize=self.norm)
wrong_img_name = '%s/images/%s.jpg' % (data_dir, wrong_key)
wrong_imgs = get_imgs(wrong_img_name, self.imsize, wrong_bbox, self.transform, normalize=self.norm)

embedding_ix = random.randint(0, embeddings.shape[0] - 1)
embedding = embeddings[embedding_ix, :]
Expand All @@ -327,8 +306,7 @@ def prepair_test_pairs(self, index):
# captions = self.captions[key]
embeddings = self.embeddings[index, :, :]
img_name = '%s/images/%s.jpg' % (data_dir, key)
imgs = get_imgs(img_name, self.imsize,
bbox, self.transform, normalize=self.norm)
imgs = get_imgs(img_name, self.imsize, bbox, self.transform, normalize=self.norm)

if self.target_transform is not None:
embeddings = self.target_transform(embeddings)
Expand Down
14 changes: 7 additions & 7 deletions code/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import print_function
from miscc.config import cfg, cfg_from_file

import torch
import torchvision.transforms as transforms

Expand All @@ -16,7 +18,7 @@
sys.path.append(dir_path)


from miscc.config import cfg, cfg_from_file



# 19 classes --> 7 valid classes with 8,555 images
Expand Down Expand Up @@ -95,8 +97,7 @@ def parse_args():

now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
output_dir = '../output/%s_%s_%s' % \
(cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
output_dir = '../output/%s_%s_%s' % (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

split_dir, bshuffle = 'train', True
if not cfg.TRAIN.FLAG:
Expand All @@ -107,7 +108,7 @@ def parse_args():
# Get data loader
imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1))
image_transform = transforms.Compose([
transforms.Scale(int(imsize * 76 / 64)),
transforms.Resize(int(imsize * 76 / 64)),
transforms.RandomCrop(imsize),
transforms.RandomHorizontalFlip()])
if cfg.DATA_DIR.find('lsun') != -1:
Expand All @@ -128,9 +129,8 @@ def parse_args():
transform=image_transform)
assert dataset
num_gpu = len(cfg.GPU_ID.split(','))
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS))

# Define models and go to train/evaluate
if not cfg.GAN.B_CONDITION:
Expand Down
4 changes: 2 additions & 2 deletions code/miscc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def _merge_a_into_b(a, b):
if type(a) is not edict:
return

for k, v in a.iteritems():
for k, v in a.items():
# a must specify keys that are in b
if not b.has_key(k):
if k not in b:
raise KeyError('{} is not a valid config key'.format(k))

# the types must match, too
Expand Down
Loading