-
Notifications
You must be signed in to change notification settings - Fork 151
/
mgcn_predictor.py
84 lines (73 loc) · 3.24 KB
/
mgcn_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# MGCN
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
from ..gnn import MGCNGNN
from ..readout import MLPNodeReadout
__all__ = ['MGCNPredictor']
# pylint: disable=W0221
class MGCNPredictor(nn.Module):
"""MGCN for for regression and classification on graphs.
MGCN is introduced in `Molecular Property Prediction: A Multilevel Quantum Interactions
Modeling Perspective <https://arxiv.org/abs/1906.11081>`__.
Parameters
----------
feats : int
Size for the node and edge embeddings to learn. Default to 128.
n_layers : int
Number of gnn layers to use. Default to 3.
classifier_hidden_feats : int
(Deprecated, see ``predictor_hidden_feats``) Size for hidden
representations in the classifier. Default to 64.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
num_node_types : int
Number of node types to embed. Default to 100.
num_edge_types : int
Number of edge types to embed. Default to 3000.
cutoff : float
Largest center in RBF expansion. Default to 5.0
gap : float
Difference between two adjacent centers in RBF expansion. Default to 1.0
predictor_hidden_feats : int
Size for hidden representations in the output MLP predictor. Default to 64.
"""
def __init__(self, feats=128, n_layers=3, classifier_hidden_feats=64,
n_tasks=1, num_node_types=100, num_edge_types=3000,
cutoff=5.0, gap=1.0, predictor_hidden_feats=64):
super(MGCNPredictor, self).__init__()
if predictor_hidden_feats == 64 and classifier_hidden_feats != 64:
print('classifier_hidden_feats is deprecated and will be removed in the future, '
'use predictor_hidden_feats instead')
predictor_hidden_feats = classifier_hidden_feats
self.gnn = MGCNGNN(feats=feats,
n_layers=n_layers,
num_node_types=num_node_types,
num_edge_types=num_edge_types,
cutoff=cutoff,
gap=gap)
self.readout = MLPNodeReadout(node_feats=(n_layers + 1) * feats,
hidden_feats=predictor_hidden_feats,
graph_feats=n_tasks,
activation=nn.Softplus(beta=1, threshold=20))
def forward(self, g, node_types, edge_dists):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
float32 tensor of shape (G, n_tasks)
Prediction for the graphs in the batch. G for the number of graphs.
"""
node_feats = self.gnn(g, node_types, edge_dists)
return self.readout(g, node_feats)