forked from deepchem/deepchem
-
Notifications
You must be signed in to change notification settings - Fork 0
/
benchmark_low_data.py
62 lines (51 loc) · 1.46 KB
/
benchmark_low_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 8 16:48:05 2016
@author: Michael Wu
Low data benchmark test
Giving performances of: Siamese, attention-based embedding, residual embedding
on datasets: muv, sider, tox21
time estimation listed in README file
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals
import numpy as np
import deepchem as dc
import argparse
np.random.seed(123)
parser = argparse.ArgumentParser(
description='Deepchem benchmark: ' +
'giving performances of different learning models on datasets')
parser.add_argument(
'-m',
action='append',
dest='model_args',
default=[],
help='Choice of model: siamese, attn, res')
parser.add_argument(
'-d',
action='append',
dest='dataset_args',
default=[],
help='Choice of dataset: tox21, sider, muv')
parser.add_argument(
'--cv',
action='store_true',
dest='cross_valid',
default=False,
help='whether to implement cross validation')
args = parser.parse_args()
#Datasets and models used in the benchmark test
models = args.model_args
datasets = args.dataset_args
cross_valid = args.cross_valid
if len(models) == 0:
models = ['siamese', 'attn', 'res']
if len(datasets) == 0:
datasets = ['tox21', 'sider', 'muv']
for dataset in datasets:
for model in models:
dc.molnet.run_benchmark_low_data(
[dataset], str(model), cross_valid=cross_valid)