-
Notifications
You must be signed in to change notification settings - Fork 2
/
lime_wrap.py
63 lines (52 loc) · 2.48 KB
/
lime_wrap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pandas as pd
from lime.lime_tabular import LimeTabularExplainer
class LimeWrapper:
def __init__(self, train, model, target_column, categorical=[], kernel_width=None):
if not isinstance(train, pd.DataFrame):
raise Exception('Use a pandas DataFrame to store the train set')
# Handling target
if target_column not in train.columns:
raise Exception('Target column not in train set.')
target = train[target_column]
train_x = train.drop(target_column, axis=1)
# Features - detecting categoricals and label encoding
features = train_x.columns.tolist()
cat = train_x.select_dtypes(include=[object]).columns.tolist()
cat = cat + categorical
cat_idx = [features.index(c) for c in cat]
categorical_names = {}
for c in cat_idx:
feat = features[c]
le = LabelEncoder()
train_x[feat] = le.fit_transform(train_x[feat])
categorical_names[c] = le.classes_
# Classif or Reg ? Classif if less than 20 different values among first 1000 rows
if target.values.dtype.type == np.string_:
pred_type = 'classification'
else:
pred_type = 'classification' if np.unique(target[:1000]).shape[0] < 20 else 'regression'
if pred_type == 'classification':
class_names = np.unique(target)
self.labels = [1]
self.predict_fn = lambda x: model.predict_proba(x)
else:
class_names = ['lower', 'higher']
self.labels = [0]
self.predict_fn = lambda x: model.predict(x).reshape(-1, 1)
# setting up the explainer
self.explainer = LimeTabularExplainer(train_x.values,
feature_names=features, class_names=class_names,
categorical_features=cat_idx,
categorical_names=categorical_names, kernel_width=kernel_width,
verbose=True)
def explain(self, instance, num_features=10, num_samples=1000, labels=None, show=True):
if not isinstance(instance, pd.Series):
raise Exception('Use a pandas Serie to store the train set')
exp = self.explainer.explain_instance(instance.values, self.predict_fn,
num_features=num_features, num_samples=num_samples,
labels=self.labels)
if show:
exp.show_in_notebook()
return exp