Skip to content

Commit

Permalink
updated to be in accordance with the lua code
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengdao-chen committed Dec 5, 2018
1 parent afa9c6f commit c48f2bf
Showing 1 changed file with 399 additions and 0 deletions.
399 changes: 399 additions & 0 deletions src/main_lg_lua_otf_cp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,399 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-

import numpy as np
import os
# import dependencies
from data_generator_mod import Generator
from load import get_lg_inputs
from model_lua import lGNN_multiclass
from Logger import Logger
import time
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt

#Pytorch requirements
import unicodedata
import string
import re
import random
import argparse

import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from losses import compute_loss_multiclass, compute_accuracy_multiclass
# import losses

from Logger import compute_accuracy_bcd

parser = argparse.ArgumentParser()

###############################################################################
# General Settings #
###############################################################################
parser.add_argument('--num_examples_train', nargs='?', const=1, type=int,
default=int(6000))
parser.add_argument('--num_examples_test', nargs='?', const=1, type=int,
default=int(1000))
parser.add_argument('--edge_density', nargs='?', const=1, type=float,
default=0.2)
parser.add_argument('--p_SBM', nargs='?', const=1, type=float,
default=0.8)
parser.add_argument('--q_SBM', nargs='?', const=1, type=float,
default=0.2)
parser.add_argument('--random_noise', action='store_true')
parser.add_argument('--noise', nargs='?', const=1, type=float, default=0.03)
parser.add_argument('--noise_model', nargs='?', const=1, type=int, default=2)
parser.add_argument('--generative_model', nargs='?', const=1, type=str,
default='ErdosRenyi')
parser.add_argument('--batch_size', nargs='?', const=1, type=int, default=1)
parser.add_argument('--mode', nargs='?', const=1, type=str, default='train')
parser.add_argument('--path_dataset', nargs='?', const=1, type=str, default='')
parser.add_argument('--path_logger', nargs='?', const=1, type=str, default='')
parser.add_argument('--path_gnn', nargs='?', const=1, type=str, default='')
parser.add_argument('--filename_existing_gnn', nargs='?', const=1, type=str, default='')
parser.add_argument('--print_freq', nargs='?', const=1, type=int, default=100)
parser.add_argument('--test_freq', nargs='?', const=1, type=int, default=500)
parser.add_argument('--save_freq', nargs='?', const=1, type=int, default=2000)
parser.add_argument('--clip_grad_norm', nargs='?', const=1, type=float,
default=40.0)
parser.add_argument('--freeze_bn', dest='eval_vs_train', action='store_true')
parser.set_defaults(eval_vs_train=False)

###############################################################################
# GNN Settings #
###############################################################################

parser.add_argument('--num_features', nargs='?', const=1, type=int,
default=20)
parser.add_argument('--num_layers', nargs='?', const=1, type=int,
default=20)
parser.add_argument('--n_classes', nargs='?', const=1, type=int,
default=2)
parser.add_argument('--J', nargs='?', const=1, type=int, default=4)
parser.add_argument('--N_train', nargs='?', const=1, type=int, default=50)
parser.add_argument('--N_test', nargs='?', const=1, type=int, default=50)
parser.add_argument('--lr', nargs='?', const=1, type=float, default=1e-3)

args = parser.parse_args()

if torch.cuda.is_available():
dtype = torch.cuda.FloatTensor
dtype_l = torch.cuda.LongTensor
torch.cuda.manual_seed(42)
else:
dtype = torch.FloatTensor
dtype_l = torch.LongTensor
torch.manual_seed(42)

batch_size = args.batch_size
criterion = nn.CrossEntropyLoss()
template1 = '{:<10} {:<10} {:<10} {:<15} {:<10} {:<10} {:<10} '
template2 = '{:<10} {:<10.5f} {:<10.5f} {:<15} {:<10} {:<10} {:<10.3f} \n'
template3 = '{:<10} {:<10} {:<10} '
template4 = '{:<10} {:<10.5f} {:<10.5f} \n'

