-
Notifications
You must be signed in to change notification settings - Fork 1
/
rnn_test.py
35 lines (30 loc) · 917 Bytes
/
rnn_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from keras.models import load_model
import pickle
import numpy as np
def create_model(data_path, model_path):
model = load_model(model_path)
print("loading samples")
pickle_in = open(data_path, "rb")
data = pickle.load(pickle_in)
print("samples loaded")
print(data['class'].shape)
pred = model.predict(data['bigdata1'])
#print pred[:10]
a=np.argmax(pred, axis=1)
#print a[:300]
#print(data['y_data'].reshape(80,6)[::5])
b=data['class'].flatten()
#print b[:300]
for i,j in enumerate(a[a!=b]):
print j, "-->", b[a!=b][i]
#print(a[a!=b])
#print(b[a!=b])
print np.sum(a==b), len(b)
print (float(np.sum(a==b))/float(len(b)))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-d', type=str)
parser.add_argument('-m', type=str)
args = parser.parse_args()
create_model(args.d, args.m)