Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
CongWeilin committed Feb 24, 2020
0 parents commit b1afea1
Show file tree
Hide file tree
Showing 21 changed files with 3,214 additions and 0 deletions.
Binary file not shown.
7 changes: 7 additions & 0 deletions README
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Run `run_ppi.sh` for experiments on PPI dataset.

The same for other datasets.

Create a folder "./data" by `mkdir data` then download data into this folder.

Datasets can be download from https://drive.google.com/drive/folders/1qrFuQOxrbaDziJFeEkpAiXmL_C8dlk3K?usp=sharing
40 changes: 40 additions & 0 deletions args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from utils import *
import argparse
"""
Dataset arguments
"""
parser = argparse.ArgumentParser(
description='Training GCN on Large-scale Graph Datasets')

parser.add_argument('--dataset', type=str, default='yelp',
help='Dataset name: cora/citeseer/pubmed/flickr/reddit/ppi/ppi-large')
parser.add_argument('--nhid', type=int, default=256,
help='Hidden state dimension')
parser.add_argument('--epoch_num', type=int, default=400,
help='Number of Epoch')
parser.add_argument('--pool_num', type=int, default=10,
help='Number of Pool')
parser.add_argument('--batch_num', type=int, default=20,
help='Maximum Batch Number')
parser.add_argument('--batch_size', type=int, default=512,
help='size of output node in a batch')
parser.add_argument('--n_layers', type=int, default=2,
help='Number of GCN layers')
parser.add_argument('--n_stops', type=int, default=1000,
help='Stop after number of batches that f1 dont increase')
parser.add_argument('--samp_num', type=int, default=512,
help='Number of sampled nodes per layer (only for ladies & factgcn)')
parser.add_argument('--dropout', type=float, default=0.1,
help='Dropout rate')
parser.add_argument('--cuda', type=int, default=-1,
help='Avaiable GPU ID')
parser.add_argument('--is_ratio', type=float, default=1.0,
help='Importance sampling rate')
parser.add_argument('--show_grad_norm', type=int, default=0,
help='Whether show gradient norm 0-False, 1-True')
parser.add_argument('--cluster_bsize', type=int, default=5,
help='how many cluster selected each mini-batch')
args = parser.parse_args()
print(args)

# pubmed n_layers 2
64 changes: 64 additions & 0 deletions autograd_wl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch

"""
Only use the last layer
"""
def capture_activations(layer, inputs, outputs):
setattr(layer, "activations", inputs[0].detach())

def capture_backprops(layer, inputs, outputs):
setattr(layer, "backprops", outputs[0].detach())

def calculate_sample_grad(layer):
A = layer.activations
B = layer.backprops

n = A.shape[0]
B = B * n
weight_grad = torch.einsum('ni,nj->nij', B, A)
bias_grad = B
grad_norm = torch.sqrt(weight_grad.norm(p=2, dim=(1,2)).pow(2) + bias_grad.norm(p=2, dim=1).pow(2)).squeeze().detach()
return grad_norm

"""
Use all layers
"""
def capture_activations_exact(layer, inputs, outputs):
setattr(layer, "activations", inputs[0].detach())

def capture_backprops_exact(layer, inputs, outputs):
if not hasattr(layer, 'backprops_list'):
setattr(layer, 'backprops_list', [])
layer.backprops_list.append(outputs[0].detach())

def add_hooks(model):
for layer in model.modules():
if layer.__class__.__name__=='Linear':
layer.register_forward_hook(capture_activations_exact)
layer.register_backward_hook(capture_backprops_exact)

def calculate_exact_sample_grad(model):
grad_norm_sum = None
for layer in model.modules():
if layer.__class__.__name__!='Linear':
continue
A = layer.activations
n = A.shape[0]

B = layer.backprops_list[0]
B = B*n

weight_grad = torch.einsum('ni,nj->nij', B, A)
bias_grad = B
grad_norm = weight_grad.norm(p=2, dim=(1,2)).pow(2) + bias_grad.norm(p=2, dim=1).pow(2)
if grad_norm_sum is None:
grad_norm_sum = grad_norm
else:
grad_norm_sum += grad_norm
grad_norm = torch.sqrt(grad_norm_sum).squeeze().detach()
return grad_norm

def del_backprops(model):
for layer in model.modules():
if hasattr(layer, 'backprops_list'):
del layer.backprops_list
147 changes: 147 additions & 0 deletions forward_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from utils import *
import autograd_wl
"""
Wrapper for variance reduction opts
"""

class ForwardWrapper(nn.Module):
def __init__(self, n_nodes, n_hid, n_layers, n_classes, concat=False):
super(ForwardWrapper, self).__init__()
self.n_layers = n_layers
if concat:
self.hiddens = torch.zeros(n_layers, n_nodes, 2*n_hid)
else:
self.hiddens = torch.zeros(n_layers, n_nodes, n_hid)

def forward_full(self, net, x, adjs, sampled_nodes):
for ell in range(len(net.gcs)):
x = net.gcs[ell](x, adjs[ell])
self.hiddens[ell,sampled_nodes[ell]] = x.cpu().detach()
x = net.relu(x)
x = net.dropout(x)

