diff --git a/src/data_generator_mod.py b/src/data_generator_mod.py index 9d49c45..388026d 100644 --- a/src/data_generator_mod.py +++ b/src/data_generator_mod.py @@ -23,21 +23,27 @@ from torch.autograd import Variable from torch import optim import torch.nn.functional as F +from load import get_Pm, get_Pd, get_W_lg class Generator(object): - def __init__(self): - self.N_train = 50 - self.N_test = 100 + def __init__(self, N_train=50, N_test=100, generative_model='SBM_multiclass', p_SBM=0.8, q_SBM=0.2, n_classes=2, path_dataset='', num_examples_train=100, num_examples_test=10): + self.N_train = N_train + self.N_test = N_test # self.generative_model = 'ErdosRenyi' - self.generative_model = 'SBM' - self.edge_density = 0.2 - self.random_noise = False - self.noise = 0.03 - self.noise_model = 2 - self.p_SBM = 0.8 - self.q_SBM = 0.2 - self.n_classes = 5 + self.generative_model = generative_model + # self.edge_density = 0.2 + # self.random_noise = False + # self.noise = 0.03 + # self.noise_model = 2 + self.p_SBM = p_SBM + self.q_SBM = q_SBM + self.n_classes = n_classes + self.path_dataset = path_dataset + self.data_train = None + self.data_test = None + self.num_examples_train = num_examples_train + self.num_examples_test = num_examples_test def SBM(self, p, q, N): W = np.zeros((N, N)) @@ -115,32 +121,107 @@ def RegularGraph_netx(self, p, N): W = np.array(W) return W + def create_dataset(self, directory, is_training): + if (self.generative_model == 'SBM_multiclass'): + if not os.path.exists(directory): + os.mkdir(directory) + if is_training: + graph_size = self.N_train + num_graphs = self.num_examples_train + else: + graph_size = self.N_test + num_graphs = self.num_examples_test + dataset = [] + for i in range(num_graphs): + W, labels = self.SBM_multiclass(self.p_SBM, self.q_SBM, graph_size, self.n_classes) + Pm = get_Pm(W) + Pd = get_Pd(W) + NB = get_W_lg(W) + example = {} + example['W'] = W + example['labels'] = labels + # example['Pm'] = Pm + # example['Pd'] = Pd + # example['NB'] = NB + dataset.append(example) + if is_training: + print ('Saving the training dataset') + else: + print ('Saving the testing dataset') + np.save(directory + '/dataset.npy', dataset) + if is_training: + self.data_train = dataset + else: + self.data_test = dataset + else: + raise ValueError('Generative model {} not supported'.format(self.generative_model)) + + + + def prepare_data(self): + train_directory = self.generative_model + '_nc' + str(self.n_classes) + '_p' + str(self.p_SBM) + '_q' + str(self.q_SBM) + '_gstr' + str(self.N_train) + '_numtr' + str(self.num_examples_train) + + train_path = os.path.join(self.path_dataset, train_directory) + if os.path.exists(train_path + '/dataset.npy'): + print('Reading training dataset at {}'.format(train_path)) + # self.data_train = self.load_from_directory(train_path) + self.data_train = np.load(train_path + '/dataset.npy') + else: + print('Creating training dataset.') + self.create_dataset(train_path, is_training=True) + # print('Saving training datatset at {}'.format(train_path)) + # np.save(path, self.data_train) + # load test dataset + test_directory = self.generative_model + '_nc' + str(self.n_classes) + '_p' + str(self.p_SBM) + '_q' + str(self.q_SBM) + '_gste' + str(self.N_test) + '_numte' + str(self.num_examples_test) + test_path = os.path.join(self.path_dataset, test_directory) + if os.path.exists(test_path + '/dataset.npy'): + print('Reading testing dataset at {}'.format(test_path)) + # self.data_test = self.load_from_directory(test_path) + self.data_test = np.load(test_path + '/dataset.npy') + else: + print('Creating testing dataset.') + self.create_dataset(test_path, is_training=False) + # print('Saving testing datatset at {}'.format(test_path)) + # np.save(open(path, 'wb'), self.data_test) + + def sample_single(self, i, is_training=True): + if is_training: + dataset = self.data_train + else: + dataset = self.data_test + example = dataset[i] + if (self.generative_model == 'SBM_multiclass'): + W_np = example['W'] + labels = np.expand_dims(example['labels'], 0) + labels_var = Variable(torch.from_numpy(labels), volatile=not is_training) + return W_np, labels_var + def sample_otf_single(self, is_training=True, cuda=True): if is_training: N = self.N_train else: N = self.N_test if self.generative_model == 'SBM': - W, labels = self.SBM(self.p_SBM, self.q_SBM, self.N_train) + W, labels = self.SBM(self.p_SBM, self.q_SBM, N) elif self.generative_model == 'SBM_multiclass': - W, labels = self.SBM_multiclass(self.p_SBM, self.q_SBM, self.N_train, self.n_classes) + W, labels = self.SBM_multiclass(self.p_SBM, self.q_SBM, N, self.n_classes) else: raise ValueError('Generative model {} not supported'.format(self.generative_model)) - if self.random_noise: - self.noise = np.random.uniform(0.000, 0.050, 1) - if self.noise_model == 1: - # use noise model from [arxiv 1602.04181], eq (3.8) - noise = self.ErdosRenyi(self.noise, self.N_train) - W_noise = W*(1-noise) + (1-W)*noise - elif self.noise_model == 2: - # use noise model from [arxiv 1602.04181], eq (3.9) - pe1 = self.noise - pe2 = (self.edge_density*self.noise)/(1.0-self.edge_density) - noise1 = self.ErdosRenyi_netx(pe1, self.N_train) - noise2 = self.ErdosRenyi_netx(pe2, self.N_train) - W_noise = W*(1-noise1) + (1-W)*noise2 - else: - raise ValueError('Noise model {} not implemented'.format(self.noise_model)) + # if self.random_noise: + # self.noise = np.random.uniform(0.000, 0.050, 1) + # if self.noise_model == 1: + # # use noise model from [arxiv 1602.04181], eq (3.8) + # noise = self.ErdosRenyi(self.noise, self.N_train) + # W_noise = W*(1-noise) + (1-W)*noise + # elif self.noise_model == 2: + # # use noise model from [arxiv 1602.04181], eq (3.9) + # pe1 = self.noise + # pe2 = (self.edge_density*self.noise)/(1.0-self.edge_density) + # noise1 = self.ErdosRenyi_netx(pe1, self.N_train) + # noise2 = self.ErdosRenyi_netx(pe2, self.N_train) + # W_noise = W*(1-noise1) + (1-W)*noise2 + # else: + # raise ValueError('Noise model {} not implemented'.format(self.noise_model)) labels = np.expand_dims(labels, 0) labels = Variable(torch.from_numpy(labels), volatile=not is_training) W = np.expand_dims(W, 0) diff --git a/src/load.py b/src/load.py index ef804bb..af3e8be 100644 --- a/src/load.py +++ b/src/load.py @@ -24,6 +24,25 @@ from torch import optim import torch.nn.functional as F +if (torch.cuda.is_available()): + dtype_sp = torch.cuda.sparse.FloatTensor + dtype = torch.cuda.FloatTensor +else: + dtype_sp = torch.sparse.FloatTensor + dtype = torch.FloatTensor + +def to_sparse(x): + """ converts dense tensor x to sparse format """ + x_typename = torch.typename(x).split('.')[-1] + sparse_tensortype = getattr(torch.sparse, x_typename) + + indices = torch.nonzero(x) + if len(indices.shape) == 0: # if all elements are zeros + return sparse_tensortype(*x.shape) + indices = indices.t() + values = x[tuple(indices[i] for i in range(indices.shape[0]))] + return sparse_tensortype(indices, values, x.size()) + def compute_operators(W, J): N = W.shape[0] # print ('W', W) @@ -38,16 +57,67 @@ def compute_operators(W, J): WW[:, :, j + 1] = QQ.copy() # QQ = np.dot(QQ, QQ) QQ = np.minimum(np.dot(QQ, QQ), np.ones(QQ.shape)) - WW[:, :, J] = D - WW[:, :, J + 1] = np.ones((N, N)) * 1.0 / float(N) + WW[:, :, J + 1] = D + # WW[:, :, J + 1] = np.ones((N, N)) * 1.0 / float(N) WW = np.reshape(WW, [N, N, J + 2]) x = np.reshape(d, [N, 1]) return WW, x +def compute_operators_sp(W, J): + N = W.shape[0] + # print ('W', W) + # print ('W size', W.size()) + # operators: {Id, W, W^2, ..., W^{J-1}, D, U} + d = W.sum(1) + D = np.diag(d) + QQ = W.copy() + WW = [] + I = np.eye(N) + I = torch.from_numpy(I)#.unsqueeze(0) + I = to_sparse(I).type(dtype_sp) + I = Variable(I, volatile=False) + WW.append(I) + for j in range(J): + # QQc = np.expand_dims(QQ.copy(), 0) + QQc = QQ.copy() + QQc = torch.from_numpy(QQc) + QQc = to_sparse(QQc).type(dtype_sp) + QQc = Variable(QQc, volatile=False) + WW.append(QQc) + # QQ = np.dot(QQ, QQ) + QQ = np.minimum(np.dot(QQ, QQ), np.ones(QQ.shape)) + # D = torch.from_numpy(np.expand_dims(D, 0)) + D = torch.from_numpy(D) + D = to_sparse(D).type(dtype_sp) + D = Variable(D, volatile=False) + WW.append(D) + # WW[:, :, J + 1] = np.ones((N, N)) * 1.0 / float(N) + x = np.reshape(d, [N, 1]) + return WW, x + +def compute_operators_noD(W, J): + N = W.shape[0] + # print ('W', W) + # print ('W size', W.size()) + # operators: {Id, W, W^2, ..., W^{J-1}, D, U} + d = W.sum(1) + D = np.diag(d) + QQ = W.copy() + WW = np.zeros([N, N, J + 1]) + WW[:, :, 0] = np.eye(N) + for j in range(J): + WW[:, :, j + 1] = QQ.copy() + # QQ = np.dot(QQ, QQ) + QQ = np.minimum(np.dot(QQ, QQ), np.ones(QQ.shape)) + # WW[:, :, J] = np.ones((N, N)) * 1.0 / float(N) + WW = np.reshape(WW, [N, N, J + 1]) + x = np.reshape(d, [N, 1]) + return WW, x + def get_Pm(W): N = W.shape[0] W = W * (np.ones([N, N]) - np.eye(N)) - M = int(W.sum()) + M = int(W.sum()) // 2 p = 0 Pm = np.zeros([N, M * 2]) for n in range(N): @@ -63,15 +133,15 @@ def get_Pm(W): def get_Pd(W): N = W.shape[0] W = W * (np.ones([N, N]) - np.eye(N)) - M = int(W.sum()) + M = int(W.sum()) // 2 p = 0 Pd = np.zeros([N, M * 2]) for n in range(N): for m in range(n+1, N): if (W[n][m]==1): Pd[n][p] = 1 - Pd[m][p] = 1 - Pd[n][p + M] = 1 + Pd[m][p] = -1 + Pd[n][p + M] = -1 Pd[m][p + M] = 1 p += 1 return Pd @@ -81,14 +151,57 @@ def get_P(W): P = np.concatenate((np.expand_dims(get_Pm(W), 2), np.expand_dims(get_Pd(W), 2)), axis=2) return P +def get_P_sp(W): + # P = np.concatenate((np.expand_dims(get_Pm(W), 2), np.expand_dims(get_Pd(W), 2)), axis=2) + P = [] + # Pm = np.expand_dims(get_Pm(W), 0) + Pm = get_Pm(W) + Pm = torch.from_numpy(Pm) + Pm = to_sparse(Pm).type(dtype_sp) + Pm = Variable(Pm, volatile=False) + # Pd = np.expand_dims(get_Pd(W), 0) + Pd = get_Pd(W) + Pd = torch.from_numpy(Pd) + Pd = to_sparse(Pd).type(dtype_sp) + Pd = Variable(Pd, volatile=False) + P.append(Pm) + P.append(Pd) + return P + def get_W_lg(W): W_lg = np.transpose(get_Pm(W)).dot(get_Pd(W)) return W_lg +def get_NB(W): + product = np.transpose(get_Pm(W)).dot(get_Pd(W)) + M_doubled = product.shape[0] + NB = product * (product > 0) + for i in range(M_doubled): + NB[i, i] = 0 + return NB + +def get_NB_2(W): + Pm = get_Pm(W) + Pd = get_Pd(W) + Pf = (Pm + Pd) / 2 + Pt = (Pm - Pd) / 2 + NB = np.transpose(Pt).dot(Pf) * (1 - np.transpose(Pf).dot(Pt)) + return NB + +# def get_NB_direct(W): +# Pd = get_Pd(W) +# M_doubled = Pd.shape[1] +# NB = np.zeros(M_doubled, M_doubled) +# for i in range(M_doubled): + + def get_lg_inputs(W, J): - W = W[0, :, :] + if (W.ndim == 3): + W = W[0, :, :] WW, x = compute_operators(W, J) - W_lg = get_W_lg(W) + # W_lg = get_W_lg(W) + # W_lg = get_NB(W) + W_lg = get_NB_2(W) WW_lg, y = compute_operators(W_lg, J) P = get_P(W) x = x.astype(float) @@ -103,4 +216,60 @@ def get_lg_inputs(W, J): P = Variable(torch.from_numpy(P).unsqueeze(0), volatile=False) return WW, x, WW_lg, y, P +def get_lg_inputs_noD(W, J): + if (W.ndim == 3): + W = W[0, :, :] + WW, x = compute_operators_noD(W, J) + # W_lg = get_W_lg(W) + W_lg = get_NB_2(W) + WW_lg, y = compute_operators_noD(W_lg, J) + P = get_P(W) + x = x.astype(float) + y = y.astype(float) + WW = WW.astype(float) + WW_lg = WW_lg.astype(float) + P = P.astype(float) + WW = Variable(torch.from_numpy(WW).unsqueeze(0), volatile=False) + x = Variable(torch.from_numpy(x).unsqueeze(0), volatile=False) + WW_lg = Variable(torch.from_numpy(WW_lg).unsqueeze(0), volatile=False) + y = Variable(torch.from_numpy(y).unsqueeze(0), volatile=False) + P = Variable(torch.from_numpy(P).unsqueeze(0), volatile=False) + return WW, x, WW_lg, y, P + +def get_splg_inputs(W, J): + if (W.ndim == 3): + W = W[0, :, :] + WW, x = compute_operators_sp(W, J) + # W_lg = get_W_lg(W) + W_lg = get_NB_2(W) + WW_lg, y = compute_operators_sp(W_lg, J) + P = get_P_sp(W) + # x = x.astype(float) + # y = y.astype(float) + # WW = WW.astype(float) + # WW_lg = WW_lg.astype(float) + # P = P.astype(float) + # WW = Variable(to_sparse(torch.from_numpy(WW).unsqueeze(0)), volatile=False) + x = Variable(torch.from_numpy(x), volatile=False).type(dtype) + # WW_lg = Variable(to_sparse(torch.from_numpy(WW_lg).unsqueeze(0)), volatile=False) + y = Variable(torch.from_numpy(y), volatile=False).type(dtype) + # P = Variable(to_sparse(torch.from_numpy(P).unsqueeze(0)), volatile=False) + return WW, x, WW_lg, y, P + +def get_gnn_inputs(W, J): + W = W[0, :, :] + WW, x = compute_operators(W, J) + WW = WW.astype(float) + WW = Variable(torch.from_numpy(WW).unsqueeze(0), volatile=False) + x = Variable(torch.from_numpy(x).unsqueeze(0), volatile=False) + return WW, x + +def get_gnn_inputs_noD(W, J): + W = W[0, :, :] + WW, x = compute_operators_noD(W, J) + WW = WW.astype(float) + WW = Variable(torch.from_numpy(WW).unsqueeze(0), volatile=False) + x = Variable(torch.from_numpy(x).unsqueeze(0), volatile=False) + return WW, x + diff --git a/src/main_gnn_otf_cp.py b/src/main_gnn_otf_cp.py index 32d8cc2..fef6c1c 100644 --- a/src/main_gnn_otf_cp.py +++ b/src/main_gnn_otf_cp.py @@ -5,7 +5,7 @@ import os # import dependencies from data_generator_mod import Generator -from load import get_lg_inputs +from load import get_lg_inputs, get_gnn_inputs from model import GNN_bcd, GNN_multiclass from Logger import Logger import time @@ -106,7 +106,7 @@ def train_mcd_single(gnn, optimizer, logger, gen, n_classes, it): if (args.generative_model == 'SBM_multiclass') and (args.n_classes == 2): labels = (labels + 1)/2 - WW, x, WW_lg, y, P = get_lg_inputs(W, args.J) + WW, x = get_gnn_inputs(W, args.J) # print ('WW', WW.shape) # print ('WW_lg', WW_lg.shape) @@ -114,9 +114,7 @@ def train_mcd_single(gnn, optimizer, logger, gen, n_classes, it): if (torch.cuda.is_available()): WW.cuda() x.cuda() - WW_lg.cuda() - y.cuda() - P.cuda() + # print ('input', input) # pred = gnn(WW.type(dtype), x.type(dtype), WW_lg.type(dtype), y.type(dtype), P.type(dtype)) pred = gnn(WW.type(dtype), x.type(dtype)) @@ -139,15 +137,12 @@ def train_mcd_single(gnn, optimizer, logger, gen, n_classes, it): info = ['epoch', 'avg loss', 'avg acc', 'edge_density', 'noise', 'model', 'elapsed'] out = [it, loss_value, acc, args.edge_density, - args.noise, 'lGNN', elapsed] + args.noise, 'GNN', elapsed] print(template1.format(*info)) print(template2.format(*out)) del WW - del WW_lg del x - del y - del P return loss_value, acc @@ -163,6 +158,7 @@ def train(gnn, logger, gen, n_classes=args.n_classes, iters=args.num_examples_tr torch.cuda.empty_cache() print ('Avg train loss', np.mean(loss_lst)) print ('Avg train acc', np.mean(acc_lst)) + print ('Std train acc', np.std(acc_lst)) def test_mcd_single(gnn, logger, gen, n_classes, iter): @@ -171,17 +167,14 @@ def test_mcd_single(gnn, logger, gen, n_classes, iter): labels = labels.type(dtype_l) if (args.generative_model == 'SBM_multiclass') and (args.n_classes == 2): labels = (labels + 1)/2 - WW, x, WW_lg, y, P = get_lg_inputs(W, args.J) + WW, x = get_gnn_inputs(W, args.J) print ('WW', WW.shape) - print ('WW_lg', WW_lg.shape) if (torch.cuda.is_available()): WW.cuda() x.cuda() - WW_lg.cuda() - y.cuda() - P.cuda() + # print ('input', input) pred_single = gnn(WW.type(dtype), x.type(dtype)) labels_single = labels @@ -208,15 +201,12 @@ def test_mcd_single(gnn, logger, gen, n_classes, iter): info = ['epoch', 'avg loss', 'avg acc', 'edge_density', 'noise', 'model', 'elapsed'] out = [iter, loss_value, acc_test, args.edge_density, - args.noise, 'lGNN', elapsed] + args.noise, 'GNN', elapsed] print(template1.format(*info)) print(template2.format(*out)) del WW - del WW_lg del x - del y - del P return loss_value, acc_test @@ -232,6 +222,7 @@ def test(gnn, logger, gen, n_classes, iters=args.num_examples_test): torch.cuda.empty_cache() print ('Avg test loss', np.mean(loss_lst)) print ('Avg test acc', np.mean(acc_lst)) + print ('Std test acc', np.std(acc_lst)) diff --git a/src/main_lg_otf_cp.py b/src/main_lg_otf_cp.py index 03dc8c0..0a29bc4 100644 --- a/src/main_lg_otf_cp.py +++ b/src/main_lg_otf_cp.py @@ -103,6 +103,8 @@ def train_mcd_single(gnn, optimizer, logger, gen, n_classes, it): W, labels = gen.sample_otf_single(is_training=True, cuda=torch.cuda.is_available()) labels = labels.type(dtype_l) + print ('Num of edges: ', np.sum(W)) + if (args.generative_model == 'SBM_multiclass') and (args.n_classes == 2): labels = (labels + 1)/2 @@ -156,12 +158,19 @@ def train(gnn, logger, gen, n_classes=args.n_classes, iters=args.num_examples_tr loss_lst = np.zeros([iters]) acc_lst = np.zeros([iters]) for it in range(iters): + # W, labels = gen.sample_otf_single(is_training=True, cuda=torch.cuda.is_available()) + # WW, x, WW_lg, y, P = get_lg_inputs(W, args.J) + # print ("Num of edges: ", np.sum(W)) loss_single, acc_single = train_mcd_single(gnn, optimizer, logger, gen, n_classes, it) loss_lst[it] = loss_single acc_lst[it] = acc_single torch.cuda.empty_cache() + + if (it % 200 == 0): + test(gnn, logger, gen, args.n_classes, iters = 20) print ('Avg train loss', np.mean(loss_lst)) print ('Avg train acc', np.mean(acc_lst)) + print ('Std train acc', np.std(acc_lst)) def test_mcd_single(gnn, logger, gen, n_classes, iter): @@ -172,8 +181,8 @@ def test_mcd_single(gnn, logger, gen, n_classes, iter): labels = (labels + 1)/2 WW, x, WW_lg, y, P = get_lg_inputs(W, args.J) - print ('WW', WW.shape) - print ('WW_lg', WW_lg.shape) + # print ('WW', WW.shape) + # print ('WW_lg', WW_lg.shape) if (torch.cuda.is_available()): WW.cuda() @@ -231,6 +240,7 @@ def test(gnn, logger, gen, n_classes, iters=args.num_examples_test): torch.cuda.empty_cache() print ('Avg test loss', np.mean(loss_lst)) print ('Avg test acc', np.mean(acc_lst)) + print ('Std test acc', np.std(acc_lst)) diff --git a/src/model.py b/src/model.py index 97aa166..4a445d6 100644 --- a/src/model.py +++ b/src/model.py @@ -59,9 +59,16 @@ def GMul(W, x): # print (W) N = W_size[-3] J = W_size[-1] - W = W.split(1, 3) - W = torch.cat(W, 1).squeeze(3) # W is now a tensor of size (bs, J*N, N) - output = torch.bmm(W, x) # output has size (bs, J*N, num_features) + W_lst = W.split(1, 3) + # print (len(W_lst)) + if N > 5000: + output_lst = [] + for W in W_lst: + output_lst.append(torch.bmm(W.squeeze(3),x)) + output = torch.cat(output_lst, 1) + else: + W = torch.cat(W_lst, 1).squeeze(3) # W is now a tensor of size (bs, J*N, N) + output = torch.bmm(W, x) # output has size (bs, J*N, num_features) output = output.split(N, 1) output = torch.cat(output, 2) # output has size (bs, N, J*num_features) return output @@ -170,6 +177,8 @@ def forward(self, WW, x): # # print ('W size', W.size()) # # print ('x size', input[1].size()) # x = gmul(input) # out has size (bs, N, num_inputs) + # print ('WW size', WW.shape) + # print ('x size', x.shape) x = GMul(WW, x) x_size = x.size() # print (x_size) @@ -252,7 +261,8 @@ def forward(self, WW, x, WW_lg, y, P): x_output = x_output.view(*x2x_size[:-1], self.num_outputs) - + # print ('WW_lg shape', WW_lg.shape) + # print ('y shape', y.shape) y2y = GMul(WW_lg, y) y2y_size = y2y.size() y2y = y2y.contiguous() @@ -355,6 +365,7 @@ def __init__(self, num_features, num_layers, J, n_classes=2): def forward(self, W, x, W_lg, y, P): cur = self.layer0(W, x, W_lg, y, P) for i in range(self.num_layers): + # print ('layer', i) cur = self._modules['layer{}'.format(i+1)](*cur) out = self.layerlast(*cur) return out[1] @@ -401,7 +412,9 @@ def __init__(self, num_features, num_layers, J, n_classes=2): def forward(self, W, x): cur = self.layer0(W, x) + # print ('layer0') for i in range(self.num_layers): + # print ('layer', i+1) cur = self._modules['layer{}'.format(i+1)](*cur) out = self.layerlast(*cur) return out[1] diff --git a/src/model_lua.py b/src/model_lua.py new file mode 100644 index 0000000..20b7b20 --- /dev/null +++ b/src/model_lua.py @@ -0,0 +1,463 @@ +#!/usr/bin/python +# -*- coding: UTF-8 -*- + +import matplotlib +matplotlib.use('Agg') + +# Pytorch requirements +import unicodedata +import string +import re +import random + +import torch +import torch.nn as nn +from torch.nn import init +from torch.autograd import Variable +from torch import optim +import torch.nn.functional as F + +if torch.cuda.is_available(): + dtype = torch.cuda.FloatTensor + dtype_l = torch.cuda.LongTensor +else: + dtype = torch.FloatTensor + dtype_l = torch.cuda.LongTensor + + +def GMul(W, x): + # x is a tensor of size (bs, N, num_features) + # W is a tensor of size (bs, N, N, J) + x_size = x.size() + # print (x) + W_size = W.size() + # print ('WW', W_size) + # print (W) + N = W_size[-3] + J = W_size[-1] + W_lst = W.split(1, 3) + # print (len(W_lst)) + if N > 5000: + output_lst = [] + for W in W_lst: + output_lst.append(torch.bmm(W.squeeze(3),x)) + output = torch.cat(output_lst, 1) + else: + W = torch.cat(W_lst, 1).squeeze(3) # W is now a tensor of size (bs, J*N, N) + # print ('W', W.shape) + # print ('x', x.shape) + output = torch.bmm(W, x) # output has size (bs, J*N, num_features) + output = output.split(N, 1) + output = torch.cat(output, 2) # output has size (bs, N, J*num_features) + return output + +class Gconv_last(nn.Module): + def __init__(self, feature_maps, J): + super(Gconv_last, self).__init__() + self.num_inputs = J*feature_maps[0] + self.num_outputs = feature_maps[2] + self.fc = nn.Linear(self.num_inputs, self.num_outputs) + + def forward(self, input): + W = input[0] + x = gmul(input) # out has size (bs, N, num_inputs) + x_size = x.size() + x = x.contiguous() + x = x.view(x_size[0]*x_size[1], -1) + x = self.fc(x) # has size (bs*N, num_outputs) + x = x.view(*x_size[:-1], self.num_outputs) + return W, x + + +class gnn_atomic(nn.Module): + def __init__(self, feature_maps, J): + super(gnn_atomic, self).__init__() + self.num_inputs = J*feature_maps[0] + self.num_outputs = feature_maps[2] + self.fc1 = nn.Linear(self.num_inputs, self.num_outputs // 2) + self.fc2 = nn.Linear(self.num_inputs, self.num_outputs - self.num_outputs // 2) + self.bn2d = nn.BatchNorm2d(self.num_outputs) + + def forward(self, WW, x): + # W = input[0] + # # print ('W size', W.size()) + # # print ('x size', input[1].size()) + # x = gmul(input) # out has size (bs, N, num_inputs) + # print ('WW size', WW.shape) + # print ('x size', x.shape) + x = GMul(WW, x) + x_size = x.size() + # print (x_size) + x = x.contiguous() + x = x.view(-1, self.num_inputs) + # print (x.size()) + x1 = F.relu(self.fc1(x)) # has size (bs*N, num_outputs) + x2 = self.fc2(x) + x = torch.cat((x1, x2), 1) + # x = self.bn2d(x.unsqueeze(0).unsqueeze(3)).squeeze(3).squeeze(0) + x = self.bn2d(x) + # print (x.size()) + x = x.view(*x_size[:-1], self.num_outputs) + return WW, x + +class gnn_atomic_final(nn.Module): + def __init__(self, feature_maps, J, n_classes): + super(gnn_atomic_final, self).__init__() + self.num_inputs = J*feature_maps[0] + self.num_outputs = n_classes + self.fc = nn.Linear(self.num_inputs, self.num_outputs) + + def forward(self, WW, x): + x = GMul(WW, x) # out has size (bs, N, num_inputs) + x_size = x.size() + x = x.contiguous() + x = x.view(x_size[0]*x_size[1], -1) + x = self.fc(x) # has size (bs*N, num_outputs) + # x = F.tanh(x) # added for last layer + x = x.view(*x_size[:-1], self.num_outputs) + return WW, x + + + +class gnn_atomic_lg(nn.Module): + def __init__(self, feature_maps, J): + super(gnn_atomic_lg, self).__init__() + # self.num_inputs_1 = J*feature_maps[0] + # self.num_inputs_2 = 2 * feature_maps[1] + # self.num_inputs_3 = 4 * feature_maps[2] + # self.num_outputs = feature_maps[2] + self.feature_maps = feature_maps + self.J = J + self.fcx2x_1 = nn.Linear(J * feature_maps[0], feature_maps[2]) + self.fcy2x_1 = nn.Linear(2 * feature_maps[1], feature_maps[2]) + self.fcx2x_2 = nn.Linear(J * feature_maps[0], feature_maps[2]) + self.fcy2x_2 = nn.Linear(2 * feature_maps[1], feature_maps[2]) + self.fcx2y_1 = nn.Linear(J * feature_maps[1], feature_maps[2]) + self.fcy2y_1 = nn.Linear(4 * feature_maps[2], feature_maps[2]) + self.fcx2y_2 = nn.Linear(J * feature_maps[1], feature_maps[2]) + self.fcy2y_2 = nn.Linear(4 * feature_maps[2], feature_maps[2]) + self.bn2d_x = nn.BatchNorm2d(2 * feature_maps[2]) + self.bn2d_y = nn.BatchNorm2d(2 * feature_maps[2]) + + def forward(self, WW, x, WW_lg, y, P): + # print ('W size', W.size()) + # print ('x size', input[1].size()) + xa1 = GMul(WW, x) # out has size (bs, N, num_inputs) + xa1_size = xa1.size() + # print (x_size) + xa1 = xa1.contiguous() + # print ('xa1', xa1.shape) + # print ('J', self.J) + # print ('fm0', self.feature_maps[0]) + xa1 = xa1.view(-1, self.J * self.feature_maps[0]) + # print (x.size()) + # print ('x2x', x2x) + # xa1 = xa1.type(dtype) + + # y2x = torch.bmm(P, y) + xb1 = GMul(P, y) + # xb1 = xb1.size() + xb1 = xb1.contiguous() + # print ('xb1', xb1.shape) + xb1 = xb1.view(-1, 2 * self.feature_maps[1]) + + # y2x = y2x.type(dtype) + + # xy2x = x2x + y2x + z1 = F.relu(self.fcx2x_1(xa1) + self.fcy2x_1(xb1)) # has size (bs*N, num_outputs) + + yl1 = self.fcx2x_2(xa1) + self.fcy2x_2(xb1) + zb1 = torch.cat((yl1, z1), 1) + # x_output = self.bn2d_x(x_cat) + zc1 = self.bn2d_x(zb1.unsqueeze(2).unsqueeze(3)).squeeze(3).squeeze(2) + # print ('zc1', zc1.shape) + zc1 = zc1.view(*xa1_size[:-1], 2 * self.feature_maps[2]) + # print ('zc1', zc1.shape) + x_output = zc1 + + # print ('WW_lg shape', WW_lg.shape) + # print ('y shape', y.shape) + xda1 = GMul(WW_lg, y) + xda1_size = xda1.size() + xda1 = xda1.contiguous() + # print ('xda1', xda1.shape) + xda1 = xda1.view(-1, self.J * self.feature_maps[1]) + + # y2y = y2y.type(dtype) + + # x2y = torch.bmm(torch.t(P), x) + xdb1 = GMul(torch.transpose(P, 2, 1), zc1) + # xdb1_size = xdb1.size() + xdb1 = xdb1.contiguous() + # print ('xdb1', xdb1.shape) + xdb1 = xdb1.view(-1, 4 * self.feature_maps[2]) + + # x2y = x2y.type(dtype) + + # xy2y = x2y + y2y + zd1 = F.relu(self.fcx2y_1(xda1) + self.fcy2y_1(xdb1)) + + ydl1 = self.fcx2y_2(xda1) + self.fcy2y_2(xdb1) + + zdb1 = torch.cat((ydl1, zd1), 1) + # y_output = self.bn2d_x(y_cat) + zdc1 = self.bn2d_y(zdb1.unsqueeze(2).unsqueeze(3)).squeeze(3).squeeze(2) + + # print ('zdc1', zdc1.shape) + + zdc1 = zdc1.view(*xda1_size[:-1], 2 * self.feature_maps[2]) + y_output = zdc1 + + # WW = WW.type(dtype) + + return WW, x_output, WW_lg, y_output, P + +# class gnn_atomic_lg(nn.Module): +# def __init__(self, feature_maps, J): +# super(gnn_atomic_lg, self).__init__() +# self.num_inputs = J*feature_maps[0] +# self.num_inputs_2 = 2 * feature_maps[1] +# # self.num_inputs_3 = 4 * feature_maps[2] +# self.num_outputs = feature_maps[2] +# self.fcx2x_1 = nn.Linear(self.num_inputs, self.num_outputs // 2) +# self.fcy2x_1 = nn.Linear(self.num_inputs_2, self.num_outputs // 2) +# self.fcx2x_2 = nn.Linear(self.num_inputs, self.num_outputs - self.num_outputs // 2) +# self.fcy2x_2 = nn.Linear(self.num_inputs_2, self.num_outputs - self.num_outputs // 2) +# self.fcx2y_1 = nn.Linear(self.num_inputs_2, self.num_outputs // 2) +# self.fcy2y_1 = nn.Linear(self.num_inputs, self.num_outputs // 2) +# self.fcx2y_2 = nn.Linear(self.num_inputs_2, self.num_outputs - self.num_outputs // 2) +# self.fcy2y_2 = nn.Linear(self.num_inputs, self.num_outputs - self.num_outputs // 2) +# self.bn2d_x = nn.BatchNorm2d(self.num_outputs) +# self.bn2d_y = nn.BatchNorm2d(self.num_outputs) + +# def forward(self, WW, x, WW_lg, xd, P): +# # print ('W size', W.size()) +# # print ('x size', input[1].size()) +# xa1 = GMul(WW, x) # out has size (bs, N, num_inputs) +# # x2x_size = xa1.size() +# # # print (x_size) +# # x2x = x2x.contiguous() +# # x2x = x2x.view(-1, self.num_inputs) +# # # print (x.size()) +# # # print ('x2x', x2x) +# # x2x = x2x.type(dtype) + +# # y2x = torch.bmm(P, y) +# xb1 = GMul(P, xd) +# # y2x_size = y2x.size() +# # y2x = y2x.contiguous() +# # y2x = y2x.view(-1, self.num_inputs_2) + +# # y2x = y2x.type(dtype) + +# # x1 = torch.cat([xa1, xb1], 2) +# # xy2x = x2x + y2x +# xy2x = F.relu(self.fcx2x_1(x2x) + self.fcy2x_1(y2x)) # has size (bs*N, num_outputs) + +# xy2x_l = self.fcx2x_2(x2x) + self.fcy2x_2(y2x) +# x_cat = torch.cat((xy2x, xy2x_l), 1) +# # x_output = self.bn2d_x(x_cat) +# x_output = self.bn2d_x(x_cat.unsqueeze(2).unsqueeze(3)).squeeze(3).squeeze(2) + +# x_output = x_output.view(*x2x_size[:-1], self.num_outputs) + +# # print ('WW_lg shape', WW_lg.shape) +# # print ('y shape', y.shape) +# y2y = GMul(WW_lg, y) +# y2y_size = y2y.size() +# y2y = y2y.contiguous() +# y2y = y2y.view(-1, self.num_inputs) + +# y2y = y2y.type(dtype) + +# # x2y = torch.bmm(torch.t(P), x) +# x2y = GMul(torch.transpose(P, 2, 1), x) +# x2y_size = x2y.size() +# x2y = x2y.contiguous() +# x2y = x2y.view(-1, self.num_inputs_2) + +# x2y = x2y.type(dtype) + +# # xy2y = x2y + y2y +# xy2y = F.relu(self.fcx2y_1(x2y) + self.fcy2y_1(y2y)) + +# xy2y_l = self.fcx2y_2(x2y) + self.fcy2y_2(y2y) + +# y_cat = torch.cat((xy2y, xy2y_l), 1) +# # y_output = self.bn2d_x(y_cat) +# y_output = self.bn2d_y(y_cat.unsqueeze(2).unsqueeze(3)).squeeze(3).squeeze(2) + +# y_output = y_output.view(*y2y_size[:-1], self.num_outputs) + +# # WW = WW.type(dtype) + +# return WW, x_output, WW_lg, y_output, P + +class gnn_atomic_lg_final(nn.Module): + def __init__(self, feature_maps, J, n_classes): + super(gnn_atomic_lg_final, self).__init__() + self.num_inputs = J*feature_maps[0] + self.num_inputs_2 = 2 * feature_maps[1] + self.num_outputs = n_classes + self.fcx2x_1 = nn.Linear(self.num_inputs, self.num_outputs) + self.fcy2x_1 = nn.Linear(self.num_inputs_2, self.num_outputs) + + def forward(self, W, x, W_lg, y, P): + # print ('W size', W.size()) + # print ('x size', input[1].size()) + x2x = GMul(W, x) # out has size (bs, N, num_inputs) + x2x_size = x2x.size() + # print (x_size) + x2x = x2x.contiguous() + x2x = x2x.view(-1, self.num_inputs) + # print (x.size()) + + # y2x = torch.bmm(P, y) + y2x = GMul(P, y) + y2x_size = x2x.size() + y2x = y2x.contiguous() + y2x = y2x.view(-1, self.num_inputs_2) + + # xy2x = x2x + y2x + xy2x = self.fcx2x_1(x2x) + self.fcy2x_1(y2x) # has size (bs*N, num_outputs) + + x_output = xy2x.view(*x2x_size[:-1], self.num_outputs) + + return W, x_output + +class GNN(nn.Module): + def __init__(self, num_features, num_layers, J): + super(GNN, self).__init__() + self.num_features = num_features + self.num_layers = num_layers + self.featuremap_in = [1, 1, num_features] + self.featuremap_mi = [num_features, num_features, num_features] + self.featuremap_end = [num_features, num_features, 1] + self.layer0 = Gconv(self.featuremap_in, J) + for i in range(num_layers): + module = Gconv(self.featuremap_mi, J) + self.add_module('layer{}'.format(i + 1), module) + self.layerlast = Gconv_last(self.featuremap_end, J) + + def forward(self, input): + cur = self.layer0(input) + for i in range(self.num_layers): + cur = self._modules['layer{}'.format(i+1)](cur) + out = self.layerlast(cur) + return out[1] + +class lGNN_multiclass(nn.Module): + def __init__(self, num_features, num_layers, J, n_classes=2): + super(lGNN_multiclass, self).__init__() + self.num_features = num_features + self.num_layers = num_layers + self.featuremap_in = [1, 1, num_features // 2] + self.featuremap_mi = [num_features, num_features, num_features // 2] + self.featuremap_end = [num_features, num_features, 1] + # self.layer0 = Gconv(self.featuremap_in, J) + self.layer0 = gnn_atomic_lg(self.featuremap_in, J) + for i in range(num_layers): + # module = Gconv(self.featuremap_mi, J) + module = gnn_atomic_lg(self.featuremap_mi, J) + self.add_module('layer{}'.format(i + 1), module) + self.layerlast = gnn_atomic_lg_final(self.featuremap_end, J, n_classes) + + def forward(self, W, x, W_lg, y, P): + cur = self.layer0(W, x, W_lg, y, P) + for i in range(self.num_layers): + # print ('layer', i) + cur = self._modules['layer{}'.format(i+1)](*cur) + out = self.layerlast(*cur) + return out[1] + + +class GNN_bcd(nn.Module): + def __init__(self, num_features, num_layers, J, n_classes=2): + super(GNN_bcd, self).__init__() + self.num_features = num_features + self.num_layers = num_layers + self.featuremap_in = [1, 1, num_features] + self.featuremap_mi = [num_features, num_features, num_features] + self.featuremap_end = [num_features, num_features, num_features] + # self.layer0 = Gconv(self.featuremap_in, J) + self.layer0 = Gconv_new(self.featuremap_in, J) + for i in range(num_layers): + # module = Gconv(self.featuremap_mi, J) + module = Gconv_new(self.featuremap_mi, J) + self.add_module('layer{}'.format(i + 1), module) + self.layerlast = Gconv_last_bcd(self.featuremap_end, J, n_classes) + + def forward(self, input): + cur = self.layer0(input) + for i in range(self.num_layers): + cur = self._modules['layer{}'.format(i+1)](cur) + out = self.layerlast(cur) + return out[1] + +class GNN_multiclass(nn.Module): + def __init__(self, num_features, num_layers, J, n_classes=2): + super(GNN_multiclass, self).__init__() + self.num_features = num_features + self.num_layers = num_layers + self.featuremap_in = [1, 1, num_features] + self.featuremap_mi = [num_features, num_features, num_features] + self.featuremap_end = [num_features, num_features, num_features] + # self.layer0 = Gconv(self.featuremap_in, J) + self.layer0 = gnn_atomic(self.featuremap_in, J) + for i in range(num_layers): + # module = Gconv(self.featuremap_mi, J) + module = gnn_atomic(self.featuremap_mi, J) + self.add_module('layer{}'.format(i + 1), module) + self.layerlast = gnn_atomic_final(self.featuremap_end, J, n_classes) + + def forward(self, W, x): + cur = self.layer0(W, x) + # print ('layer0') + for i in range(self.num_layers): + # print ('layer', i+1) + cur = self._modules['layer{}'.format(i+1)](*cur) + out = self.layerlast(*cur) + return out[1] + + + +if __name__ == '__main__': + # test modules + bs = 4 + num_features = 10 + num_layers = 5 + N = 8 + x = torch.ones((bs, N, num_features)) + W1 = torch.eye(N).unsqueeze(0).unsqueeze(-1).expand(bs, N, N, 1) + W2 = torch.ones(N).unsqueeze(0).unsqueeze(-1).expand(bs, N, N, 1) + J = 2 + W = torch.cat((W1, W2), 3) + input = [Variable(W), Variable(x)] + ######################### test gmul ############################## + # feature_maps = [num_features, num_features, num_features] + # out = gmul(input) + # print(out[0, :, num_features:]) + ######################### test gconv ############################## + # feature_maps = [num_features, num_features, num_features] + # gconv = Gconv(feature_maps, J) + # _, out = gconv(input) + # print(out.size()) + ######################### test gnn ############################## + # x = torch.ones((bs, N, 1)) + # input = [Variable(W), Variable(x)] + # gnn = GNN(num_features, num_layers, J) + # out = gnn(input) + # print(out.size()) + ######################### test siamese gnn ############################## + x = torch.ones((bs, N, 1)) + input1 = [Variable(W), Variable(x)] + input2 = [Variable(W.clone()), Variable(x.clone())] + siamese_gnn = Siamese_GNN(num_features, num_layers, J) + out = siamese_gnn(input1, input2) + print(out.size()) + print(out) + + gnn = GNN_bcd(num_features, num_layers, 2) + out=gnn(input1) + print(out.size()) + print(out) diff --git a/src/script_disac5_gnn_yeD_otf_cp4.sh b/src/script_disac5_gnn_yeD_otf_cp4.sh new file mode 100644 index 0000000..6e72a9a --- /dev/null +++ b/src/script_disac5_gnn_yeD_otf_cp4.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +# all commands that start with SBATCH contain commands that are just used by SLURM forscheduling +################# +# set a job name +#SBATCH --job-name=81_5gnnyeD_30_8_1_3200 +################# +# a file for job output, you can check job progress +#SBATCH --output=disac5_gnn_yeD_otf_nl30_nf8_J1_ntr6000.out +################# +# a file for errors from the job +#SBATCH --error=disac5_gnn_yeD_otf_nl30_nf8_J1_ntr6000.err +################# +# time you think you need; default is one hour +# in minutes +# In this case, hh:mm:ss, select whatever time you want, the less you ask for the +# fasteryour job will run. +# Default is one hour, this example will run in less that 5 minutes. +#SBATCH --time=06:00:00 +################# +# --gres will give you one GPU, you can ask for more, up to 4 (or how ever many are on the node/card) +#SBATCH --gres gpu:1 +# We are submitting to the batch partition +#SBATCH --qos=batch +################# +#number of nodes you are requesting +#SBATCH --nodes=1 +################# +#memory per node; default is 4000 MB per CPU +#SBATCH --mem=100000 +################# +# Have SLURM send you an email when the job ends or fails, careful, the email could end up in your clutter folder +#SBATCH --mail-type=END,FAIL # notifications for job done & fail +#SBATCH --mail-user=zc1216@nyu.edu + +source activate py36 +python3 main_gnn_otf_cp.py \ +--path_logger '' \ +--path_gnn '' \ +--filename_existing_gnn '' \ +--num_examples_train 6000 \ +--num_examples_test 100 \ +--p_SBM 0.0 \ +--q_SBM 0.045 \ +--generative_model 'SBM_multiclass' \ +--batch_size 1 \ +--mode 'train' \ +--clip_grad_norm 40.0 \ +--num_features 8 \ +--num_layers 30 \ +--J 1 \ +--N_train 400 \ +--N_test 400 \ +--print_freq 1 \ +--n_classes 5 \ +--lr 0.004 diff --git a/src/script_disac5_lgnn_yeD_luaNB_cp4.sh b/src/script_disac5_lgnn_yeD_luaNB_cp4.sh new file mode 100644 index 0000000..8372358 --- /dev/null +++ b/src/script_disac5_lgnn_yeD_luaNB_cp4.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +# all commands that start with SBATCH contain commands that are just used by SLURM forscheduling +################# +# set a job name +#SBATCH --job-name=J2NB2otf_5YlgNf8J3RS_J2_nl30_ntr5000_RS +################# +# a file for job output, you can check job progress +#SBATCH --output=disac5_lgnn_yeD_luaNB2_nl30_nf8_J2_ntr5000_RS42.out +################# +# a file for errors from the job +#SBATCH --error=disac5_lgnn_yeD_luaNB2_nl30_nf8_J2_ntr5000_RS42.err +################# +# time you think you need; default is one hour +# in minutes +# In this case, hh:mm:ss, select whatever time you want, the less you ask for the +# fasteryour job will run. +# Default is one hour, this example will run in less that 5 minutes. +#SBATCH --time=6-23:00:00 +################# +# --gres will give you one GPU, you can ask for more, up to 4 (or how ever many are on the node/card) +#SBATCH --gres gpu:1 +# We are submitting to the batch partition +#SBATCH --qos=batch +################# +#number of nodes you are requesting +#SBATCH --nodes=1 +################# +#memory per node; default is 4000 MB per CPU +#SBATCH --mem=100000 +#SBATCH --constraint=gpu_12gb +################# +# Have SLURM send you an email when the job ends or fails, careful, the email could end up in your clutter folder +#SBATCH --mail-type=END,FAIL # notifications for job done & fail +#SBATCH --mail-user=zc1216@nyu.edu + +source activate py36 +python3 main_lg_lua_otf_cp.py \ +--path_logger '' \ +--path_gnn '' \ +--filename_existing_gnn '' \ +--num_examples_train 6000 \ +--num_examples_test 300 \ +--p_SBM 0.0 \ +--q_SBM 0.045 \ +--generative_model 'SBM_multiclass' \ +--batch_size 1 \ +--mode 'train' \ +--clip_grad_norm 40.0 \ +--num_features 8 \ +--num_layers 30 \ +--J 2 \ +--N_train 400 \ +--N_test 400 \ +--print_freq 1 \ +--n_classes 5 \ +--lr 0.004