-
Notifications
You must be signed in to change notification settings - Fork 4
/
predict.py
108 lines (93 loc) · 3.57 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import copy
import json
import torch
import pprint
import utils
import random
import dataset
import argparse
from model.model import Module
mydir = os.path.dirname(__file__)
parser = argparse.ArgumentParser()
parser.add_argument('resume')
parser.add_argument('input')
parser.add_argument('--resumes', nargs='*')
parser.add_argument('--dataset', default='spider', choices=['spider', 'cosql', 'sparc'])
parser.add_argument('--tables', default='tables.json')
parser.add_argument('--db', default='database')
parser.add_argument('--dcache', default='cache')
parser.add_argument('--batch', type=int, default=6)
parser.add_argument('--output', default='output.txt')
def main(orig_args):
# load pretrained model
fresume = os.path.abspath(orig_args.resume)
# print('resuming from {}'.format(fresume))
assert os.path.isfile(fresume), '{} does not exist'.format(fresume)
orig_args.input = os.path.abspath(orig_args.input)
orig_args.tables = os.path.abspath(orig_args.tables)
orig_args.db = os.path.abspath(orig_args.db)
orig_args.dcache = os.path.abspath(orig_args.dcache)
binary = torch.load(fresume, map_location=torch.device('cpu'))
args = binary['args']
ext = binary['ext']
args.gpu = torch.cuda.is_available()
args.tables = orig_args.tables
args.db = orig_args.db
args.dcache = orig_args.dcache
args.batch = orig_args.batch
Model = utils.load_module(args.model)
if args.model == 'nl2sql':
Reranker = utils.load_module(args.beam_rank)
ext['reranker'] = Reranker(args, ext)
m = Model(args, ext).place_on_device()
if orig_args.resumes:
m.average_saves(orig_args.resumes)
else:
m.load_save(fname=fresume)
# preprocess data
data = dataset.Dataset()
if orig_args.dataset == 'spider':
import preprocess_nl2sql as preprocess
elif orig_args.dataset == 'sparc':
import preprocess_nl2sql_sparc as preprocess
elif orig_args.dataset == 'cosql':
import preprocess_nl2sql_cosql as preprocess
proc_errors = set()
with open(orig_args.input) as f:
C = preprocess.SQLDataset
raw = json.load(f)
# make contexts and populate vocab
for i, ex in enumerate(raw):
for k in ['query', 'query_toks', 'query_toks_no_value', 'sql']:
if k in ex:
del ex[k]
ex['id'] = '{}/{}'.format(ex['db_id'], i)
new = C.make_example(ex, m.bert_tokenizer, m.sql_vocab, m.conv.kmaps, m.conv, train=False, evaluation=True)
new['question'] = ex['question']
if new is not None:
new['cands_query'], new['cands_value'] = C.make_cands(new, m.sql_vocab)
data.append(new)
else:
print('proc error')
proc_errors.add(ex['id'])
# run preds
if orig_args.dataset in {'cosql', 'sparc'}:
preds = m.run_interactive_pred(data, args, verbose=True)
raise NotImplementedError()
else:
preds = m.run_pred(data, args, verbose=True)
assert len(preds) + len(proc_errors) == len(data), 'got {} predictions for {} examples'.format(len(preds), len(data))
# print('writing to {}'.format(orig_args.output))
with open(orig_args.output, 'wt') as f:
for ex in data:
if ex['id'] in proc_errors:
s = 'ERROR'
else:
p = preds[ex['id']]
s = p['query']
f.write(s + '\n')
f.flush()
if __name__ == '__main__':
args = parser.parse_args()
main(args)