-
Notifications
You must be signed in to change notification settings - Fork 0
/
k_NN.py
20 lines (16 loc) · 862 Bytes
/
k_NN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
from ..utils.evaluation_utils import *
from ..utils.settings import Model_Settings
def train_knn_model(X_train, y_train, X_test, y_test, model_settings: Model_Settings, params, exclude_fet):
X_train_drop = X_train.drop(columns = exclude_fet)
X_test_drop = X_test.drop(columns = exclude_fet)
model = KNeighborsClassifier(n_jobs=-1, **params)
model.fit(X_train_drop, y_train)
y_pred_logit = model.predict_proba(X_test_drop)
# y_pred = (y_pred_logit > classification_threshold).astype(int)
y_pred = np.argmax(y_pred_logit, axis=2)
y_hot = (y_pred == 1).astype(int).T
f_score, precision, recall = compute_avg_f_score(y_hot, y_test)
acuracy = compute_avg_acuracy(y_hot, y_test)
return model, y_hot, acuracy, f_score, precision, recall