-
Notifications
You must be signed in to change notification settings - Fork 19
/
predict.py
57 lines (51 loc) · 1.69 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from model import *
import sys
import pickle
f = open('wordtagidx','rb')
wordtagidx = pickle.load(f)
class NER():
def __init__(self,modelpath,wordtagidx,ctx=mx.cpu()):
word2idx,tag2idx,idx2tag = wordtagidx
model = BiLSTM_CRF(len(word2idx), tag2idx, 5, 4,ctx=ctx)
model.load_params(modelpath,ctx=ctx)
self.model = model
self.word2idx = word2idx
self.idx2tag = idx2tag
self.ctx = ctx
def predict(self,string):
string = list(string)
precheck_sent = prepare_sequence(string, self.word2idx)
prd = self.model(precheck_sent.as_in_context(self.ctx))
n = -1
word = []
words = []
targs = []
targidx = 0
for i,zzz in enumerate(prd[1]):
if zzz <= len(self.idx2tag)-1:
if n == -1:
word.append(string[i])
n = i
targidx = zzz
else:
if n+1 == i and targidx == zzz:
word.append(string[i])
n = i
else:
words.append(''.join(word))
targs.append(self.idx2tag[targidx])
targidx = zzz
word=[]
word.append(string[i])
n = i
words.append(''.join(word))
targs.append(self.idx2tag[targidx])
return words,targs
import os
ner = NER("model.params",wordtagidx,ctx=mx.cpu())
f = open(sys.argv[1],encoding="utf-8")
string = f.read()
string = string.replace("\n", "")
out = ner.predict(string)
for i,word in enumerate(out[0]):
print(word + ' '+ out[1][i])