Skip to content

Commit

Permalink
deleted logger
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengdao-chen committed May 4, 2019
1 parent 795d55d commit 7e8bc2e
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 290 deletions.
254 changes: 0 additions & 254 deletions src/Logger.py

This file was deleted.

29 changes: 12 additions & 17 deletions src/main_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from data_generator import Generator
from load import get_lg_inputs, get_gnn_inputs
from models import GNN_bcd, GNN_multiclass
from Logger import Logger
import time
import matplotlib
matplotlib.use('Agg')
Expand Down Expand Up @@ -52,7 +51,6 @@
default='ErdosRenyi')
parser.add_argument('--batch_size', nargs='?', const=1, type=int, default=1)
parser.add_argument('--mode', nargs='?', const=1, type=str, default='train')
parser.add_argument('--path_logger', nargs='?', const=1, type=str, default='')
parser.add_argument('--path_gnn', nargs='?', const=1, type=str, default='')
parser.add_argument('--filename_existing_gnn', nargs='?', const=1, type=str, default='')
parser.add_argument('--print_freq', nargs='?', const=1, type=int, default=100)
Expand Down Expand Up @@ -96,7 +94,7 @@
template3 = '{:<10} {:<10} {:<10} '
template4 = '{:<10} {:<10.5f} {:<10.5f} \n'

def train_mcd_single(gnn, optimizer, logger, gen, n_classes, it):
def train_mcd_single(gnn, optimizer, gen, n_classes, it):
start = time.time()
W, labels = gen.sample_otf_single(is_training=True, cuda=torch.cuda.is_available())
labels = labels.type(dtype_l)
Expand Down Expand Up @@ -132,7 +130,7 @@ def train_mcd_single(gnn, optimizer, logger, gen, n_classes, it):
else:
loss_value = float(loss.data.numpy())

info = ['epoch', 'avg loss', 'avg acc', 'edge_density',
info = ['iter', 'avg loss', 'avg acc', 'edge_density',
'noise', 'model', 'elapsed']
out = [it, loss_value, acc, args.edge_density,
args.noise, 'GNN', elapsed]
Expand All @@ -144,21 +142,21 @@ def train_mcd_single(gnn, optimizer, logger, gen, n_classes, it):

return loss_value, acc

def train(gnn, logger, gen, n_classes=args.n_classes, iters=args.num_examples_train):
def train(gnn, gen, n_classes=args.n_classes, iters=args.num_examples_train):
gnn.train()
optimizer = torch.optim.Adamax(gnn.parameters(), lr=args.lr)
loss_lst = np.zeros([iters])
acc_lst = np.zeros([iters])
for it in range(iters):
loss_single, acc_single = train_mcd_single(gnn, optimizer, logger, gen, n_classes, it)
loss_single, acc_single = train_mcd_single(gnn, optimizer, gen, n_classes, it)
loss_lst[it] = loss_single
acc_lst[it] = acc_single
torch.cuda.empty_cache()
print ('Avg train loss', np.mean(loss_lst))
print ('Avg train acc', np.mean(acc_lst))
print ('Std train acc', np.std(acc_lst))

def test_mcd_single(gnn, logger, gen, n_classes, iter):
def test_mcd_single(gnn, gen, n_classes, it):

start = time.time()
W, labels = gen.sample_otf_single(is_training=False, cuda=torch.cuda.is_available())
Expand Down Expand Up @@ -196,9 +194,9 @@ def test_mcd_single(gnn, logger, gen, n_classes, iter):
else:
loss_value = float(loss_test.data.numpy())

info = ['epoch', 'avg loss', 'avg acc', 'edge_density',
info = ['iter', 'avg loss', 'avg acc', 'edge_density',
'noise', 'model', 'elapsed']
out = [iter, loss_value, acc_test, args.edge_density,
out = [it, loss_value, acc_test, args.edge_density,
args.noise, 'GNN', elapsed]
print(template1.format(*info))
print(template2.format(*out))
Expand All @@ -208,13 +206,13 @@ def test_mcd_single(gnn, logger, gen, n_classes, iter):

return loss_value, acc_test

def test(gnn, logger, gen, n_classes, iters=args.num_examples_test):
def test(gnn, gen, n_classes, iters=args.num_examples_test):
gnn.train()
loss_lst = np.zeros([iters])
acc_lst = np.zeros([iters])
for it in range(iters):
# inputs, labels, W = gen.sample_single(it, cuda=torch.cuda.is_available(), is_training=False)
loss_single, acc_single = test_mcd_single(gnn, logger, gen, n_classes, it)
loss_single, acc_single = test_mcd_single(gnn, gen, n_classes, it)
loss_lst[it] = loss_single
acc_lst[it] = acc_single
torch.cuda.empty_cache()
Expand All @@ -230,9 +228,6 @@ def count_parameters(model):
if __name__ == '__main__':
# print (args.eval_vs_train)

logger = Logger(args.path_logger)
logger.write_settings(args)

## One fixed generator
gen = Generator()
## generator setup
Expand Down Expand Up @@ -295,9 +290,9 @@ def count_parameters(model):
gnn.cuda()
print ('Training begins')
if (args.generative_model == 'SBM'):
train(gnn, logger, gen, 2)
train(gnn, gen, 2)
elif (args.generative_model == 'SBM_multiclass'):
train(gnn, logger, gen, args.n_classes)
train(gnn, gen, args.n_classes)
print ('Saving gnn ' + filename)
if torch.cuda.is_available():
torch.save(gnn.cpu(), path_plus_name)
Expand All @@ -314,7 +309,7 @@ def count_parameters(model):
print ('model status: train')
gnn.train()

test(gnn, logger, gen, args.n_classes)
test(gnn, gen, args.n_classes)

print ('total num of params:', count_parameters(gnn))

Expand Down
Loading

0 comments on commit 7e8bc2e

Please sign in to comment.