def train_mcd_single(gnn, optimizer, logger, gen, n_classes, it):
start = time.time()
W, labels = gen.sample_otf_single(is_training=True, cuda=torch.cuda.is_available())
labels = labels.type(dtype_l)

print ('Num of edges: ', np.sum(W))

if (args.generative_model == 'SBM_multiclass') and (args.n_classes == 2):
labels = (labels + 1)/2

WW, x, WW_lg, y, P = get_lg_inputs(W, args.J)

# print ('WW', WW.shape)
# print ('WW_lg', WW_lg.shape)

if (torch.cuda.is_available()):
WW.cuda()
x.cuda()
WW_lg.cuda()
y.cuda()
P.cuda()
# print ('input', input)
pred = gnn(WW.type(dtype), x.type(dtype), WW_lg.type(dtype), y.type(dtype), P.type(dtype))

loss = compute_loss_multiclass(pred, labels, n_classes)
gnn.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(gnn.parameters(), args.clip_grad_norm)
optimizer.step()

acc = compute_accuracy_multiclass(pred, labels, n_classes)

elapsed = time.time() - start

if(torch.cuda.is_available()):
loss_value = float(loss.data.cpu().numpy())
else:
loss_value = float(loss.data.numpy())

info = ['epoch', 'avg loss', 'avg acc', 'edge_density',
'noise', 'model', 'elapsed']
out = [it, loss_value, acc, args.edge_density,
args.noise, 'LGNN', elapsed]
print(template1.format(*info))
print(template2.format(*out))

del WW
del WW_lg
del x
del y
del P

return loss_value, acc

def train(gnn, logger, gen, n_classes=args.n_classes, iters=args.num_examples_train):
gnn.train()
optimizer = torch.optim.Adamax(gnn.parameters(), lr=args.lr)
loss_lst = np.zeros([iters])
acc_lst = np.zeros([iters])
for it in range(iters):
# W, labels = gen.sample_otf_single(is_training=True, cuda=torch.cuda.is_available())
# WW, x, WW_lg, y, P = get_lg_inputs(W, args.J)
# print ("Num of edges: ", np.sum(W))
loss_single, acc_single = train_mcd_single(gnn, optimizer, logger, gen, n_classes, it)
loss_lst[it] = loss_single
acc_lst[it] = acc_single
torch.cuda.empty_cache()

if (it % 100 == 0) and (it >= 100):
# print ('Testing at check_pt begins')
print ('Check_pt at iteration ' + str(it) + ' :')
# test(gnn, logger, gen, args.n_classes, iters = 20)
print ('Avg train loss', np.mean(loss_lst[it-100:it]))
print ('Avg train acc', np.mean(acc_lst[it-100:it]))
print ('Std train acc', np.std(acc_lst[it-100:it]))

print ('Final avg train loss', np.mean(loss_lst))
print ('Final avg train acc', np.mean(acc_lst))
print ('Final std train acc', np.std(acc_lst))

def test_mcd_single(gnn, logger, gen, n_classes, it):

start = time.time()
W, labels = gen.sample_otf_single(is_training=False, cuda=torch.cuda.is_available())
labels = labels.type(dtype_l)
if (args.generative_model == 'SBM_multiclass') and (args.n_classes == 2):
labels = (labels + 1)/2
WW, x, WW_lg, y, P = get_lg_inputs(W, args.J)

# print ('WW', WW.shape)
# print ('WW_lg', WW_lg.shape)

if (torch.cuda.is_available()):
WW.cuda()
x.cuda()
WW_lg.cuda()
y.cuda()
P.cuda()
# print ('input', input)
pred_single = gnn(WW.type(dtype), x.type(dtype), WW_lg.type(dtype), y.type(dtype), P.type(dtype))
labels_single = labels

loss_test = compute_loss_multiclass(pred_single, labels_single, n_classes)
acc_test = compute_accuracy_multiclass(pred_single, labels_single, n_classes)

