-
Notifications
You must be signed in to change notification settings - Fork 2
/
data_preparation_banana.py
333 lines (289 loc) · 11.3 KB
/
data_preparation_banana.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import pandas as pd
from rdkit import Chem
from rdkit import RDLogger
from rdkit.Chem import AllChem
from rdkit.Chem import Descriptors
logger = RDLogger.logger()
logger.setLevel(RDLogger.CRITICAL)
def get_mol_fragments(mol, add_dummy=True, remove_label=True):
# generate aromatic fragments and unique aliphatic atoms from a rdkit Molecule.
# mol: fragmented molecule
# add_dummy: keep subsitution sites as dummy atom (label: *). Subsitution sites on aliphatic atoms are always removed.
# remove_label: remove number of subsitution site label.
split_bonds = set()
split_pattern_list = [
Chem.MolFromSmarts("*-*"), # All single bonds
Chem.MolFromSmarts("A=A"), # All aliphatic double bonds
Chem.MolFromSmarts("A#A"), # All aliphatic triple bonds
]
# determine unique bond indices
for split_pattern in split_pattern_list:
split_bonds.update(
{
mol.GetBondBetweenAtoms(a1, a2).GetIdx()
for a1, a2 in mol.GetSubstructMatches(split_pattern)
}
)
# if no bonds are selected.
if not split_bonds:
return [Chem.MolToSmiles(mol)]
# fragment on bonds and extract fragments
fragmented_mol = Chem.FragmentOnBonds(mol, split_bonds, addDummies=add_dummy)
fragment_list = list(Chem.GetMolFrags(fragmented_mol, asMols=True))
# remove number from subsitution site label, if requested
if remove_label:
for fragment in fragment_list:
for atom in fragment.GetAtoms():
atom.SetIsotope(0)
# for aliphatic atoms with substitution sites: remove substitution site
subsitution_site_aliphatic = Chem.MolFromSmarts("[$([#0]~[A])]")
final_fragment_list = []
for fragment in fragment_list:
fragment = AllChem.DeleteSubstructs(fragment, subsitution_site_aliphatic)
try:
Chem.SanitizeMol(fragment)
fragment = Chem.RemoveHs(fragment)
except Exception as err:
display(fragment)
raise err
final_fragment_list.append(fragment)
return [Chem.MolToSmiles(fragment) for fragment in final_fragment_list]
odor_list = [
"alliaceous",
"almond",
"amber",
"animal",
"apple",
"apricot",
"balsamic",
"banana",
"berry",
"camphoraceous",
"cherry",
"cinnamyl",
"citrus",
"coconut",
"coffee",
"garlic",
"grape",
"jasmine",
"lemon",
"lily",
"melon",
"peach",
"pear",
"pine",
"pineapple",
"raspberry",
"vanilla",
]
aroma_df = pd.read_csv("data/SMILES_odor_mapping.tsv", sep="\t").query(
"odor.isin(@odor_list)"
)
# removing stereochemistry
aroma_df["nonstereo_smiles"] = aroma_df["SMILES"].apply(
lambda smi: Chem.MolToSmiles(Chem.MolFromSmiles(smi), isomericSmiles=False)
)
# creating matrix with comounds x odors
unique_odors = aroma_df.odor.unique()
full_odor_df = pd.DataFrame(
index=pd.Index(aroma_df.nonstereo_smiles.unique(), name="nonstereo_smiles"),
columns=pd.Index(unique_odors, name="odor"),
)
# setting all values to False
full_odor_df[:] = False
# changing dtype to bool
full_odor_df = full_odor_df.astype(bool)
# setting all values to True where odor is recorded
for index, row in aroma_df.iterrows():
full_odor_df.loc[row["nonstereo_smiles"], row["odor"]] = True
full_odor_df.reset_index(inplace=True)
full_odor_df["fragment_list"] = full_odor_df.nonstereo_smiles.apply(
lambda smi: get_mol_fragments(Chem.MolFromSmiles(smi), add_dummy=True)
)
full_odor_df["fragment_set"] = full_odor_df["fragment_list"].apply(set)
full_odor_df["fragment_count"] = full_odor_df["fragment_list"].apply(len)
import numpy as np
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch
import os
# types of fragments with banana as the target odor
frag_types = [
"C",
"O",
"*c1ccco1",
"*c1cccc(*)c1",
"*c1ccccc1",
"*c1ccccc1*",
"*c1ccc(*)c(*)c1",
]
# covalence for each type of fragment
covalences = [4, 2, 1, 3]
def get_frag_type(frag_smiles):
# after ignoring the attachment positions, there are four types of fragments
for index, smiles in enumerate(frag_types):
if frag_smiles == smiles:
return min(index, 3)
def build_graph_for_fragments(smiles, mol_index, odor, data_index):
# build graph structure for fragment-based molecules
# smiles: SMILES of fragmented molecule
# mol_index: index of fragmented molecule in original dataset
# odor: if this molecule has target odor
# data_index: the new index of this molecule in preprocessed dataset
mol = Chem.MolFromSmiles(smiles)
split_bonds = set()
split_pattern_list = [
Chem.MolFromSmarts("*-*"), # All single bonds
Chem.MolFromSmarts("A=A"), # All aliphatic double bonds
Chem.MolFromSmarts("A#A"), # All aliphatic triple bonds
]
# determine unique bond indices
for split_pattern in split_pattern_list:
split_bonds.update(
{
mol.GetBondBetweenAtoms(a1, a2).GetIdx()
for a1, a2 in mol.GetSubstructMatches(split_pattern)
}
)
# if no bonds are selected.
if not split_bonds:
return 0
# fragment on bonds and extract fragments
frag_mol = Chem.FragmentOnBonds(mol, split_bonds, addDummies=False)
# number of atoms
n_atoms = len(mol.GetAtoms())
# this list provides the mapping between atoms and fragments
frag_list = list(Chem.GetMolFrags(frag_mol, asMols=False))
# list of smiles of each fragment
frag_smiles_list = get_mol_fragments(mol)
# number of fragments
N = len(frag_list)
# number of features
F = len(covalences) + 4 + 5 + 1
# edge set
edges = []
# double bond tag for each edge
db_tag = []
# feature matrix
X = np.zeros((N, F), dtype=np.int8)
# number of neighbors for each fragment
neighbors = np.zeros(N, dtype=np.int8)
# number of hydrogen atoms for each fragment
counts = np.zeros(N, dtype=np.int8)
# index of fragment that each atom belongs to
atom_index = np.zeros(n_atoms, dtype=np.int8)
for frag_index in range(N):
# get the fragment type
frag_type = get_frag_type(frag_smiles_list[frag_index])
# update corresponding feature
X[frag_index, frag_type] = 1
# get the maximal number of hydrogen atoms associated with this fragment
counts[frag_index] = covalences[frag_type]
# index each atom in this fragment
for atom in frag_list[frag_index]:
atom_index[atom] = frag_index
# build graph for fragments
for atom_u in mol.GetAtoms():
for atom_v in atom_u.GetNeighbors():
u = atom_u.GetIdx()
v = atom_v.GetIdx()
frag_u = atom_index[u]
frag_v = atom_index[v]
# record each edge between two different fragments
if frag_u != frag_v:
# update the number of neighbors of fragment frag_u
neighbors[frag_u] += 1
# update edge set
edges.append([frag_u, frag_v])
# get the bond type
db_tag.append(0)
bond_type = mol.GetBondBetweenAtoms(u, v).GetBondTypeAsDouble()
# for both odors considered, there is no triple bond
assert bond_type in [1.0, 2.0]
# update the number of hydrogen atoms associated with fragment frag_u
counts[frag_u] -= int(bond_type)
# update double bond feature and double bond tag
if bond_type == 2.0:
X[frag_u, -1] = 1
db_tag[-1] = 1
# update features for number of neighbors (1,2,3,4) and hydrogen atoms (0,1,2,3,4)
for frag_index in range(N):
X[frag_index, len(covalences) + neighbors[frag_index] - 1] = 1
X[frag_index, len(covalences) + 4 + counts[frag_index]] = 1
edges = np.transpose(np.array(edges))
db_tag = np.array(db_tag)
# add some constraints to filter negative data for banana odor
if not odor:
# number of edges
M = edges.shape[1] // 2
# check if this molecule has at least one aromatic ring or an oxygen atom linked with double bond
flag = 0
if sum(X[:, 2]) + sum(X[:, 3]) > 0:
flag = 1
for u in range(N):
if X[u, 1] + X[u, -1] == 2:
flag = 1
break
# select molecules with (i) moderate size, (2) frag=1, (3) at most one ring (except for aromatic rings)
if N < 8 or N > 10 or flag == 0 or M - (N - 1) > 1:
return 0
# bounds for oxygen atoms
if sum(X[:, 1]) > N / 4.0:
return 0
# bounds for aromatic ring
if sum(X[:, 2]) + sum(X[:, 3]) > 2:
return 0
# bounds for double bonds
if sum(db_tag) // 2 > N // 4:
return 0
# print(edges)
# print(db_tag)
# print(X)
# construct a data corresponding to this molecule for training
data = Data(
x=torch.tensor(X, dtype=torch.float),
edge_index=torch.tensor(edges, dtype=torch.long),
y=torch.tensor([odor], dtype=torch.long),
db_tag=torch.tensor(db_tag, dtype=torch.long),
index=torch.tensor(mol_index, dtype=torch.long),
smiles=smiles,
)
dir = "data/banana/"
if not os.path.exists(dir):
os.makedirs(dir)
torch.save(data, os.path.join(dir, f"data_{data_index}.pt"))
return 1
odor = "banana"
data_index = 0
# remove improper molecules, including disconnected molecules, molecules without target odor but with triple bond
removed_list = [243, 509, 730, 945, 1005, 1067, 1193, 1264, 1354, 1370]
# process all molecules with target odor
single_odor_df = full_odor_df.query(odor)[["nonstereo_smiles", "fragment_set"]]
for row in single_odor_df.itertuples():
smiles = getattr(row, "nonstereo_smiles")
index = getattr(row, "Index")
if index in removed_list:
continue
res = build_graph_for_fragments(smiles, index, 1, data_index)
data_index += res
# number of molecules with target odor
print("number of molecules with target odor:", data_index)
# get the list of fragments from molecules with target odor
unique_fragments = set.union(*single_odor_df.fragment_set.tolist())
unique_compounds = single_odor_df.nonstereo_smiles.unique()
# molecules without target odor but consist of fragments in unique_compounds
non_odor_cpds = full_odor_df.query("~nonstereo_smiles.isin(@unique_compounds)")
non_odor_cpds = non_odor_cpds.loc[
non_odor_cpds["fragment_set"].apply(lambda row: row.issubset(unique_fragments))
]
# process all molecules without target odor but consist of fragments in unique_compounds
for row in non_odor_cpds.itertuples():
smiles = getattr(row, "nonstereo_smiles")
index = getattr(row, "Index")
if index in removed_list:
continue
res = build_graph_for_fragments(smiles, index, 0, data_index)
data_index += res
# number of molecules for training
print("number of molecules for training:", data_index)