-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit b1afea1
Showing
21 changed files
with
3,214 additions
and
0 deletions.
There are no files selected for viewing
Binary file added
BIN
+3.68 MB
...Variance Sampling with Provable Guarantees for Fast Training of Graph Neural Networks.pdf
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Run `run_ppi.sh` for experiments on PPI dataset. | ||
|
||
The same for other datasets. | ||
|
||
Create a folder "./data" by `mkdir data` then download data into this folder. | ||
|
||
Datasets can be download from https://drive.google.com/drive/folders/1qrFuQOxrbaDziJFeEkpAiXmL_C8dlk3K?usp=sharing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from utils import * | ||
import argparse | ||
""" | ||
Dataset arguments | ||
""" | ||
parser = argparse.ArgumentParser( | ||
description='Training GCN on Large-scale Graph Datasets') | ||
|
||
parser.add_argument('--dataset', type=str, default='yelp', | ||
help='Dataset name: cora/citeseer/pubmed/flickr/reddit/ppi/ppi-large') | ||
parser.add_argument('--nhid', type=int, default=256, | ||
help='Hidden state dimension') | ||
parser.add_argument('--epoch_num', type=int, default=400, | ||
help='Number of Epoch') | ||
parser.add_argument('--pool_num', type=int, default=10, | ||
help='Number of Pool') | ||
parser.add_argument('--batch_num', type=int, default=20, | ||
help='Maximum Batch Number') | ||
parser.add_argument('--batch_size', type=int, default=512, | ||
help='size of output node in a batch') | ||
parser.add_argument('--n_layers', type=int, default=2, | ||
help='Number of GCN layers') | ||
parser.add_argument('--n_stops', type=int, default=1000, | ||
help='Stop after number of batches that f1 dont increase') | ||
parser.add_argument('--samp_num', type=int, default=512, | ||
help='Number of sampled nodes per layer (only for ladies & factgcn)') | ||
parser.add_argument('--dropout', type=float, default=0.1, | ||
help='Dropout rate') | ||
parser.add_argument('--cuda', type=int, default=-1, | ||
help='Avaiable GPU ID') | ||
parser.add_argument('--is_ratio', type=float, default=1.0, | ||
help='Importance sampling rate') | ||
parser.add_argument('--show_grad_norm', type=int, default=0, | ||
help='Whether show gradient norm 0-False, 1-True') | ||
parser.add_argument('--cluster_bsize', type=int, default=5, | ||
help='how many cluster selected each mini-batch') | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
# pubmed n_layers 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
|
||
""" | ||
Only use the last layer | ||
""" | ||
def capture_activations(layer, inputs, outputs): | ||
setattr(layer, "activations", inputs[0].detach()) | ||
|
||
def capture_backprops(layer, inputs, outputs): | ||
setattr(layer, "backprops", outputs[0].detach()) | ||
|
||
def calculate_sample_grad(layer): | ||
A = layer.activations | ||
B = layer.backprops | ||
|
||
n = A.shape[0] | ||
B = B * n | ||
weight_grad = torch.einsum('ni,nj->nij', B, A) | ||
bias_grad = B | ||
grad_norm = torch.sqrt(weight_grad.norm(p=2, dim=(1,2)).pow(2) + bias_grad.norm(p=2, dim=1).pow(2)).squeeze().detach() | ||
return grad_norm | ||
|
||
""" | ||
Use all layers | ||
""" | ||
def capture_activations_exact(layer, inputs, outputs): | ||
setattr(layer, "activations", inputs[0].detach()) | ||
|
||
def capture_backprops_exact(layer, inputs, outputs): | ||
if not hasattr(layer, 'backprops_list'): | ||
setattr(layer, 'backprops_list', []) | ||
layer.backprops_list.append(outputs[0].detach()) | ||
|
||
def add_hooks(model): | ||
for layer in model.modules(): | ||
if layer.__class__.__name__=='Linear': | ||
layer.register_forward_hook(capture_activations_exact) | ||
layer.register_backward_hook(capture_backprops_exact) | ||
|
||
def calculate_exact_sample_grad(model): | ||
grad_norm_sum = None | ||
for layer in model.modules(): | ||
if layer.__class__.__name__!='Linear': | ||
continue | ||
A = layer.activations | ||
n = A.shape[0] | ||
|
||
B = layer.backprops_list[0] | ||
B = B*n | ||
|
||
weight_grad = torch.einsum('ni,nj->nij', B, A) | ||
bias_grad = B | ||
grad_norm = weight_grad.norm(p=2, dim=(1,2)).pow(2) + bias_grad.norm(p=2, dim=1).pow(2) | ||
if grad_norm_sum is None: | ||
grad_norm_sum = grad_norm | ||
else: | ||
grad_norm_sum += grad_norm | ||
grad_norm = torch.sqrt(grad_norm_sum).squeeze().detach() | ||
return grad_norm | ||
|
||
def del_backprops(model): | ||
for layer in model.modules(): | ||
if hasattr(layer, 'backprops_list'): | ||
del layer.backprops_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from utils import * | ||
import autograd_wl | ||
""" | ||
Wrapper for variance reduction opts | ||
""" | ||
|
||
class ForwardWrapper(nn.Module): | ||
def __init__(self, n_nodes, n_hid, n_layers, n_classes, concat=False): | ||
super(ForwardWrapper, self).__init__() | ||
self.n_layers = n_layers | ||
if concat: | ||
self.hiddens = torch.zeros(n_layers, n_nodes, 2*n_hid) | ||
else: | ||
self.hiddens = torch.zeros(n_layers, n_nodes, n_hid) | ||
|
||
def forward_full(self, net, x, adjs, sampled_nodes): | ||
for ell in range(len(net.gcs)): | ||
x = net.gcs[ell](x, adjs[ell]) | ||
self.hiddens[ell,sampled_nodes[ell]] = x.cpu().detach() | ||
x = net.relu(x) | ||
x = net.dropout(x) | ||
|
||
x = net.gc_out(x) | ||
return x | ||
|
||
def forward_mini(self, net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes): | ||
cached_outputs = [] | ||
for ell in range(len(net.gcs)): | ||
x_bar = x if ell == 0 else net.dropout(net.relu(self.hiddens[ell-1,sampled_nodes[ell-1]].to(x))) | ||
x_bar_exact = x_exact[input_exact_nodes[ell]] if ell == 0 else net.dropout(net.relu(self.hiddens[ell-1,input_exact_nodes[ell]].to(x))) | ||
x = net.gcs[ell](x, adjs[ell]) - net.gcs[ell](x_bar, adjs[ell]) + net.gcs[ell](x_bar_exact, adjs_exact[ell]) | ||
cached_outputs += [x.detach().cpu()] | ||
x = net.relu(x) | ||
x = net.dropout(x) | ||
|
||
x = net.gc_out(x) | ||
|
||
for ell in range(len(net.gcs)): | ||
self.hiddens[ell, sampled_nodes[ell]] = cached_outputs[ell] | ||
return x | ||
|
||
def calculate_sample_grad(self, net, x, adjs, sampled_nodes, targets, batch_nodes): | ||
outputs = self.forward_full(net, x, adjs, sampled_nodes) | ||
loss = net.loss_f(outputs, targets[batch_nodes]) | ||
loss.backward() | ||
grad_per_sample = autograd_wl.calculate_sample_grad(net.gc_out) | ||
return grad_per_sample.cpu().numpy() | ||
|
||
def partial_grad(self, net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes, targets, weight=None): | ||
outputs = self.forward_mini(net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes) | ||
if weight is None: | ||
loss = net.loss_f(outputs, targets) | ||
else: | ||
if net.multi_class: | ||
loss = net.loss_f_vec(outputs, targets) | ||
loss = loss.mean(1) * weight | ||
else: | ||
loss = net.loss_f_vec(outputs, targets) * weight | ||
loss = loss.sum() | ||
loss.backward() | ||
return loss.detach() | ||
|
||
def partial_grad_with_norm(self, net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes, targets, weight): | ||
num_samples = targets.size(0) | ||
outputs = self.forward_mini(net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes) | ||
|
||
if net.multi_class: | ||
loss = net.loss_f_vec(outputs, targets) | ||
loss = loss.mean(1) * weight | ||
else: | ||
loss = net.loss_f_vec(outputs, targets) * weight | ||
loss = loss.sum() | ||
loss.backward() | ||
grad_per_sample = autograd_wl.calculate_sample_grad(net.gc_out) | ||
|
||
grad_per_sample = grad_per_sample*(1/weight/num_samples) | ||
return loss.detach(), grad_per_sample.cpu().numpy() | ||
|
||
class ForwardWrapper_v2(nn.Module): | ||
def __init__(self, n_nodes, n_hid, n_layers, n_classes, concat=False): | ||
super(ForwardWrapper_v2, self).__init__() | ||
self.n_layers = n_layers | ||
if concat: | ||
self.hiddens = torch.zeros(n_layers, n_nodes, 2*n_hid) | ||
else: | ||
self.hiddens = torch.zeros(n_layers, n_nodes, n_hid) | ||
|
||
def forward_full(self, net, x, adjs, sampled_nodes): | ||
for ell in range(len(net.gcs)): | ||
x = net.gcs[ell](x, adjs[ell]) | ||
self.hiddens[ell,sampled_nodes[ell]] = x.cpu().detach() | ||
x = net.dropout(net.relu(x)) | ||
|
||
x = net.gc_out(x) | ||
return x | ||
|
||
def forward_mini(self, net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes): | ||
cached_outputs = [] | ||
for ell in range(len(net.gcs)): | ||
x_bar_exact = x_exact[input_exact_nodes[ell]] if ell == 0 else net.dropout(net.relu(self.hiddens[ell-1,input_exact_nodes[ell]].to(x))) | ||
x = torch.cat([x, x_bar_exact], dim=0) | ||
x = net.gcs[ell](x, adjs_exact[ell]) | ||
cached_outputs += [x.detach().cpu()] | ||
x = net.dropout(net.relu(x)) | ||
|
||
x = net.gc_out(x) | ||
|
||
for ell in range(len(net.gcs)): | ||
self.hiddens[ell, sampled_nodes[ell]] = cached_outputs[ell] | ||
return x | ||
|
||
def calculate_sample_grad(self, net, x, adjs, sampled_nodes, targets, batch_nodes): | ||
outputs = self.forward_full(net, x, adjs, sampled_nodes) | ||
loss = net.loss_f(outputs, targets[batch_nodes]) | ||
loss.backward() | ||
grad_per_sample = autograd_wl.calculate_sample_grad(net.gc_out) | ||
return grad_per_sample.cpu().numpy() | ||
|
||
def partial_grad(self, net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes, targets, weight=None): | ||
outputs = self.forward_mini(net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes) | ||
if weight is None: | ||
loss = net.loss_f(outputs, targets) | ||
else: | ||
if net.multi_class: | ||
loss = net.loss_f_vec(outputs, targets) | ||
loss = loss.mean(1) * weight | ||
else: | ||
loss = net.loss_f_vec(outputs, targets) * weight | ||
loss = loss.sum() | ||
loss.backward() | ||
return loss.detach() | ||
|
||
def partial_grad_with_norm(self, net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes, targets, weight): | ||
num_samples = targets.size(0) | ||
outputs = self.forward_mini(net, x, adjs, sampled_nodes, x_exact, adjs_exact, input_exact_nodes) | ||
|
||
if net.multi_class: | ||
loss = net.loss_f_vec(outputs, targets) | ||
loss = loss.mean(1) * weight | ||
else: | ||
loss = net.loss_f_vec(outputs, targets) * weight | ||
loss = loss.sum() | ||
loss.backward() | ||
grad_per_sample = autograd_wl.calculate_sample_grad(net.gc_out) | ||
|
||
grad_per_sample = grad_per_sample*(1/weight/num_samples) | ||
return loss.detach(), grad_per_sample.cpu().numpy() |
Oops, something went wrong.