diff --git a/Pipfile b/Pipfile index 2039c9a7..4a08c49b 100644 --- a/Pipfile +++ b/Pipfile @@ -19,5 +19,5 @@ sphinx-rtd-theme = "0.5.2" [extras] ray = "*" -shap = "*" +shap = "0.44.1" xarray = "*" diff --git a/Pipfile.lock b/Pipfile.lock index c6179dbe..4d6c1420 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "4f36f76cfc52dc83b74cc00b92af24b8ff556ca22dc2f21401224c39659346eb" + "sha256": "eccfc9563430b35e4b8e92aea9e1b387fea449e1538b8162224087bad096ab71" }, "pipfile-spec": 6, "requires": {}, @@ -14,6 +14,14 @@ ], }, "default": { + "cloudpickle": { + "hashes": [ + "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7", + "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882" + ], + "markers": "python_version >= '3.8'", + "version": "==3.0.0" + }, "joblib": { "hashes": [ "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1", @@ -22,12 +30,69 @@ "markers": "python_version >= '3.7'", "version": "==1.3.2", }, + "llvmlite": { + "hashes": [ + "sha256:04725975e5b2af416d685ea0769f4ecc33f97be541e301054c9f741003085802", + "sha256:0dd0338da625346538f1173a17cabf21d1e315cf387ca21b294ff209d176e244", + "sha256:150d0bc275a8ac664a705135e639178883293cf08c1a38de3bbaa2f693a0a867", + "sha256:1eee5cf17ec2b4198b509272cf300ee6577229d237c98cc6e63861b08463ddc6", + "sha256:210e458723436b2469d61b54b453474e09e12a94453c97ea3fbb0742ba5a83d8", + "sha256:2181bb63ef3c607e6403813421b46982c3ac6bfc1f11fa16a13eaafb46f578e6", + "sha256:24091a6b31242bcdd56ae2dbea40007f462260bc9bdf947953acc39dffd54f8f", + "sha256:2b76acee82ea0e9304be6be9d4b3840208d050ea0dcad75b1635fa06e949a0ae", + "sha256:2d92c51e6e9394d503033ffe3292f5bef1566ab73029ec853861f60ad5c925d0", + "sha256:5940bc901fb0325970415dbede82c0b7f3e35c2d5fd1d5e0047134c2c46b3281", + "sha256:8454c1133ef701e8c050a59edd85d238ee18bb9a0eb95faf2fca8b909ee3c89a", + "sha256:855f280e781d49e0640aef4c4af586831ade8f1a6c4df483fb901cbe1a48d127", + "sha256:880cb57ca49e862e1cd077104375b9d1dfdc0622596dfa22105f470d7bacb309", + "sha256:8b0a9a47c28f67a269bb62f6256e63cef28d3c5f13cbae4fab587c3ad506778b", + "sha256:92c32356f669e036eb01016e883b22add883c60739bc1ebee3a1cc0249a50828", + "sha256:92f093986ab92e71c9ffe334c002f96defc7986efda18397d0f08534f3ebdc4d", + "sha256:9564c19b31a0434f01d2025b06b44c7ed422f51e719ab5d24ff03b7560066c9a", + "sha256:b67340c62c93a11fae482910dc29163a50dff3dfa88bc874872d28ee604a83be", + "sha256:bf14aa0eb22b58c231243dccf7e7f42f7beec48970f2549b3a6acc737d1a4ba4", + "sha256:c1e1029d47ee66d3a0c4d6088641882f75b93db82bd0e6178f7bd744ebce42b9", + "sha256:df75594e5a4702b032684d5481db3af990b69c249ccb1d32687b8501f0689432", + "sha256:f19f767a018e6ec89608e1f6b13348fa2fcde657151137cb64e56d48598a92db", + "sha256:f8afdfa6da33f0b4226af8e64cfc2b28986e005528fbf944d0a24a72acfc9432", + "sha256:fa1469901a2e100c17eb8fe2678e34bd4255a3576d1a543421356e9c14d6e2ae" + ], + "markers": "python_version >= '3.8'", + "version": "==0.41.1" + }, "networkx": { "hashes": [ "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36", "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61", ], "index": "pypi", + "markers": "python_version >= '3.9'", + "version": "==3.2.1" + }, + "numba": { + "hashes": [ + "sha256:07f2fa7e7144aa6f275f27260e73ce0d808d3c62b30cff8906ad1dec12d87bbe", + "sha256:240e7a1ae80eb6b14061dc91263b99dc8d6af9ea45d310751b780888097c1aaa", + "sha256:45698b995914003f890ad839cfc909eeb9c74921849c712a05405d1a79c50f68", + "sha256:487ded0633efccd9ca3a46364b40006dbdaca0f95e99b8b83e778d1195ebcbaa", + "sha256:4e79b6cc0d2bf064a955934a2e02bf676bc7995ab2db929dbbc62e4c16551be6", + "sha256:55a01e1881120e86d54efdff1be08381886fe9f04fc3006af309c602a72bc44d", + "sha256:5c765aef472a9406a97ea9782116335ad4f9ef5c9f93fc05fd44aab0db486954", + "sha256:6fe7a9d8e3bd996fbe5eac0683227ccef26cba98dae6e5cee2c1894d4b9f16c1", + "sha256:7bf1ddd4f7b9c2306de0384bf3854cac3edd7b4d8dffae2ec1b925e4c436233f", + "sha256:811305d5dc40ae43c3ace5b192c670c358a89a4d2ae4f86d1665003798ea7a1a", + "sha256:81fe5b51532478149b5081311b0fd4206959174e660c372b94ed5364cfb37c82", + "sha256:898af055b03f09d33a587e9425500e5be84fc90cd2f80b3fb71c6a4a17a7e354", + "sha256:9e9356e943617f5e35a74bf56ff6e7cc83e6b1865d5e13cee535d79bf2cae954", + "sha256:a1eaa744f518bbd60e1f7ccddfb8002b3d06bd865b94a5d7eac25028efe0e0ff", + "sha256:bc2d904d0319d7a5857bd65062340bed627f5bfe9ae4a495aef342f072880d50", + "sha256:bcecd3fb9df36554b342140a4d77d938a549be635d64caf8bd9ef6c47a47f8aa", + "sha256:bd3dda77955be03ff366eebbfdb39919ce7c2620d86c906203bed92124989032", + "sha256:bf68df9c307fb0aa81cacd33faccd6e419496fdc621e83f1efce35cdc5e79cac", + "sha256:d3e2fe81fe9a59fcd99cc572002101119059d64d31eb6324995ee8b0f144a306", + "sha256:e63d6aacaae1ba4ef3695f1c2122b30fa3d8ba039c8f517784668075856d79e2", + "sha256:ea5bfcf7d641d351c6a80e8e1826eb4a145d619870016eeaf20bbd71ef5caa22" + ], "markers": "python_version >= '3.8'", "version": "==3.1", }, @@ -360,6 +425,14 @@ "markers": "python_version >= '3.8'", "version": "==3.3.0", }, + "jeepney": { + "hashes": [ + "sha256:5efe48d255973902f6badc3ce55e2aa6c5c3b3bc642059ef3a91247bcfcc5806", + "sha256:c0a454ad016ca575060802ee4d590dd912e35c122fa04e70306de3d076cce755" + ], + "markers": "sys_platform == 'linux'", + "version": "==0.8.0" + }, "jinja2": { "hashes": [ "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852", @@ -594,15 +667,16 @@ "markers": "python_full_version >= '3.7.0'", "version": "==13.5.2", }, - "setuptools": { + "secretstorage": { "hashes": [ "sha256:3d4dfa6d95f1b101d695a6160a7626e15583af71a5f52176efa5d39a054d475d", "sha256:3d8083eed2d13afc9426f227b24fd1659489ec107c0e86cec2ffdde5c92e790b", ], "markers": "python_version >= '3.8'", "version": "==68.1.2", + }, - "six": { + "setuptools": { "hashes": [ "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926", "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254", @@ -1137,6 +1211,7 @@ "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4", "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba", "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8", + "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef", "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5", "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd", "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3", diff --git a/README.md b/README.md index 04ad8fc6..7fd8bd28 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ HiClass is an open-source Python library for hierarchical classification compati - [Who is using HiClass?](#who-is-using-hiclass) - [Install](#install) - [Quick start](#quick-start) +- [Explaining Hierarchical Classifiers](#explaining-hierarchical-classifiers) - [Step-by-step walk-through](#step-by-step-walk-through) - [API documentation](#api-documentation) - [FAQ](#faq) @@ -34,6 +35,7 @@ HiClass is an open-source Python library for hierarchical classification compati - **[Hierarchical metrics](https://hiclass.readthedocs.io/en/latest/api/utilities.html#hierarchical-metrics):** HiClass supports the computation of hierarchical precision, recall and f-score, which are more appropriate for hierarchical data than traditional metrics. - **[Compatible with pickle](https://hiclass.readthedocs.io/en/latest/auto_examples/plot_model_persistence.html):** Easily store trained models on disk for future use. - **[BERT sklearn](https://hiclass.readthedocs.io/en/latest/auto_examples/plot_bert.html):** Compatible with the library [BERT sklearn](https://github.com/charles9n/bert-sklearn). +- **[Hierarchical Explanability](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html):** HiClass allows explaining hierarchical models using the [SHAP](https://github.com/shap/shap) package. **Any feature missing on this list?** Search our [issue tracker](https://github.com/scikit-learn-contrib/hiclass/issues) to see if someone has already requested it and add a comment to it explaining your use-case. Otherwise, please open a new issue describing the requested feature and possible use-case scenario. We prioritize our roadmap based on user feedback, so we would love to hear from you. @@ -113,7 +115,7 @@ pip install hiclass"[]" Replace with one of the following options: - ray: Installs the ray package, which is required for parallel processing support. -- xai: Installs the shap and xarray packages, which are required for explaining Hiclass predictions. +- xai: Installs the shap and xarray packages, which are required for explaining Hiclass' predictions. ### Option 2: Conda @@ -199,6 +201,9 @@ pipeline.fit(X_train, Y_train) predictions = pipeline.predict(X_test) ``` +## Explaining Hierarchical Classifiers +Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html). + ## Step-by-step walk-through A step-by-step walk-through is available on our documentation hosted on [Read the Docs](https://hiclass.readthedocs.io/en/latest/index.html). diff --git a/docs/examples/plot_lcppn_explainer.py b/docs/examples/plot_lcppn_explainer.py new file mode 100644 index 00000000..b4864541 --- /dev/null +++ b/docs/examples/plot_lcppn_explainer.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +""" +============================================ +Explaining Local Classifier Per Parent Node +============================================ + +A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPPN model. +A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`. +SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here `_. +""" +from sklearn.ensemble import RandomForestClassifier +from hiclass import LocalClassifierPerParentNode, Explainer +import requests +import pandas as pd +import shap + +# Download training data +url = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv" +path = "platypus_diseases.csv" +response = requests.get(url) +with open(path, "wb") as file: + file.write(response.content) + +# Load training data into pandas dataframe +training_data = pd.read_csv(path).fillna(" ") + +# Define data +X_train = training_data.drop(["label"], axis=1) +X_test = X_train[:100] # Use first 100 samples as test set +Y_train = training_data["label"] +Y_train = [eval(my) for my in Y_train] + +# Use random forest classifiers for every node +rfc = RandomForestClassifier() +classifier = LocalClassifierPerParentNode( + local_classifier=rfc, replace_classifiers=False +) + +# Train local classifier per parent node +classifier.fit(X_train, Y_train) + +# Define Explainer +explainer = Explainer(classifier, data=X_train, mode="tree") +explanations = explainer.explain(X_test.values) +print(explanations) + +# Filter samples which only predicted "Respiratory" at first level +respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory" + +# Specify additional filters to obtain only level 0 +shap_filter = {"level": 0, "class": "Respiratory", "sample": respiratory_idx} + +# Use .sel() method to apply the filter and obtain filtered results +shap_val_respiratory = explanations.sel(shap_filter) + +# Plot feature importance on test set +shap.plots.violin( + shap_val_respiratory.shap_values, + feature_names=X_train.columns.values, + plot_size=(13, 8), +) diff --git a/docs/requirements.txt b/docs/requirements.txt index fc59178d..ac1037e3 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -9,3 +9,5 @@ pandas==1.4.2 ray==1.13.0 numpy git+https://github.com/charles9n/bert-sklearn.git@master +shap==0.44.1 +xarray==2023.1.0 \ No newline at end of file diff --git a/docs/source/algorithms/explainer-indexing.png b/docs/source/algorithms/explainer-indexing.png new file mode 100644 index 00000000..60c177e7 Binary files /dev/null and b/docs/source/algorithms/explainer-indexing.png differ diff --git a/docs/source/algorithms/explainer.rst b/docs/source/algorithms/explainer.rst new file mode 100644 index 00000000..b87ced9c --- /dev/null +++ b/docs/source/algorithms/explainer.rst @@ -0,0 +1,131 @@ +.. _explainer-overview: + +=========================== +Hierarchical Explainability +=========================== +HiClass also provides support for eXplainable AI (XAI) using SHAP values. This section demonstrates the Explainer class along with examples and design principles. + +++++++++++++++++++++++++++ +Motivation +++++++++++++++++++++++++++ + +Explainability in machine learning refers to understanding and interpreting how a model arrives at a particular decision. Several explainability methods are available in the literature, which have found applications in various machine learning applications. + +SHAP values are one such approach that provides a unified measure of feature importance that considers the contribution of each feature to the model prediction. These values are based on cooperative game theory and provide a fair way to distribute the credit for the prediction among the features. + +Integrating explainability methods into Hierarchical classifiers can yield promising results depending on the application domain. Hierarchical explainability extends the concept of SHAP values to hierarchical classification models. + +++++++++++++++++++++++++++ +Dataset overview +++++++++++++++++++++++++++ +For the remainder of this section, we will utilize a synthetically generated dataset representing platypus diseases. This tabular dataset is created to visualize and test the essence of explainability using SHAP on hierarchical models. The diagram below illustrates the hierarchical structure of the dataset. With nine symptoms as features—fever, diarrhea, stomach pain, skin rash, cough, sniffles, shortness of breath, headache, and body size—the objective is to predict the disease based on these feature values. + +.. figure:: ../algorithms/platypus_diseases_hierarchy.svg + :align: center + :width: 100% + + Hierarchical structure of the synthetic dataset representing platypus diseases. + +++++++++++++++++++++++++++ +Background +++++++++++++++++++++++++++ +This section introduces two main concepts: hierarchical classification and SHAP values. Hierarchical classification leverages the hierarchical structure of data, breaking down the classification task into manageable sub-tasks using models organized in a tree or DAG structure. + +SHAP values, adapted from game theory, show the impact of features on model predictions, thus aiding model interpretation. The SHAP library offers practical implementation of these methods, supporting various machine learning algorithms for explanation generation. + +To demonstrate how SHAP values provide insights into model prediction, consider the following sample from the platypus disease dataset. + +.. code-block:: python + + test_sample = np.array([[35.5, 0. , 1. , 1. , 3. , 3. , 0. , 2. , 37.5]]) + sample_target = np.array([['Respiratory', 'Cold', '']]) + +We can calculate SHAP values using the SHAP python package and visualize them. SHAP values tell us how much each symptom "contributes" to the model's decision about which disease a platypus might have. The following diagram illustrates how SHAP values can be visualized using the :literal:`shap.force_plot`. + +.. figure:: ../algorithms/shap_explanation.png + :align: center + :width: 100% + + Force plot illustrating the influence of symptoms on predicting platypus diseases using SHAP values. Each bar represents a symptom, and its length indicates the magnitude of its impact on disease prediction. + +++++++++++++++++++++++++++ +API Design +++++++++++++++++++++++++++ + +Designing an API for hierarchical classifiers and SHAP value computation presents numerous challenges, including complex data structures, difficulties accessing correct SHAP values corresponding to a classifier, and slow computation. We addressed these issues by using :literal:`xarray.Dataset` for organization, filtering, and storage of SHAP values efficiency. We also utilized parallelization using Joblib for speed. These enhancements ensure a streamlined and user-friendly experience for users dealing with hierarchical classifiers and SHAP values. + +.. figure:: ../algorithms/explainer-indexing.png + :align: center + :width: 75% + + Pictorial representation of dimensions along which indexing of hierarchical SHAP values is required. + +The Explainer class takes a fitted HiClass model, training data, and some named parameters as input. After creating an instance of the Explainer, the :literal:`Explainer.explain` method can be called by providing the samples for which SHAP values need to be calculated. + +.. code-block:: python + + explainer = Explainer(fitted_hiclass_model, data=training_data) + +The Explainer returns an :literal:`xarray.Dataset` object which allows users to intuitively access, filter, slice, and plot SHAP values. This Explanation object can also be used interactively within the Jupyter notebook environment. The Explanation object along with its respective attributes are depicted in the following UML diagram. + +.. figure:: ../algorithms/hiclass-uml.png + :align: center + :width: 100% + + UML diagram showing the relationship between HiClass Explainer and the returned Explanation object. + +The Explanation object can be obtained by calling the :literal:`explain` method of the Explainer class. + +.. code-block:: python + + explanations = explainer.explain(sample_data) + + +++++++++++++++++++++++++++ +Code sample +++++++++++++++++++++++++++ + +.. code-block:: python + + from sklearn.ensemble import RandomForestClassifier + import numpy as np + from hiclass import LocalClassifierPerParentNode, Explainer + + rfc = RandomForestClassifier() + lcppn = LocalClassifierPerParentNode(local_classifier=rfc, replace_classifiers=False) + + x_train = np.array([ + [40.7, 1. , 1. , 2. , 5. , 2. , 1. , 5. , 34.3], + [39.2, 0. , 2. , 4. , 1. , 3. , 1. , 2. , 34.1], + [40.6, 0. , 3. , 1. , 4. , 5. , 0. , 6. , 27.7], + [36.5, 0. , 3. , 1. , 2. , 2. , 0. , 2. , 39.9], + ]) + y_train = np.array([ + ['Gastrointestinal', 'Norovirus', ''], + ['Respiratory', 'Covid', ''], + ['Allergy', 'External', 'Bee Allergy'], + ['Respiratory', 'Cold', ''], + ]) + + x_test = np.array([[35.5, 0. , 1. , 1. , 3. , 3. , 0. , 2. , 37.5]]) + + lcppn.fit(x_train, y_train) + explainer = Explainer(lcppn, data=x_train, mode="tree") + explanations = explainer.explain(x_test) + + +++++++++++++++++++++++++++ +Filtering and Manipulation +++++++++++++++++++++++++++ + +The Explanation object returned by the Explainer is built using the :literal:`xarray.Dataset` data structure, that enables the application of any xarray dataset operation. For example, filtering specific values can be quickly done. To illustrate the filtering operation, suppose we have SHAP values stored in the Explanation object named :literal:`explanation`. + +A common use case is to extract SHAP values for only the predicted nodes. In Local classifier per parent node approach, each node except the leaf nodes represents a classifier. Hence, to find the SHAP values, we can pass the prediction until the penultimate element to obtain the SHAP values. +To achieve this, we can use xarray's :literal:`.sel()` method: + +.. code-block:: python + + mask = {'class': lcppn.predict(x_test).flatten()[:-1]} + x = explanations.sel(mask).shap_values + +More advanced usage and capabilities can be found at the `Xarray.Dataset `_ documentation. diff --git a/docs/source/algorithms/hiclass-uml.png b/docs/source/algorithms/hiclass-uml.png new file mode 100644 index 00000000..2ff13577 Binary files /dev/null and b/docs/source/algorithms/hiclass-uml.png differ diff --git a/docs/source/algorithms/index.rst b/docs/source/algorithms/index.rst index ce639141..092a6079 100644 --- a/docs/source/algorithms/index.rst +++ b/docs/source/algorithms/index.rst @@ -16,3 +16,4 @@ HiClass provides implementations for the most popular machine learning models fo local_classifier_per_level multi_label metrics + explainer diff --git a/docs/source/algorithms/platypus_diseases_hierarchy.svg b/docs/source/algorithms/platypus_diseases_hierarchy.svg new file mode 100644 index 00000000..122a5e7e --- /dev/null +++ b/docs/source/algorithms/platypus_diseases_hierarchy.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/source/algorithms/shap_explanation.png b/docs/source/algorithms/shap_explanation.png new file mode 100644 index 00000000..e255a1d2 Binary files /dev/null and b/docs/source/algorithms/shap_explanation.png differ diff --git a/docs/source/api/explainer_api.rst b/docs/source/api/explainer_api.rst new file mode 100644 index 00000000..1cadc303 --- /dev/null +++ b/docs/source/api/explainer_api.rst @@ -0,0 +1,10 @@ +.. _explainer_api: + +Explainer +======================== + +Explainer +----------------------- +.. autoclass:: Explainer.Explainer + :members: + :special-members: __init__ \ No newline at end of file diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 379729a5..d7ba18bf 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -13,3 +13,4 @@ This is done in order to provide a complete list of the callable functions for e classifiers utilities + explainer_api diff --git a/docs/source/get_started/install.rst b/docs/source/get_started/install.rst index 0fdaa61b..7ed92400 100644 --- a/docs/source/get_started/install.rst +++ b/docs/source/get_started/install.rst @@ -7,6 +7,17 @@ To install HiClass from the Python Package Index (PyPI) simply run: pip install hiclass +Additionally, it is also possible to install optional packages along. To install optional packages run: + +.. code-block:: bash + + pip install hiclass"[]" + +:literal:`` can have one of the following options: + +- ray: Installs the ray package, which is required for parallel processing support. +- xai: Installs the shap and xarray packages, which are required for explaining Hiclass' predictions. + It is also possible to install HiClass using :literal:`conda`, as follows: .. code-block:: bash diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py new file mode 100644 index 00000000..36510256 --- /dev/null +++ b/hiclass/Explainer.py @@ -0,0 +1,287 @@ +"""Explainer API for explaining predictions using shapley values.""" + +from copy import deepcopy +from joblib import Parallel, delayed +import numpy as np +from sklearn.utils.validation import check_array, check_is_fitted +from hiclass import ( + LocalClassifierPerParentNode, + LocalClassifierPerNode, + LocalClassifierPerLevel, + HierarchicalClassifier, +) + +try: + import xarray as xr +except ImportError: + xarray_installed = False +else: + xarray_installed = True + +try: + import shap +except ImportError: + shap_installed = False +else: + shap_installed = True + + +class Explainer: + """Explainer class for returning shap values for each of the three hierarchical classifiers.""" + + def __init__( + self, + hierarchical_model: HierarchicalClassifier.HierarchicalClassifier, + data: None, + n_jobs: int = 1, + algorithm: str = "auto", + mode: str = "", + ): + """ + Initialize the SHAP explainer for a hierarchical model. + + Parameters + ---------- + hierarchical_model : HierarchicalClassifier + The hierarchical classification model to explain. + data : array-like or None, default=None + The dataset used for creating the SHAP explainer. + n_jobs : int, default=1 + The number of jobs to run in parallel. + algorithm : str, default="auto" + The algorithm to use for SHAP explainer. Possible values are 'linear', 'tree', 'auto', 'permutation', or 'partition' + mode : str, default="" + The mode of the SHAP explainer. Can be 'tree', 'gradient', 'deep', 'linear', or '' for default SHAP explainer. + + Examples + -------- + >>> from sklearn.ensemble import RandomForestClassifier + >>> import numpy as np + >>> from hiclass import LocalClassifierPerParentNode, Explainer + >>> rfc = RandomForestClassifier() + >>> lcppn = LocalClassifierPerParentNode(local_classifier=rfc, replace_classifiers=False) + >>> x_train = np.array([[1, 3], [2, 5]]) + >>> y_train = np.array([[1, 2], [3, 4]]) + >>> x_test = np.array([[4, 6]]) + >>> lcppn.fit(x_train, y_train) + >>> explainer = Explainer(lcppn, data=x_train, mode="tree") + >>> explanations = explainer.explain(x_test) + + Dimensions: (class: 3, sample: 1, level: 2, feature: 2) + Coordinates: + * class (class) 0: + successors = list( + self.hierarchical_model.hierarchy_.successors(predecessor) + ) + if len(successors) > 0: + classifier = self.hierarchical_model.hierarchy_.nodes[ + predecessor + ]["classifier"] + traversals[mask, level] = classifier.predict( + predecessor_x + ).flatten() + return traversals + + def _calculate_shap_values(self, X): + """ + Return an xarray.Dataset object for a single sample provided. This dataset is aligned on the `level` attribute. + + Parameters + ---------- + X : array-like + Data for single sample for which to generate SHAP values. + + Returns + ------- + explanation : xarray.Dataset + A single explanation for the prediction of given sample. + """ + traversed_nodes = [] + if isinstance(self.hierarchical_model, LocalClassifierPerParentNode): + traversed_nodes = self._get_traversed_nodes_lcppn(X)[0] + datasets = [] + level = 0 + for node in traversed_nodes: + # Skip if classifier is not found, can happen in case of imbalanced hierarchies + if "classifier" not in self.hierarchical_model.hierarchy_.nodes[node]: + continue + + local_classifier = self.hierarchical_model.hierarchy_.nodes[node][ + "classifier" + ] + + # Create a SHAP explainer for the local classifier + local_explainer = deepcopy(self.explainer)(local_classifier, self.data) + + current_node = node.split(self.hierarchical_model.separator_)[-1] + + # Calculate SHAP values for the given sample X + shap_values = np.array( + local_explainer.shap_values(X, check_additivity=False) + ) + + if len(shap_values.shape) < 3: + shap_values = shap_values.reshape( + 1, shap_values.shape[0], shap_values.shape[1] + ) + + if isinstance(self.hierarchical_model, LocalClassifierPerNode): + simplified_labels = [ + f"{current_node}_{int(label)}" + for label in local_classifier.classes_ + ] + predicted_class = current_node + else: + simplified_labels = [ + label.split(self.hierarchical_model.separator_)[-1] + for label in local_classifier.classes_ + ] + predicted_class = ( + local_classifier.predict(X) + .flatten()[0] + .split(self.hierarchical_model.separator_)[-1] + ) + + classes = xr.DataArray( + simplified_labels, + dims=["class"], + coords={"class": simplified_labels}, + ) + + shap_val_local = xr.DataArray( + shap_values, + dims=["class", "sample", "feature"], + coords={"class": simplified_labels}, + ) + + prediction_probability = local_classifier.predict_proba(X)[0] + + predict_proba = xr.DataArray( + prediction_probability, + dims=["class"], + coords={ + "class": simplified_labels, + }, + ) + + local_dataset = xr.Dataset( + { + "node": current_node, + "predicted_class": predicted_class, + "predict_proba": predict_proba, + "classes": classes, + "shap_values": shap_val_local, + "level": level, + } + ) + level = level + 1 + datasets.append(local_dataset) + sample_explanation = xr.concat(datasets, dim="level") + return sample_explanation diff --git a/hiclass/__init__.py b/hiclass/__init__.py index 09370436..ec3db00e 100644 --- a/hiclass/__init__.py +++ b/hiclass/__init__.py @@ -1,5 +1,7 @@ """Init module for the library.""" +import os +from ._version import get_versions from .LocalClassifierPerLevel import LocalClassifierPerLevel from .LocalClassifierPerNode import LocalClassifierPerNode from .LocalClassifierPerParentNode import LocalClassifierPerParentNode @@ -7,6 +9,7 @@ from .MultiLabelLocalClassifierPerParentNode import ( MultiLabelLocalClassifierPerParentNode, ) +from .Explainer import Explainer from ._version import get_versions __version__ = get_versions()["version"] @@ -16,6 +19,7 @@ "LocalClassifierPerNode", "LocalClassifierPerParentNode", "LocalClassifierPerLevel", + "Explainer", "MultiLabelLocalClassifierPerNode", "MultiLabelLocalClassifierPerParentNode", ] diff --git a/setup.py b/setup.py index b2a49b97..eb560e2f 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ # 'fancy feature': ['django'],} EXTRAS = { "ray": ["ray>=1.11.0"], - "xai": ["shap", "xarray"], + "xai": ["shap==0.44.1", "xarray==2023.1.0"], "dev": [ "flake8==4.0.1", "pytest==7.1.2", @@ -45,6 +45,8 @@ "black==24.2.0", "pre-commit==2.20.0", "ray", + "shap==0.44.1", + "xarray==2023.1.0", ], } diff --git a/tests/test_Explainer.py b/tests/test_Explainer.py new file mode 100644 index 00000000..c4af7a30 --- /dev/null +++ b/tests/test_Explainer.py @@ -0,0 +1,148 @@ +import numpy as np +import pytest +from sklearn.ensemble import RandomForestClassifier +from hiclass import ( + LocalClassifierPerParentNode, + Explainer, +) + +try: + import shap +except ImportError: + shap_installed = False +else: + shap_installed = True + +try: + import xarray +except ImportError: + xarray_installed = False +else: + xarray_installed = True + + +@pytest.fixture +def explainer_data(): + x_train = np.random.randn(4, 3) + y_train = np.array( + [["a", "b", "d"], ["a", "b", "e"], ["a", "c", "f"], ["a", "c", "g"]] + ) + x_test = np.random.randn(5, 3) + + return x_train, x_test, y_train + + +@pytest.fixture +def explainer_data_no_root(): + x_train = np.random.randn(6, 3) + y_train = np.array( + [ + ["a", "b", "c"], + ["x", "y", "z"], + ["a", "b", "c"], + ["x", "y", "z"], + ["a", "b", "c"], + ["x", "y", "z"], + ] + ) + x_test = np.random.randn(5, 3) + return x_train, x_test, y_train + + +@pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) +def test_explainer_tree_lcppn(data, request): + rfc = RandomForestClassifier() + lcppn = LocalClassifierPerParentNode( + local_classifier=rfc, replace_classifiers=False + ) + + x_train, x_test, y_train = request.getfixturevalue(data) + + lcppn.fit(x_train, y_train) + + explainer = Explainer(lcppn, data=x_train, mode="tree") + explanations = explainer.explain(x_test) + + # Assert if explainer returns an xarray.Dataset object + assert isinstance(explanations, xarray.Dataset) + + # Assert if predictions made are consistent with the explanation object + y_preds = lcppn.predict(x_test) + for i in range(len(x_test)): + y_pred = y_preds[i] + explanation = explanations["predicted_class"][i] + for j in range(len(y_pred)): + assert explanation.data[j].split(lcppn.separator_)[-1] == y_pred[j] + + +@pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) +def test_traversal_path_lcppn(data, request): + x_train, x_test, y_train = request.getfixturevalue(data) + rfc = RandomForestClassifier() + lcppn = LocalClassifierPerParentNode( + local_classifier=rfc, replace_classifiers=False + ) + + lcppn.fit(x_train, y_train) + explainer = Explainer(lcppn, data=x_train, mode="tree") + traversals = explainer._get_traversed_nodes_lcppn(x_test) + preds = lcppn.predict(x_test) + assert len(preds) == len(traversals) + for i in range(len(x_test)): + for j in range(len(traversals[i])): + if traversals[i][j] == lcppn.root_: + continue + label = traversals[i][j].split(lcppn.separator_)[-1] + assert label == preds[i][j - 1] + + +@pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") +@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) +@pytest.mark.parametrize("classifier", [LocalClassifierPerParentNode]) +def test_explain_with_xr(data, request, classifier): + x_train, x_test, y_train = request.getfixturevalue(data) + rfc = RandomForestClassifier() + clf = classifier(local_classifier=rfc, replace_classifiers=False) + + clf.fit(x_train, y_train) + explainer = Explainer(clf, data=x_train, mode="tree") + explanations = explainer._explain_with_xr(x_test) + + # Assert if explainer returns an xarray.Dataset object + assert isinstance(explanations, xarray.Dataset) + + +@pytest.mark.parametrize("classifier", [LocalClassifierPerParentNode]) +def test_imports(classifier): + x_train = [[76, 12, 49], [88, 63, 31], [5, 42, 24], [17, 90, 55]] + y_train = [["a", "b", "d"], ["a", "b", "e"], ["a", "c", "f"], ["a", "c", "g"]] + + rfc = RandomForestClassifier() + clf = classifier(local_classifier=rfc, replace_classifiers=False) + clf.fit(x_train, y_train) + + explainer = Explainer(clf, data=x_train, mode="tree") + assert isinstance(explainer.data, np.ndarray) + + +@pytest.mark.parametrize("classifier", [LocalClassifierPerParentNode]) +@pytest.mark.parametrize("data", ["explainer_data"]) +@pytest.mark.parametrize("mode", ["linear", "gradient", "deep", "tree", ""]) +def test_explainers(data, request, classifier, mode): + x_train, x_test, y_train = request.getfixturevalue(data) + rfc = RandomForestClassifier() + clf = classifier(local_classifier=rfc, replace_classifiers=False) + + clf.fit(x_train, y_train) + explainer = Explainer(clf, data=x_train, mode=mode) + mode_mapping = { + "linear": shap.LinearExplainer, + "gradient": shap.GradientExplainer, + "deep": shap.DeepExplainer, + "tree": shap.TreeExplainer, + "": shap.Explainer, + } + assert explainer.explainer == mode_mapping[mode] diff --git a/tests/test_LocalClassifierPerNode.py b/tests/test_LocalClassifierPerNode.py index 670c9823..fb3a0ebb 100644 --- a/tests/test_LocalClassifierPerNode.py +++ b/tests/test_LocalClassifierPerNode.py @@ -11,6 +11,7 @@ from sklearn.utils.validation import check_is_fitted from hiclass import LocalClassifierPerNode + from hiclass.BinaryPolicy import ExclusivePolicy diff --git a/tests/test_LocalClassifierPerParentNode.py b/tests/test_LocalClassifierPerParentNode.py index 922a03a3..9d01d605 100644 --- a/tests/test_LocalClassifierPerParentNode.py +++ b/tests/test_LocalClassifierPerParentNode.py @@ -10,7 +10,6 @@ from sklearn.linear_model import LogisticRegression from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils.validation import check_is_fitted - from hiclass import LocalClassifierPerParentNode