elapsed = time.time() - start
# if (it % args.print_freq == 0):
# info = ['it', 'avg loss', 'avg acc', 'edge_density',
# 'noise', 'model', 'elapsed']
# out = [it, loss_test, acc_test, args.edge_density,
# args.noise, 'lGNN', elapsed]
# print(template1.format(*info))
# print(template2.format(*out))

elapsed = time.time() - start

if(torch.cuda.is_available()):
loss_value = float(loss_test.data.cpu().numpy())
else:
loss_value = float(loss_test.data.numpy())

info = ['epoch', 'avg loss', 'avg acc', 'edge_density',
'noise', 'model', 'elapsed']
out = [it, loss_value, acc_test, args.edge_density,
args.noise, 'LGNN', elapsed]
print(template1.format(*info))
print(template2.format(*out))

del WW
del WW_lg
del x
del y
del P

return loss_value, acc_test

def test(gnn, logger, gen, n_classes, iters=args.num_examples_test):
gnn.train()
loss_lst = np.zeros([iters])
acc_lst = np.zeros([iters])
for it in range(iters):
# inputs, labels, W = gen.sample_single(it, cuda=torch.cuda.is_available(), is_training=False)
loss_single, acc_single = test_mcd_single(gnn, logger, gen, n_classes, it)
loss_lst[it] = loss_single
acc_lst[it] = acc_single
torch.cuda.empty_cache()
print ('Testing results:')
print ('Avg test loss', np.mean(loss_lst))
print ('Avg test acc', np.mean(acc_lst))
print ('Std test acc', np.std(acc_lst))



if __name__ == '__main__':

print ('main file starts here')

logger = Logger(args.path_logger)
logger.write_settings(args)

# ## One fixed generator
gen = Generator()
# generator setup
gen.N_train = args.N_train
gen.N_test = args.N_test
gen.edge_density = args.edge_density
gen.p_SBM = args.p_SBM
gen.q_SBM = args.q_SBM
gen.random_noise = args.random_noise
gen.noise = args.noise
gen.noise_model = args.noise_model
gen.generative_model = args.generative_model
gen.n_classes = args.n_classes

torch.backends.cudnn.enabled=False

if (args.mode == 'test'):
print ('In testing mode')
# filename = 'gnn_J' + str(args.J) + '_lyr' + str(args.num_layers) + '_Ntr' + str(gen.N_test) + '_it' + str(args.iterations)
filename = args.filename_existing_gnn
path_plus_name = os.path.join(args.path_gnn, filename)
if ((filename != '') and (os.path.exists(path_plus_name))):
print ('Loading gnn ' + filename)
gnn = torch.load(path_plus_name)
if torch.cuda.is_available():
gnn.cuda()
else:
print ('No such a gnn exists; creating a brand new one')
if (args.generative_model == 'SBM'):
gnn = lGNN_multiclass(args.num_features, args.num_layers, args.J + 2)
elif (args.generative_model == 'SBM_multiclass'):
gnn = lGNN_multiclass(args.num_features, args.num_layers, args.J + 2, n_classes=args.n_classes)
filename = 'lgnn_J' + str(args.J) + '_lyr' + str(args.num_layers) + '_Ntr' + str(args.N_train) + '_num' + str(args.num_examples_train)
path_plus_name = os.path.join(args.path_gnn, filename)
if torch.cuda.is_available():
gnn.cuda()
print ('Training begins')
# train_bcd(gnn, logger, gen)
# print ('Saving gnn ' + filename)
# if torch.cuda.is_available():
# torch.save(gnn.cpu(), path_plus_name)
# gnn.cuda()
# else:
# torch.save(gnn, path_plus_name)