x = net.gc_out(x)
return x

def forward_mini(self, net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes):
cached_outputs = []
for ell in range(len(net.gcs)):
x_bar = x if ell == 0 else net.dropout(net.relu(self.hiddens[ell-1,sampled_nodes[ell-1]].to(x)))
x_bar_exact = x_exact[input_exact_nodes[ell]] if ell == 0 else net.dropout(net.relu(self.hiddens[ell-1,input_exact_nodes[ell]].to(x)))
x = net.gcs[ell](x, adjs[ell]) - net.gcs[ell](x_bar, adjs[ell]) + net.gcs[ell](x_bar_exact, adjs_exact[ell])
cached_outputs += [x.detach().cpu()]
x = net.relu(x)
x = net.dropout(x)

x = net.gc_out(x)

for ell in range(len(net.gcs)):
self.hiddens[ell, sampled_nodes[ell]] = cached_outputs[ell]
return x

def calculate_sample_grad(self, net, x, adjs, sampled_nodes, targets, batch_nodes):
outputs = self.forward_full(net, x, adjs, sampled_nodes)
loss = net.loss_f(outputs, targets[batch_nodes])
loss.backward()
grad_per_sample = autograd_wl.calculate_sample_grad(net.gc_out)
return grad_per_sample.cpu().numpy()

def partial_grad(self, net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes, targets, weight=None):
outputs = self.forward_mini(net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes)
if weight is None:
loss = net.loss_f(outputs, targets)
else:
if net.multi_class:
loss = net.loss_f_vec(outputs, targets)
loss = loss.mean(1) * weight
else:
loss = net.loss_f_vec(outputs, targets) * weight
loss = loss.sum()
loss.backward()
return loss.detach()

def partial_grad_with_norm(self, net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes, targets, weight):
num_samples = targets.size(0)
outputs = self.forward_mini(net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes)

if net.multi_class:
loss = net.loss_f_vec(outputs, targets)
loss = loss.mean(1) * weight
else:
loss = net.loss_f_vec(outputs, targets) * weight
loss = loss.sum()
loss.backward()
grad_per_sample = autograd_wl.calculate_sample_grad(net.gc_out)

grad_per_sample = grad_per_sample*(1/weight/num_samples)
return loss.detach(), grad_per_sample.cpu().numpy()

class ForwardWrapper_v2(nn.Module):
def __init__(self, n_nodes, n_hid, n_layers, n_classes, concat=False):
super(ForwardWrapper_v2, self).__init__()
self.n_layers = n_layers
if concat:
self.hiddens = torch.zeros(n_layers, n_nodes, 2*n_hid)
else:
self.hiddens = torch.zeros(n_layers, n_nodes, n_hid)

def forward_full(self, net, x, adjs, sampled_nodes):
for ell in range(len(net.gcs)):
x = net.gcs[ell](x, adjs[ell])
self.hiddens[ell,sampled_nodes[ell]] = x.cpu().detach()
x = net.dropout(net.relu(x))

x = net.gc_out(x)
return x

def forward_mini(self, net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes):
cached_outputs = []
for ell in range(len(net.gcs)):
x_bar_exact = x_exact[input_exact_nodes[ell]] if ell == 0 else net.dropout(net.relu(self.hiddens[ell-1,input_exact_nodes[ell]].to(x)))
x = torch.cat([x, x_bar_exact], dim=0)
x = net.gcs[ell](x, adjs_exact[ell])
cached_outputs += [x.detach().cpu()]
x = net.dropout(net.relu(x))

x = net.gc_out(x)

for ell in range(len(net.gcs)):
self.hiddens[ell, sampled_nodes[ell]] = cached_outputs[ell]
return x

def calculate_sample_grad(self, net, x, adjs, sampled_nodes, targets, batch_nodes):
outputs = self.forward_full(net, x, adjs, sampled_nodes)
loss = net.loss_f(outputs, targets[batch_nodes])
loss.backward()
grad_per_sample = autograd_wl.calculate_sample_grad(net.gc_out)
return grad_per_sample.cpu().numpy()

def partial_grad(self, net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes, targets, weight=None):
outputs = self.forward_mini(net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes)
if weight is None:
loss = net.loss_f(outputs, targets)
else:
if net.multi_class:
loss = net.loss_f_vec(outputs, targets)
loss = loss.mean(1) * weight
else:
loss = net.loss_f_vec(outputs, targets) * weight
loss = loss.sum()
loss.backward()
return loss.detach()

def partial_grad_with_norm(self, net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes, targets, weight):
num_samples = targets.size(0)
outputs = self.forward_mini(net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes)

if net.multi_class:
loss = net.loss_f_vec(outputs, targets)
loss = loss.mean(1) * weight
else:
loss = net.loss_f_vec(outputs, targets) * weight
loss = loss.sum()
loss.backward()
grad_per_sample = autograd_wl.calculate_sample_grad(net.gc_out)

grad_per_sample = grad_per_sample*(1/weight/num_samples)
return loss.detach(), grad_per_sample.cpu().numpy()
Loading

0 comments on commit b1afea1

Please sign in to comment.