-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Explainer API for Local Classifier per parent node #minor (#106)
- Loading branch information
1 parent
c767c9b
commit c436275
Showing
20 changed files
with
746 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,5 +19,5 @@ sphinx-rtd-theme = "0.5.2" | |
|
||
[extras] | ||
ray = "*" | ||
shap = "*" | ||
shap = "0.44.1" | ||
xarray = "*" |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
============================================ | ||
Explaining Local Classifier Per Parent Node | ||
============================================ | ||
A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPPN model. | ||
A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`. | ||
SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here <https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/3f225c3f80dd8cbb1b6252f6c372a054ec968705/platypus_diseases.csv>`_. | ||
""" | ||
from sklearn.ensemble import RandomForestClassifier | ||
from hiclass import LocalClassifierPerParentNode, Explainer | ||
import requests | ||
import pandas as pd | ||
import shap | ||
|
||
# Download training data | ||
url = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv" | ||
path = "platypus_diseases.csv" | ||
response = requests.get(url) | ||
with open(path, "wb") as file: | ||
file.write(response.content) | ||
|
||
# Load training data into pandas dataframe | ||
training_data = pd.read_csv(path).fillna(" ") | ||
|
||
# Define data | ||
X_train = training_data.drop(["label"], axis=1) | ||
X_test = X_train[:100] # Use first 100 samples as test set | ||
Y_train = training_data["label"] | ||
Y_train = [eval(my) for my in Y_train] | ||
|
||
# Use random forest classifiers for every node | ||
rfc = RandomForestClassifier() | ||
classifier = LocalClassifierPerParentNode( | ||
local_classifier=rfc, replace_classifiers=False | ||
) | ||
|
||
# Train local classifier per parent node | ||
classifier.fit(X_train, Y_train) | ||
|
||
# Define Explainer | ||
explainer = Explainer(classifier, data=X_train, mode="tree") | ||
explanations = explainer.explain(X_test.values) | ||
print(explanations) | ||
|
||
# Filter samples which only predicted "Respiratory" at first level | ||
respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory" | ||
|
||
# Specify additional filters to obtain only level 0 | ||
shap_filter = {"level": 0, "class": "Respiratory", "sample": respiratory_idx} | ||
|
||
# Use .sel() method to apply the filter and obtain filtered results | ||
shap_val_respiratory = explanations.sel(shap_filter) | ||
|
||
# Plot feature importance on test set | ||
shap.plots.violin( | ||
shap_val_respiratory.shap_values, | ||
feature_names=X_train.columns.values, | ||
plot_size=(13, 8), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.