elif (args.mode == 'train'):
filename = args.filename_existing_gnn
path_plus_name = os.path.join(args.path_gnn, filename)
if ((filename != '') and (os.path.exists(path_plus_name))):
print ('Loading gnn ' + filename)
gnn = torch.load(path_plus_name)
filename = filename + '_Ntr' + str(args.N_train) + '_num' + str(args.num_examples_train)
path_plus_name = os.path.join(args.path_gnn, filename)
# if (os.path.exists(path_plus_name)):
# print ('loading gnn ' + filename)
# gnn = torch.load(path_plus_name)
# else:
else:
print ('No such a gnn exists; creating a brand new one')
filename = 'lgnn_J' + str(args.J) + '_lyr' + str(args.num_layers) + '_Ntr' + str(args.N_train) + '_num' + str(args.num_examples_train)
path_plus_name = os.path.join(args.path_gnn, filename)
if (args.generative_model == 'SBM'):
gnn = lGNN_multiclass(args.num_features, args.num_layers, args.J + 2)
elif (args.generative_model == 'SBM_multiclass'):
gnn = lGNN_multiclass(args.num_features, args.num_layers, args.J + 2, n_classes=args.n_classes)

if torch.cuda.is_available():
gnn.cuda()
print ('Training begins')
if (args.generative_model == 'SBM'):
train_bcd(gnn, logger, gen)
elif (args.generative_model == 'SBM_multiclass'):
# train_mcd(gnn, logger, gen, args.n_classes)
train(gnn, logger, gen, args.n_classes)
print ('Saving gnn ' + filename)
if torch.cuda.is_available():
torch.save(gnn.cpu(), path_plus_name)
gnn.cuda()
else:
torch.save(gnn, path_plus_name)

elif (args.mode == 'search_opt_iters'):
iters_per_check = 600

for check_pt in range(int(args.num_examples_train / iters_per_check) + 1):
filename = args.filename_existing_gnn
path_plus_name = os.path.join(args.path_gnn, filename)
if ((filename != '') and (os.path.exists(path_plus_name))):
print ('Loading gnn ' + filename)
gnn = torch.load(path_plus_name)
# gnn = GNN_bcd(args.num_features, args.num_layers, args.J + 2)

filename = filename + '_Ntr' + str(gen.N_train) + '_it' + str(iters_per_check * (check_pt))
path_plus_name = os.path.join(args.path_gnn, filename)
else:
print ('creating a brand new gnn')
if (args.generative_model == 'SBM'):
gnn = lGNN_multiclass(args.num_features, args.num_layers, args.J + 2)
elif (args.generative_model == 'SBM_multiclass'):
gnn = lGNN_multiclass(args.num_features, args.num_layers, args.J + 2, n_classes=args.n_classes)
filename = 'lgnn_J' + str(args.J) + '_lyr' + str(args.num_layers) + '_Ntr' + str(gen.N_train) + '_it' + str(iters_per_check * (check_pt))
path_plus_name = os.path.join(args.path_gnn, filename)
if torch.cuda.is_available():
gnn.cuda()
print ('Training begins for num_iters = ' + str(iters_per_check * (check_pt)))
if (args.generative_model == 'SBM'):
train_bcd(gnn, logger, gen)
elif (args.generative_model == 'SBM_multiclass'):
train_mcd(gnn, logger, gen, args.n_classes)
# print ('Saving gnn ' + filename)
# if torch.cuda.is_available():
# torch.save(gnn.cpu(), path_plus_name)
# gnn.cuda()
# else:
# torch.save(gnn, path_plus_name)

print ('Testing the GNN at check_pt ' + str(check_pt))
test_mcd(gnn, logger, gen, args.n_classes)
#
# if torch.cuda.is_available():
# gnn.cuda()
#
# train_bcd(gnn, logger, gen)
#
# if ((args.mode == 'train') or (not os.path.exists(path_plus_name))):
# print ('saving gnn ' + filename)
# torch.save(gnn, path_plus_name)

print ('Testing the GNN:')
if args.eval_vs_train:
gnn.eval()
print ('model status: eval')
else:
print ('model status: train')
gnn.train()
# test(siamese_gnn, logger, gen)
# for i in range(100):
# test_mcd(gnn, logger, gen, args.n_classes, args.eval_vs_train)
test(gnn, logger, gen, args.n_classes)

0 comments on commit c48f2bf

Please sign in to comment.