-
Notifications
You must be signed in to change notification settings - Fork 151
/
alchemy.py
301 lines (259 loc) · 11.4 KB
/
alchemy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Tencent alchemy Dataset https://alchemy.tencent.com/
import numpy as np
import os
import os.path as osp
import pandas as pd
import pathlib
import zipfile
from collections import defaultdict
from dgl import backend as F
from dgl.data.utils import download, get_download_dir, _get_dgl_url, save_graphs, load_graphs
from rdkit import Chem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
from ..utils.mol_to_graph import mol_to_complete_graph
from ..utils.featurizers import atom_type_one_hot, atom_hybridization_one_hot, atom_is_aromatic
__all__ = ['TencentAlchemyDataset']
def alchemy_nodes(mol):
"""Featurization for all atoms in a molecule. The atom indices
will be preserved.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
Returns
-------
atom_feats_dict : dict
Dictionary for atom features
"""
atom_feats_dict = defaultdict(list)
is_donor = defaultdict(int)
is_acceptor = defaultdict(int)
fdef_name = osp.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
mol_featurizer = ChemicalFeatures.BuildFeatureFactory(fdef_name)
mol_feats = mol_featurizer.GetFeaturesForMol(mol)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
for i in range(len(mol_feats)):
if mol_feats[i].GetFamily() == 'Donor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_donor[u] = 1
elif mol_feats[i].GetFamily() == 'Acceptor':
node_list = mol_feats[i].GetAtomIds()
for u in node_list:
is_acceptor[u] = 1
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
atom = mol.GetAtomWithIdx(u)
atom_type = atom.GetAtomicNum()
num_h = atom.GetTotalNumHs()
atom_feats_dict['node_type'].append(atom_type)
h_u = []
h_u += atom_type_one_hot(atom, ['H', 'C', 'N', 'O', 'F', 'S', 'Cl'])
h_u.append(atom_type)
h_u.append(is_acceptor[u])
h_u.append(is_donor[u])
h_u += atom_is_aromatic(atom)
h_u += atom_hybridization_one_hot(atom, [Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3])
h_u.append(num_h)
atom_feats_dict['n_feat'].append(F.tensor(np.array(h_u).astype(np.float32)))
atom_feats_dict['n_feat'] = F.stack(atom_feats_dict['n_feat'], dim=0)
atom_feats_dict['node_type'] = F.tensor(np.array(
atom_feats_dict['node_type']).astype(np.int64))
return atom_feats_dict
def alchemy_edges(mol, self_loop=False):
"""Featurization for all bonds in a molecule.
The bond indices will be preserved.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
RDKit molecule object
self_loop : bool
Whether to add self loops. Default to be False.
Returns
-------
bond_feats_dict : dict
Dictionary for bond features
"""
bond_feats_dict = defaultdict(list)
mol_conformers = mol.GetConformers()
assert len(mol_conformers) == 1
geom = mol_conformers[0].GetPositions()
num_atoms = mol.GetNumAtoms()
for u in range(num_atoms):
for v in range(num_atoms):
if u == v and not self_loop:
continue
e_uv = mol.GetBondBetweenAtoms(u, v)
if e_uv is None:
bond_type = None
else:
bond_type = e_uv.GetBondType()
bond_feats_dict['e_feat'].append([
float(bond_type == x)
for x in (Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC, None)
])
bond_feats_dict['distance'].append(
np.linalg.norm(geom[u] - geom[v]))
bond_feats_dict['e_feat'] = F.tensor(
np.array(bond_feats_dict['e_feat']).astype(np.float32))
bond_feats_dict['distance'] = F.tensor(
np.array(bond_feats_dict['distance']).astype(np.float32)).reshape(-1 , 1)
return bond_feats_dict
class TencentAlchemyDataset(object):
"""
Developed by the Tencent Quantum Lab, the dataset lists 12 quantum mechanical
properties of 130, 000+ organic molecules, comprising up to 12 heavy atoms
(C, N, O, S, F and Cl), sampled from the GDBMedChem database. These properties
have been calculated using the open-source computational chemistry program
Python-based Simulation of Chemistry Framework (PySCF).
For more details, check the `paper <https://arxiv.org/abs/1906.09427>`__.
Parameters
----------
mode : str
'dev', 'valid' or 'test', separately for training, validation and test.
Default to be 'dev'. Note that 'test' is not available as the alchemy
contest is ongoing.
mol_to_graph: callable, str -> DGLGraph
A function turning an RDKit molecule instance into a DGLGraph.
Default to :func:`dgllife.utils.mol_to_complete_graph`.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. By default, we construct graphs where nodes represent atoms
and node features represent atom features. We store the atomic numbers under the
name ``"node_type"`` and store the atom features under the name ``"n_feat"``.
The atom features include:
* One hot encoding for atom types
* Atomic number of atoms
* Whether the atom is a donor
* Whether the atom is an acceptor
* Whether the atom is aromatic
* One hot encoding for atom hybridization
* Total number of Hs on the atom
edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. By default, we construct edges between every pair of atoms,
excluding the self loops. We store the distance between the end atoms under the name
``"distance"`` and store the edge features under the name ``"e_feat"``. The edge
features represent one hot encoding of edge types (bond types and non-bond edges).
load : bool
Whether to load the previously pre-processed dataset or pre-process from scratch.
``load`` should be False when we want to try different graph construction and
featurization methods and need to preprocess from scratch. Default to True.
"""
def __init__(self, mode='dev',
mol_to_graph=mol_to_complete_graph,
node_featurizer=alchemy_nodes,
edge_featurizer=alchemy_edges,
load=True):
if mode == 'test':
raise ValueError('The test mode is not supported before '
'the alchemy contest finishes.')
assert mode in ['dev', 'valid', 'test'], \
'Expect mode to be dev, valid or test, got {}.'.format(mode)
self.mode = mode
# Construct DGLGraphs from raw data or use the preprocessed data
self.load = load
file_dir = osp.join(get_download_dir(), 'Alchemy_data')
if load:
file_name = "{}_processed_dgl".format(mode)
else:
file_name = "{}_single_sdf".format(mode)
self.file_dir = pathlib.Path(file_dir, file_name)
self._url = 'dataset/alchemy/'
self.zip_file_path = pathlib.Path(file_dir, file_name + '.zip')
download(_get_dgl_url(self._url + file_name + '.zip'),
path=str(self.zip_file_path), overwrite=False)
if not os.path.exists(str(self.file_dir)):
archive = zipfile.ZipFile(self.zip_file_path)
archive.extractall(file_dir)
archive.close()
self._load(mol_to_graph, node_featurizer, edge_featurizer)
def _load(self, mol_to_graph, node_featurizer, edge_featurizer):
if self.load:
self.graphs, label_dict = load_graphs(osp.join(self.file_dir, "{}_graphs.bin".format(self.mode)))
self.labels = label_dict['labels']
with open(osp.join(self.file_dir, "{}_smiles.txt".format(self.mode)), 'r') as f:
smiles_ = f.readlines()
self.smiles = [s.strip() for s in smiles_]
else:
print('Start preprocessing dataset...')
target_file = pathlib.Path(self.file_dir, "{}_target.csv".format(self.mode))
self.target = pd.read_csv(
target_file,
index_col=0,
usecols=['gdb_idx',] + ['property_{:d}'.format(x) for x in range(12)])
self.target = self.target[['property_{:d}'.format(x) for x in range(12)]]
self.graphs, self.labels, self.smiles = [], [], []
supp = Chem.SDMolSupplier(osp.join(self.file_dir, self.mode + ".sdf"))
cnt = 0
dataset_size = len(self.target)
for mol, label in zip(supp, self.target.iterrows()):
cnt += 1
print('Processing molecule {:d}/{:d}'.format(cnt, dataset_size))
graph = mol_to_graph(mol, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer)
smiles = Chem.MolToSmiles(mol)
self.smiles.append(smiles)
self.graphs.append(graph)
label = F.tensor(np.array(label[1].tolist()).astype(np.float32))
self.labels.append(label)
save_graphs(osp.join(self.file_dir, "{}_graphs.bin".format(self.mode)), self.graphs,
labels={'labels': F.stack(self.labels, dim=0)})
with open(osp.join(self.file_dir, "{}_smiles.txt".format(self.mode)), 'w') as f:
for s in self.smiles:
f.write(s + '\n')
self.set_mean_and_std()
print(len(self.graphs), "loaded!")
def __getitem__(self, item):
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
Tensor of dtype float32 and shape (T)
Labels of the datapoint for all tasks.
"""
return self.smiles[item], self.graphs[item], self.labels[item]
def __len__(self):
"""Size for the dataset.
Returns
-------
int
Size for the dataset.
"""
return len(self.graphs)
def set_mean_and_std(self, mean=None, std=None):
"""Set mean and std or compute from labels for future normalization.
The mean and std can be fetched later with ``self.mean`` and ``self.std``.
Parameters
----------
mean : float32 tensor of shape (T)
Mean of labels for all tasks.
std : float32 tensor of shape (T)
Std of labels for all tasks.
"""
labels = np.array([i.numpy() for i in self.labels])
if mean is None:
mean = np.mean(labels, axis=0)
if std is None:
std = np.std(labels, axis=0)
self.mean = mean
self.std = std