Skip to content

Commit

Permalink
Explainer API for Local Classifier per parent node #minor (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashishpatel16 authored Mar 27, 2024
1 parent c767c9b commit c436275
Show file tree
Hide file tree
Showing 20 changed files with 746 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ sphinx-rtd-theme = "0.5.2"

[extras]
ray = "*"
shap = "*"
shap = "0.44.1"
xarray = "*"
81 changes: 78 additions & 3 deletions Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ HiClass is an open-source Python library for hierarchical classification compati
- [Who is using HiClass?](#who-is-using-hiclass)
- [Install](#install)
- [Quick start](#quick-start)
- [Explaining Hierarchical Classifiers](#explaining-hierarchical-classifiers)
- [Step-by-step walk-through](#step-by-step-walk-through)
- [API documentation](#api-documentation)
- [FAQ](#faq)
Expand All @@ -34,6 +35,7 @@ HiClass is an open-source Python library for hierarchical classification compati
- **[Hierarchical metrics](https://hiclass.readthedocs.io/en/latest/api/utilities.html#hierarchical-metrics):** HiClass supports the computation of hierarchical precision, recall and f-score, which are more appropriate for hierarchical data than traditional metrics.
- **[Compatible with pickle](https://hiclass.readthedocs.io/en/latest/auto_examples/plot_model_persistence.html):** Easily store trained models on disk for future use.
- **[BERT sklearn](https://hiclass.readthedocs.io/en/latest/auto_examples/plot_bert.html):** Compatible with the library [BERT sklearn](https://github.com/charles9n/bert-sklearn).
- **[Hierarchical Explanability](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html):** HiClass allows explaining hierarchical models using the [SHAP](https://github.com/shap/shap) package.

**Any feature missing on this list?** Search our [issue tracker](https://github.com/scikit-learn-contrib/hiclass/issues) to see if someone has already requested it and add a comment to it explaining your use-case. Otherwise, please open a new issue describing the requested feature and possible use-case scenario. We prioritize our roadmap based on user feedback, so we would love to hear from you.

Expand Down Expand Up @@ -113,7 +115,7 @@ pip install hiclass"[<extra_name>]"
Replace <extra_name> with one of the following options:

- ray: Installs the ray package, which is required for parallel processing support.
- xai: Installs the shap and xarray packages, which are required for explaining Hiclass predictions.
- xai: Installs the shap and xarray packages, which are required for explaining Hiclass' predictions.

### Option 2: Conda

Expand Down Expand Up @@ -199,6 +201,9 @@ pipeline.fit(X_train, Y_train)
predictions = pipeline.predict(X_test)
```

## Explaining Hierarchical Classifiers
Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html).

## Step-by-step walk-through

A step-by-step walk-through is available on our documentation hosted on [Read the Docs](https://hiclass.readthedocs.io/en/latest/index.html).
Expand Down
61 changes: 61 additions & 0 deletions docs/examples/plot_lcppn_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""
============================================
Explaining Local Classifier Per Parent Node
============================================
A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPPN model.
A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`.
SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here <https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/3f225c3f80dd8cbb1b6252f6c372a054ec968705/platypus_diseases.csv>`_.
"""
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerParentNode, Explainer
import requests
import pandas as pd
import shap

# Download training data
url = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv"
path = "platypus_diseases.csv"
response = requests.get(url)
with open(path, "wb") as file:
file.write(response.content)

# Load training data into pandas dataframe
training_data = pd.read_csv(path).fillna(" ")

# Define data
X_train = training_data.drop(["label"], axis=1)
X_test = X_train[:100] # Use first 100 samples as test set
Y_train = training_data["label"]
Y_train = [eval(my) for my in Y_train]

# Use random forest classifiers for every node
rfc = RandomForestClassifier()
classifier = LocalClassifierPerParentNode(
local_classifier=rfc, replace_classifiers=False
)

# Train local classifier per parent node
classifier.fit(X_train, Y_train)

# Define Explainer
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test.values)
print(explanations)

# Filter samples which only predicted "Respiratory" at first level
respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory"

# Specify additional filters to obtain only level 0
shap_filter = {"level": 0, "class": "Respiratory", "sample": respiratory_idx}

# Use .sel() method to apply the filter and obtain filtered results
shap_val_respiratory = explanations.sel(shap_filter)

# Plot feature importance on test set
shap.plots.violin(
shap_val_respiratory.shap_values,
feature_names=X_train.columns.values,
plot_size=(13, 8),
)
2 changes: 2 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ pandas==1.4.2
ray==1.13.0
numpy
git+https://github.com/charles9n/bert-sklearn.git@master
shap==0.44.1
xarray==2023.1.0
Binary file added docs/source/algorithms/explainer-indexing.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit c436275

Please sign in to comment.