-
Notifications
You must be signed in to change notification settings - Fork 0
/
classifier_utils.py
42 lines (36 loc) · 1.35 KB
/
classifier_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import pickle
import sys
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
# TODO: Remove workaround for my strange Python environment
if sys.version_info.minor <= 9:
import torch
#Cross validates the classifier to check accuracy
def check_classifier(clf, features, labels):
#Computes scores using 10-fold crossvalidation
print("Cross-Validating model")
scores = cross_val_score(clf, features, labels, cv=10)
print(scores)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std()))
print("\n")
#Loads the classifer from a pickle
def load_classifier(name="optimal.pkl"):
if name.split(".")[-1] == "pkl":
with open(name, "rb") as f:
return pickle.load(f)
else:
return torch.load(name)
#Helper function to save the classifier to a pickle file
def save_classifier(clf, features, labels, name="optimal.pkl"):
clf.fit(features, labels)
pickle.dump(clf, open(name, 'wb+'))
def grid_search_params(clf, features, labels, params):
print("Searching Param Space")
grid_cv = GridSearchCV(clf, param_grid=params, cv=5)
grid_cv.fit(features, labels)
print(grid_cv.best_params_)