-
Notifications
You must be signed in to change notification settings - Fork 1
/
modelEvaluation.py
63 lines (51 loc) · 2.28 KB
/
modelEvaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from pandas import DataFrame as df
import matplotlib.pyplot as plt
import numpy as np
def plot_results(predicted_labels, true_labels, snrs):
sorted_snrs = np.sort(np.unique(snrs))
x_axis = []
y_axis = []
for snr in sorted_snrs:
idx = np.where(snrs == snr)
#print('snr =', snr, '-->', accuracy_score(np.argmax(true_labels[idx], axis = 1), np.argmax(predicted_labels[idx], axis = 1)))
x_axis.append(snr)
y_axis.append(accuracy_score(np.argmax(true_labels[idx], axis = 1), np.argmax(predicted_labels[idx], axis = 1)))
plt.xlabel('SNR')
plt.ylabel('Accuracy')
plt.title('Classification Accuracy over different SNRs')
plt.plot(x_axis, np.array(y_axis) * 100, 'ro--')
plt.grid(True)
plt.savefig("./result/metaSGD_result.png")
plt.show()
def print_results(predicted_labels, true_labels, snrs):
sorted_snrs = np.sort(np.unique(snrs))
x_axis = []
y_axis = []
for snr in sorted_snrs:
idx = np.where(snrs == snr)
#print('snr =', snr, '-->', accuracy_score(np.argmax(true_labels[idx], axis = 1), np.argmax(predicted_labels[idx], axis = 1)))
x_axis.append(snr)
y_axis.append(accuracy_score(np.argmax(true_labels[idx], axis = 1), np.argmax(predicted_labels[idx], axis = 1)))
return df(data = np.array(y_axis).reshape(1, -1) * 100, columns = sorted_snrs, index = ['accuracy']).round(2)
def plot_confusion_matrix(y_true, y_pred, classes, cmap=plt.cm.Blues):
y_true, y_pred = np.argmax(y_true, axis = 1), np.argmax(y_pred, axis = 1)
cm = confusion_matrix(y_true, y_pred)
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
title = 'Confusion Matrix',
ylabel = 'True label',
xlabel = 'Predicted label')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
plt.savefig('result/metaSGD_confusion.png')
plt.show()