From 5af9efbd7aca7b7243efecec57edb552ee140327 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Sun, 14 Apr 2024 20:17:14 +0200 Subject: [PATCH 01/31] basic implementation of filter_by_level function --- hiclass/Explainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 8708527b..3893f0b2 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -365,3 +365,12 @@ def _calculate_shap_values(self, X): datasets.append(local_dataset) sample_explanation = xr.concat(datasets, dim="level") return sample_explanation + + def filter_by_level(self, explanations, level): + """ + TODO + Returns Shapley_values filters by given level in the hierarchy + """ + filter_by_level = {"level": level} + filtered_explanations = explanations.sel(**filter_by_level) + return filtered_explanations From 5e52e861a8dd9a515aeb33e90c234c92d6d07147 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Sun, 14 Apr 2024 20:18:40 +0200 Subject: [PATCH 02/31] basic implementation of filter_by_class function --- hiclass/Explainer.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 3893f0b2..4e69e947 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -368,9 +368,29 @@ def _calculate_shap_values(self, X): def filter_by_level(self, explanations, level): """ - TODO - Returns Shapley_values filters by given level in the hierarchy + TODO: add docstring """ filter_by_level = {"level": level} filtered_explanations = explanations.sel(**filter_by_level) return filtered_explanations + + def filter_by_class(self, explanations, class_name): + """ + TODO: add docstring + Filters explanations for the given class + + Parameters + __________ + explanations: xarray.Dataset + explanations generated by Explainer + class_name: string + Name of class + __________ + """ + filter_by_class = {"class": class_name} + filtered_explanations = explanations.sel(**filter_by_class) + return filtered_explanations + + + + From 5a8e9bebb0076d19156cbb85a61b287e2b4e4ae0 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Sun, 14 Apr 2024 20:19:32 +0200 Subject: [PATCH 03/31] basic implementation of combine_filters function --- hiclass/Explainer.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 4e69e947..2ba50eef 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -391,6 +391,27 @@ def filter_by_class(self, explanations, class_name): filtered_explanations = explanations.sel(**filter_by_class) return filtered_explanations + def combine_filters( + self, explanations, level=None, class_name=None, sample_indices=None + ): + """ + TOOD: add docstring + """ + shap_filter = dict() + if class_name is not None: + shap_filter["class"] = class_name + if level is not None: + shap_filter["level"] = level + else: + level = new_level(self.hierarchical_model, class_name) + shap_filter["level"] = level + if sample_indices is not None: + shap_filter["sample"] = sample_indices + + filtered_explanations = explanations.sel(**shap_filter) + filtered_shap_values = filtered_explanations.shap_values.values + + return filtered_shap_values From cceb6d1686d661cb6bd66df8a50e99f9aa98bb81 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Sun, 14 Apr 2024 20:20:03 +0200 Subject: [PATCH 04/31] codestyling --- hiclass/Explainer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 2ba50eef..83075931 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -392,7 +392,7 @@ def filter_by_class(self, explanations, class_name): return filtered_explanations def combine_filters( - self, explanations, level=None, class_name=None, sample_indices=None + self, explanations, level=None, class_name=None, sample_indices=None ): """ TOOD: add docstring @@ -412,6 +412,3 @@ def combine_filters( filtered_shap_values = filtered_explanations.shap_values.values return filtered_shap_values - - - From ab1e3abe04cfe786734f31f24441d53e4adc9126 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Sun, 14 Apr 2024 20:25:26 +0200 Subject: [PATCH 05/31] helper function get_class_level added --- hiclass/Explainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 83075931..d1c54c84 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -412,3 +412,12 @@ def combine_filters( filtered_shap_values = filtered_explanations.shap_values.values return filtered_shap_values + + def get_class_level(self, class_name): + """ + TODO: add docstring + """ + for node in classifier.hierarchy_.nodes: + if class_name in node: + node_classes = node.split(classifier.separator_) + return node_classes.index(class_name) From 719dd83c09dd5f987fd5ced3ca2ac46db91ed954 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Sun, 14 Apr 2024 20:32:33 +0200 Subject: [PATCH 06/31] another helper functions added --- hiclass/Explainer.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index d1c54c84..10647186 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -414,10 +414,17 @@ def combine_filters( return filtered_shap_values def get_class_level(self, class_name): - """ - TODO: add docstring - """ - for node in classifier.hierarchy_.nodes: - if class_name in node: - node_classes = node.split(classifier.separator_) - return node_classes.index(class_name) + """ + TODO: add docstring + """ + for node in classifier.hierarchy_.nodes: + if class_name in node: + node_classes = node.split(classifier.separator_) + return node_classes.index(class_name) + + def get_sample_indices(self, predictions, class_name): + """ + TODO: add docstring + """ + class_level = self.get_class_level(predictions, class_name) + return predictions[:, class_level] == class_name From 343be048a3e279964f790336048713b71929f5e4 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Sun, 14 Apr 2024 21:03:48 +0200 Subject: [PATCH 07/31] some bugs fixed --- hiclass/Explainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 10647186..b26c1236 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -403,7 +403,7 @@ def combine_filters( if level is not None: shap_filter["level"] = level else: - level = new_level(self.hierarchical_model, class_name) + level = self.get_class_level(class_name) shap_filter["level"] = level if sample_indices is not None: shap_filter["sample"] = sample_indices @@ -417,6 +417,7 @@ def get_class_level(self, class_name): """ TODO: add docstring """ + classifier = self.hierarchical_model for node in classifier.hierarchy_.nodes: if class_name in node: node_classes = node.split(classifier.separator_) @@ -426,5 +427,5 @@ def get_sample_indices(self, predictions, class_name): """ TODO: add docstring """ - class_level = self.get_class_level(predictions, class_name) + class_level = self.get_class_level(class_name) return predictions[:, class_level] == class_name From af32acb8d29bf2da5b08f63d80797b3570bb6a67 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Sun, 14 Apr 2024 21:05:45 +0200 Subject: [PATCH 08/31] small plot_lcpl_explainer actualisation --- docs/examples/plot_lcpl_explainer.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/examples/plot_lcpl_explainer.py b/docs/examples/plot_lcpl_explainer.py index d085c791..283e76b1 100644 --- a/docs/examples/plot_lcpl_explainer.py +++ b/docs/examples/plot_lcpl_explainer.py @@ -31,12 +31,17 @@ # Let's filter the Shapley values corresponding to the Covid (level 1) # and 'Respiratory' (level 0) -covid_idx = classifier.predict(X_test)[:, 1] == "Covid" +predictions = classifier.predict(X_test) +covid_lvl = explainer.get_class_level("Covid") +covid_idx = explainer.get_sample_indices(predictions, "Covid") -shap_filter_covid = {"level": 1, "class": "Covid", "sample": covid_idx} -shap_filter_resp = {"level": 0, "class": "Respiratory", "sample": covid_idx} -shap_val_covid = explanations.sel(**shap_filter_covid) -shap_val_resp = explanations.sel(**shap_filter_resp) + +shap_val_covid = explainer.combine_filters( + explanations, class_name="Covid", sample_indices=covid_idx +) +shap_val_resp = explainer.combine_filters( + explanations, class_name="Respiratory", sample_indices=covid_idx +) # This code snippet demonstrates how to visually compare the mean absolute SHAP values for 'Covid' vs. 'Respiratory' diseases. @@ -45,10 +50,10 @@ feature_names = X_train.columns.values # SHAP values for 'Covid' -shap_values_covid = shap_val_covid.shap_values.values +shap_values_covid = shap_val_covid # SHAP values for 'Respiratory' -shap_values_resp = shap_val_resp.shap_values.values +shap_values_resp = shap_val_resp shap.summary_plot( [shap_values_covid, shap_values_resp], From 3645e38f1dffb9d7ce3cbf17dd7b9d4aec7d26bf Mon Sep 17 00:00:00 2001 From: dniprocat Date: Sun, 14 Apr 2024 21:08:28 +0200 Subject: [PATCH 09/31] some small changes in plot_lcpl_explainer --- docs/examples/plot_lcpl_explainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/examples/plot_lcpl_explainer.py b/docs/examples/plot_lcpl_explainer.py index 283e76b1..9e80f7ca 100644 --- a/docs/examples/plot_lcpl_explainer.py +++ b/docs/examples/plot_lcpl_explainer.py @@ -28,8 +28,7 @@ explanations = explainer.explain(X_test.values) print(explanations) -# Let's filter the Shapley values corresponding to the Covid (level 1) -# and 'Respiratory' (level 0) +# Since Covid is a kind of Respiratory diseases, let's filter explanations for these classes predictions = classifier.predict(X_test) covid_lvl = explainer.get_class_level("Covid") From d1cf3e0f1dfd972ceccaaf7d08167874d23f22d4 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Sun, 14 Apr 2024 21:09:32 +0200 Subject: [PATCH 10/31] another changes in plot_lcpl_explainer --- docs/examples/plot_lcpl_explainer.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/docs/examples/plot_lcpl_explainer.py b/docs/examples/plot_lcpl_explainer.py index 9e80f7ca..f20bf11d 100644 --- a/docs/examples/plot_lcpl_explainer.py +++ b/docs/examples/plot_lcpl_explainer.py @@ -35,10 +35,10 @@ covid_idx = explainer.get_sample_indices(predictions, "Covid") -shap_val_covid = explainer.combine_filters( +shap_values_covid = explainer.combine_filters( explanations, class_name="Covid", sample_indices=covid_idx ) -shap_val_resp = explainer.combine_filters( +shap_values_resp = explainer.combine_filters( explanations, class_name="Respiratory", sample_indices=covid_idx ) @@ -48,12 +48,6 @@ # Feature names for the X-axis feature_names = X_train.columns.values -# SHAP values for 'Covid' -shap_values_covid = shap_val_covid - -# SHAP values for 'Respiratory' -shap_values_resp = shap_val_resp - shap.summary_plot( [shap_values_covid, shap_values_resp], features=X_test.iloc[covid_idx], From 7e6bcb3099f85e81d7aa04519e7c5207c17a05f1 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 01:49:22 +0200 Subject: [PATCH 11/31] functionality duplication removes + docstrings added (partially) --- hiclass/Explainer.py | 71 +++++++++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index b26c1236..f309a30b 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -368,43 +368,33 @@ def _calculate_shap_values(self, X): def filter_by_level(self, explanations, level): """ - TODO: add docstring - """ - filter_by_level = {"level": level} - filtered_explanations = explanations.sel(**filter_by_level) - return filtered_explanations - - def filter_by_class(self, explanations, class_name): - """ - TODO: add docstring - Filters explanations for the given class + Returns the explanations filtered by the given level. Parameters __________ - explanations: xarray.Dataset - explanations generated by Explainer - class_name: string - Name of class - __________ + explanations : xarray.DataArray + The explanations to filter + level : int + level in the hierarchy to filter + + Returns + _______ + filtered_explanations : xarray.DataArray + Explanations filtered by the given level """ - filter_by_class = {"class": class_name} - filtered_explanations = explanations.sel(**filter_by_class) + filter_by_level = {"level": level} + filtered_explanations = explanations.sel(**filter_by_level) return filtered_explanations - def combine_filters( - self, explanations, level=None, class_name=None, sample_indices=None - ): + def filter_by_class(self, explanations, class_name, sample_indices=None): """ - TOOD: add docstring + TODO: add docstring """ shap_filter = dict() - if class_name is not None: - shap_filter["class"] = class_name - if level is not None: - shap_filter["level"] = level - else: - level = self.get_class_level(class_name) - shap_filter["level"] = level + shap_filter["class"] = class_name + level = self.get_class_level(class_name) + shap_filter["level"] = level + if sample_indices is not None: shap_filter["sample"] = sample_indices @@ -415,7 +405,17 @@ def combine_filters( def get_class_level(self, class_name): """ - TODO: add docstring + Returns level of the class in the hierarchy. + + Parameters + __________ + class_name : str + Name of the class + + Returns + _______ + class_level : int + Level of the class in the hierarchy """ classifier = self.hierarchical_model for node in classifier.hierarchy_.nodes: @@ -425,7 +425,18 @@ def get_class_level(self, class_name): def get_sample_indices(self, predictions, class_name): """ - TODO: add docstring + Returns indices of samples corresponding to the certain class + + Parameters + __________ + predictions: array-like + Array of predictions of the hierarchical classificator + class_name: str + Name of class + + Returns + _______ + sample_indices: boolean array of indices """ class_level = self.get_class_level(class_name) return predictions[:, class_level] == class_name From 31e88b27be2a637bb6e693de061925b59f85f950 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 01:50:16 +0200 Subject: [PATCH 12/31] some actualization in plot_lcpl_explainer --- docs/examples/plot_lcpl_explainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/plot_lcpl_explainer.py b/docs/examples/plot_lcpl_explainer.py index f20bf11d..6921d6d1 100644 --- a/docs/examples/plot_lcpl_explainer.py +++ b/docs/examples/plot_lcpl_explainer.py @@ -35,10 +35,10 @@ covid_idx = explainer.get_sample_indices(predictions, "Covid") -shap_values_covid = explainer.combine_filters( +shap_values_covid = explainer.filter_by_class( explanations, class_name="Covid", sample_indices=covid_idx ) -shap_values_resp = explainer.combine_filters( +shap_values_resp = explainer.filter_by_class( explanations, class_name="Respiratory", sample_indices=covid_idx ) From 849b30f1eca4f37d5f0b064a84fc56348bea3ed1 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 01:56:22 +0200 Subject: [PATCH 13/31] some refactoring --- hiclass/Explainer.py | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index f309a30b..a85feaf2 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -388,20 +388,39 @@ def filter_by_level(self, explanations, level): def filter_by_class(self, explanations, class_name, sample_indices=None): """ - TODO: add docstring - """ - shap_filter = dict() - shap_filter["class"] = class_name - level = self.get_class_level(class_name) - shap_filter["level"] = level + Filters SHAP values based on a specified class and optionally by sample indices. + + This function filters the provided explanations data array to return SHAP values + for a specific class, determined both by its name and the level it appears in + the hierarchy, with an option to further filter by specific sample indices. + + Parameters + ---------- + explanations : xarray.DataArray + An array of explanations, where dimensions include 'class', 'level', and 'sample'. + class_name : str + The name of the class to filter the explanations by. + sample_indices : list of int, optional + A list of integer indices specifying which samples to include in the filter. + If None, no sample-based filtering is applied. + Returns + ------- + numpy.ndarray + An array of SHAP values filtered according to the specified class and optionally + by the provided sample indices. + + Example + ------- + # Assuming `explanations` is an xarray.DataArray with the proper dimensions: + shap_values = filter_by_class(explanations, 'Dog', sample_indices=[0, 2, 5]) + """ + shap_filter = {"class": class_name, "level": self.get_class_level(class_name)} if sample_indices is not None: shap_filter["sample"] = sample_indices filtered_explanations = explanations.sel(**shap_filter) - filtered_shap_values = filtered_explanations.shap_values.values - - return filtered_shap_values + return filtered_explanations.shap_values.values def get_class_level(self, class_name): """ @@ -421,7 +440,8 @@ def get_class_level(self, class_name): for node in classifier.hierarchy_.nodes: if class_name in node: node_classes = node.split(classifier.separator_) - return node_classes.index(class_name) + class_level = node_classes.index(class_name) + return class_level def get_sample_indices(self, predictions, class_name): """ From 9877374efd82b61fa266fc9f008bad3af5ea6a42 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 01:58:07 +0200 Subject: [PATCH 14/31] some refactoring --- hiclass/Explainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index a85feaf2..ce6ff71a 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -394,6 +394,8 @@ def filter_by_class(self, explanations, class_name, sample_indices=None): for a specific class, determined both by its name and the level it appears in the hierarchy, with an option to further filter by specific sample indices. + As long as class can belong to one level only, the function also provides filtration + by level which is computed based on the class location in the hierarchy. Parameters ---------- explanations : xarray.DataArray From 38a90602b8bf1f991efacb220ac91a8a2387129d Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 05:33:40 +0200 Subject: [PATCH 15/31] some cases handled + tests written --- hiclass/Explainer.py | 60 ++++++++++++++++++++++++++++++----------- tests/test_Explainer.py | 43 +++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 15 deletions(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index ce6ff71a..550490cc 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -379,8 +379,35 @@ def filter_by_level(self, explanations, level): Returns _______ - filtered_explanations : xarray.DataArray + filtered_explanations : xarray.Dataset Explanations filtered by the given level + + Examples + -------- + >>> from sklearn.ensemble import RandomForestClassifier + >>> import numpy as np + >>> from hiclass import LocalClassifierPerParentNode, Explainer + >>> rfc = RandomForestClassifier() + >>> lcppn = LocalClassifierPerParentNode(local_classifier=rfc, replace_classifiers=False) + >>> x_train = np.array([[1, 3], [2, 5]]) + >>> y_train = np.array([[1, 2], [3, 4]]) + >>> x_test = np.array([[4, 6]]) + >>> lcppn.fit(x_train, y_train) + >>> explainer = Explainer(lcppn, data=x_train, mode="tree") + >>> explanations = explainer.explain(x_test) + >>> explanations_level_1 = explainer.filter_by_level(explanations, level=1) + + Dimensions: (class: 3, sample: 1, feature: 2) + Coordinates: + * class (class) Date: Mon, 15 Apr 2024 05:34:10 +0200 Subject: [PATCH 16/31] small changes --- docs/examples/plot_lcpl_explainer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/docs/examples/plot_lcpl_explainer.py b/docs/examples/plot_lcpl_explainer.py index 6921d6d1..a8b95d9c 100644 --- a/docs/examples/plot_lcpl_explainer.py +++ b/docs/examples/plot_lcpl_explainer.py @@ -28,18 +28,20 @@ explanations = explainer.explain(X_test.values) print(explanations) -# Since Covid is a kind of Respiratory diseases, let's filter explanations for these classes - +# Predict predictions = classifier.predict(X_test) -covid_lvl = explainer.get_class_level("Covid") -covid_idx = explainer.get_sample_indices(predictions, "Covid") +# Since Covid is a kind of Respiratory diseases, let's filter explanations for these classes + +# Let's get sample indices where 'Covid' is predicted what can be done with .get_sample_indices() method +sample_idx = explainer.get_sample_indices(predictions, "Covid") +# Shapley values filtering by classes with .filter_by_class() method shap_values_covid = explainer.filter_by_class( - explanations, class_name="Covid", sample_indices=covid_idx + explanations, class_name="Covid", sample_indices=sample_idx ) shap_values_resp = explainer.filter_by_class( - explanations, class_name="Respiratory", sample_indices=covid_idx + explanations, class_name="Respiratory", sample_indices=sample_idx ) @@ -50,7 +52,7 @@ shap.summary_plot( [shap_values_covid, shap_values_resp], - features=X_test.iloc[covid_idx], + features=X_test.iloc[sample_idx], feature_names=X_train.columns.values, plot_type="bar", class_names=["Covid", "Respiratory"], From 5faf54be5e16c6a546bfa53d6d55feccba0d5d61 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 06:04:07 +0200 Subject: [PATCH 17/31] pydocstyle --- hiclass/Explainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 550490cc..32394ac3 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -368,7 +368,7 @@ def _calculate_shap_values(self, X): def filter_by_level(self, explanations, level): """ - Returns the explanations filtered by the given level. + Return the explanations filtered by the given level. Parameters __________ @@ -415,7 +415,7 @@ def filter_by_level(self, explanations, level): def filter_by_class(self, explanations, class_name, sample_indices=None): """ - Filters SHAP values based on a specified class and optionally by sample indices. + Filter SHAP values based on a specified class and optionally by sample indices. This function filters the provided explanations data array to return SHAP values for a specific class, determined both by its name and the level it appears in @@ -454,7 +454,7 @@ def filter_by_class(self, explanations, class_name, sample_indices=None): def get_class_level(self, class_name): """ - Returns level of the class in the hierarchy. + Return level of the class in the hierarchy. Parameters __________ @@ -477,7 +477,7 @@ def get_class_level(self, class_name): def get_sample_indices(self, predictions, class_name): """ - Returns indices of predictions corresponding to the certain class + Return indices of predictions corresponding to the certain class. Parameters __________ From c9e6aa205a83f9192e8e1b53b1ac68bfcc7aa993 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 06:23:36 +0200 Subject: [PATCH 18/31] documentation fixed --- hiclass/Explainer.py | 53 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 32394ac3..1d3b529d 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -425,8 +425,8 @@ def filter_by_class(self, explanations, class_name, sample_indices=None): by level which is computed based on the class location in the hierarchy. Parameters ---------- - explanations : xarray.DataArray - An array of explanations, where dimensions include 'class', 'level', and 'sample'. + explanations : xarray.Dataset + A dataset of explanations, where dimensions include 'class', 'level', and 'sample'. class_name : str or int The name of the class to filter the explanations by. sample_indices : list of boolean, optional @@ -438,18 +438,54 @@ def filter_by_class(self, explanations, class_name, sample_indices=None): numpy.ndarray An array of SHAP values filtered according to the specified class and optionally by the provided sample indices. + + Examples + ________ + >>> from sklearn.ensemble import RandomForestClassifier + >>> import numpy as np + >>> from hiclass import LocalClassifierPerParentNode, Explainer + >>> rfc = RandomForestClassifier() + >>> lcppn = LocalClassifierPerParentNode(local_classifier=rfc, replace_classifiers=False) + >>> x_train = np.array([[1, 3], [2, 5]]) + >>> y_train = np.array([[1, 2], [3, 4]]) + >>> x_test = np.array([[4, 6]]) + >>> lcppn.fit(x_train, y_train) + >>> predictions = lcppn.predict(x_test) + >>> explainer = Explainer(lcppn, data=x_train, mode="tree") + >>> explanations = explainer.explain(x_test) + >>> filtered_shap = explainer.filter_by_class(explanations, level=3) + >>> print(filtered_shap) + [['3' '4']] + [[0.1 0.105]] """ - # Creating filter for the class + # Ensure that explanations are provided and have the expected structure + if not isinstance(explanations, xr.Dataset): + raise ValueError("Explanations should be an xarray.Dataset!") + + # Converting class_name to the string format class_name = str(class_name) + + # Define level level = self.get_class_level(str(class_name)) + # Handling with LocalClassifierPerNode case if isinstance(self.hierarchical_model, LocalClassifierPerNode): - class_name = str(class_name) + "_1" + class_name = f"{class_name}_1" + + # Shap filter shap_filter = {"class": class_name, "level": level} if sample_indices is not None: shap_filter["sample"] = sample_indices - filtered_explanations = explanations.sel(**shap_filter) + # Select the SHAP values according to the filter and handle possible errors + try: + filtered_explanations = explanations.sel(**shap_filter) + except KeyError as e: + raise KeyError( + f"Class name {class_name} with level {level} not found." + ) from e + + # Return the selected SHAP values as a NumPy array return filtered_explanations.shap_values.values def get_class_level(self, class_name): @@ -465,9 +501,16 @@ def get_class_level(self, class_name): _______ class_level : int Level of the class in the hierarchy + + """ + # Set the classifier classifier = self.hierarchical_model + + # Converting class_name to the string formatn class_name = str(class_name) + + # Iterating through the nodes of hierarchy for node in classifier.hierarchy_.nodes: if class_name in node.split(classifier.separator_): node_classes = node.split(classifier.separator_) From 3d2bf3b382013c8c42a8e72fd146fc2c7e485f28 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 07:08:40 +0200 Subject: [PATCH 19/31] Algorithm overviem completed --- docs/source/algorithms/explainer.rst | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/source/algorithms/explainer.rst b/docs/source/algorithms/explainer.rst index b87ced9c..3ed95aa9 100644 --- a/docs/source/algorithms/explainer.rst +++ b/docs/source/algorithms/explainer.rst @@ -128,4 +128,23 @@ To achieve this, we can use xarray's :literal:`.sel()` method: mask = {'class': lcppn.predict(x_test).flatten()[:-1]} x = explanations.sel(mask).shap_values +Also, we developed some helper functions built in the Explainer class as its methods to simplify standard explanation manipulation and filtering such as filtering explanations by level, or filtering explanations by class and returning its Shapley values. A basic example below is a continuation of the example from the beginning of this section: + +.. code-block:: python + + predictions = lcppn.predict(x_test) + # Get the correcponding samples + covid_idx = explainer.get_sample_indices(predictions, 'Covid') + + # Filter the shap values + shap_values_covid = explainer.filter_by_class(explanations, 'Covid', covid_idx) + print(shap_values_covid) + + # Filter explanations by level + level = 1 + explanations_level_1 = explainer.filter_by_level(explanations, level) + print(explanations_level_1) + + + More advanced usage and capabilities can be found at the `Xarray.Dataset `_ documentation. From a2e7700aac4765ab2d1553013fe082a5df72bda4 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 08:22:17 +0200 Subject: [PATCH 20/31] plot_lcppn_explainer actualized --- docs/examples/plot_lcppn_explainer.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/examples/plot_lcppn_explainer.py b/docs/examples/plot_lcppn_explainer.py index ab27ce38..d1290a46 100644 --- a/docs/examples/plot_lcppn_explainer.py +++ b/docs/examples/plot_lcppn_explainer.py @@ -25,23 +25,26 @@ # Train local classifier per parent node classifier.fit(X_train, Y_train) +# Get predictions +predictions = classifier.predict(X_test) + # Define Explainer explainer = Explainer(classifier, data=X_train.values, mode="tree") explanations = explainer.explain(X_test.values) print(explanations) -# Filter samples which only predicted "Respiratory" at first level -respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory" - -# Specify additional filters to obtain only level 0 -shap_filter = {"level": 0, "class": "Respiratory", "sample": respiratory_idx} +# Filter samples which only predicted "Respiratory" +respiratory_idx = explainer.get_sample_indices(predictions, "Respiratory") # Use .sel() method to apply the filter and obtain filtered results -shap_val_respiratory = explanations.sel(shap_filter) +shap_val_respiratory = explainer.filter_by_class( + explanations, class_name="Respiratory", sample_indices=respiratory_idx +) + # Plot feature importance on test set shap.plots.violin( - shap_val_respiratory.shap_values, + shap_val_respiratory, feature_names=X_train.columns.values, plot_size=(13, 8), ) From 3b8239fcff2eae70abe1c50949e5305f42c2cb80 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 12:32:10 +0200 Subject: [PATCH 21/31] shap_multi_plot added --- hiclass/Explainer.py | 46 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 1d3b529d..3daf66d1 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -535,3 +535,49 @@ def get_sample_indices(self, predictions, class_name): """ class_level = self.get_class_level(class_name) return predictions[:, class_level] == class_name + + def shap_multi_plot(self, class_names, features, pred_class, features_names=None): + """ + Plot shap_values for multi-class case on a bar and return explanations. + "Lazy" function which does not require any additional actions from the user + apart from classifier fitting and explainer initialization. + + Parameters + __________ + class_names : list of str + A list of class names to calculate and visualize the Shapley values for + features: array-like + Matrix of feature values with shape (# features) or (# samples x # features). + Typically, this would be the test set features (X_test). + pred_class : int or str + The class label that the classifier's predictions must match for a sample to be + included in the subset of data used for SHAP value calculation. If not provided, + no filtering is applied, and all samples are considered. + features_names : list, optional + A list of feature names to include in the bar plot for the shap_values. + + Returns + _______ + explanations: xarray.Dataset + Whole explanations of data in features provided. + """ + classifier = self.hierarchical_model + predictions = classifier.predict(features) + explanations = self.explain(features) + + sample_idx = self.get_sample_indices(predictions, pred_class) + shap_array = [] + for class_name in class_names: + shap_val = self.filter_by_class( + explanations, class_name=class_name, sample_indices=sample_idx + ) + shap_array.append(shap_val) + + shap.summary_plot( + shap_array, + features=features[sample_idx], + feature_names=features_names, + plot_type="bar", + class_names=class_names, + ) + return explanations From c06e5e684dd239e0a9993cf2268872360e20671c Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 12:32:33 +0200 Subject: [PATCH 22/31] part of the code substituted with newer methods --- docs/examples/plot_lcpl_explainer.py | 35 ++++++---------------------- 1 file changed, 7 insertions(+), 28 deletions(-) diff --git a/docs/examples/plot_lcpl_explainer.py b/docs/examples/plot_lcpl_explainer.py index a8b95d9c..b32dec0c 100644 --- a/docs/examples/plot_lcpl_explainer.py +++ b/docs/examples/plot_lcpl_explainer.py @@ -25,35 +25,14 @@ # Define Explainer explainer = Explainer(classifier, data=X_train, mode="tree") -explanations = explainer.explain(X_test.values) -print(explanations) -# Predict -predictions = classifier.predict(X_test) +# Now, our task is to see how feature importance may vary from level to level +# We are going to calculate shap_values for 'Respiratory', 'Covid' and plot what we calculated +# This can be done with a single method .shap_multi_plot, which additionally returns calculated explanations -# Since Covid is a kind of Respiratory diseases, let's filter explanations for these classes - -# Let's get sample indices where 'Covid' is predicted what can be done with .get_sample_indices() method -sample_idx = explainer.get_sample_indices(predictions, "Covid") - -# Shapley values filtering by classes with .filter_by_class() method -shap_values_covid = explainer.filter_by_class( - explanations, class_name="Covid", sample_indices=sample_idx -) -shap_values_resp = explainer.filter_by_class( - explanations, class_name="Respiratory", sample_indices=sample_idx -) - - -# This code snippet demonstrates how to visually compare the mean absolute SHAP values for 'Covid' vs. 'Respiratory' diseases. - -# Feature names for the X-axis -feature_names = X_train.columns.values - -shap.summary_plot( - [shap_values_covid, shap_values_resp], - features=X_test.iloc[sample_idx], - feature_names=X_train.columns.values, - plot_type="bar", +explanations = explainer.shap_multi_plot( class_names=["Covid", "Respiratory"], + features=X_test.values, + pred_class="Respiratory", + features_names=X_train.columns.values, ) From 63c635f9a89871119360e0f4ab7e065b7ce6f838 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 12:42:30 +0200 Subject: [PATCH 23/31] pydocstyling --- hiclass/Explainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 3daf66d1..753ba16d 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -539,13 +539,14 @@ def get_sample_indices(self, predictions, class_name): def shap_multi_plot(self, class_names, features, pred_class, features_names=None): """ Plot shap_values for multi-class case on a bar and return explanations. + "Lazy" function which does not require any additional actions from the user apart from classifier fitting and explainer initialization. Parameters - __________ + ---------- class_names : list of str - A list of class names to calculate and visualize the Shapley values for + A list of class names to calculate and visualize the Shapley values for. features: array-like Matrix of feature values with shape (# features) or (# samples x # features). Typically, this would be the test set features (X_test). @@ -557,7 +558,7 @@ def shap_multi_plot(self, class_names, features, pred_class, features_names=None A list of feature names to include in the bar plot for the shap_values. Returns - _______ + ------- explanations: xarray.Dataset Whole explanations of data in features provided. """ From 6953517d9d6ff69f65a4263ef35ca8ee0f4feb88 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 17:55:09 +0200 Subject: [PATCH 24/31] some cases handled and new tests added --- hiclass/Explainer.py | 18 ++++++++---- tests/test_Explainer.py | 63 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 71 insertions(+), 10 deletions(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 753ba16d..86f1fb33 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -465,6 +465,9 @@ def filter_by_class(self, explanations, class_name, sample_indices=None): # Converting class_name to the string format class_name = str(class_name) + if class_name == "": + raise ValueError("Empty class!") + # Define level level = self.get_class_level(str(class_name)) @@ -511,9 +514,9 @@ def get_class_level(self, class_name): class_name = str(class_name) # Iterating through the nodes of hierarchy - for node in classifier.hierarchy_.nodes: - if class_name in node.split(classifier.separator_): - node_classes = node.split(classifier.separator_) + for node_ in classifier.hierarchy_.nodes: + if class_name in node_.split(classifier.separator_): + node_classes = node_.split(classifier.separator_) return node_classes.index(class_name) raise ValueError(f"Class '{class_name}' not found!") @@ -559,13 +562,18 @@ def shap_multi_plot(self, class_names, features, pred_class, features_names=None Returns ------- - explanations: xarray.Dataset + explanations: xarray.Dataset3 Whole explanations of data in features provided. """ classifier = self.hierarchical_model predictions = classifier.predict(features) - explanations = self.explain(features) + if pred_class is not None and not any(pred_class in row for row in predictions): + raise ValueError( + f"The specified class '{pred_class}' was not found in the predictions." + ) + + explanations = self.explain(features) sample_idx = self.get_sample_indices(predictions, pred_class) shap_array = [] for class_name in class_names: diff --git a/tests/test_Explainer.py b/tests/test_Explainer.py index 7ef0781a..1c8dbd69 100644 --- a/tests/test_Explainer.py +++ b/tests/test_Explainer.py @@ -7,6 +7,8 @@ LocalClassifierPerNode, Explainer, ) +from hiclass.datasets import load_platypus + try: import shap @@ -51,6 +53,17 @@ def explainer_data_no_root(): return x_train, x_test, y_train +@pytest.fixture() +def explainer_data_platypus(): + X_train, X_test, Y_train, Y_test = load_platypus() + X_train = X_train.iloc[0:10].values + X_test = X_test.iloc[0:10].values + Y_train = Y_train.iloc[0:10].values + Y_test = Y_test.iloc[0:10].values + + return X_train, X_test, Y_train + + @pytest.mark.skipif(not shap_installed, reason="shap not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_explainer_tree_lcppn(data, request): @@ -250,7 +263,9 @@ def test_explainers(data, request, classifier, mode): "classifier", [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], ) -@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) +@pytest.mark.parametrize( + "data", ["explainer_data", "explainer_data_no_root", "explainer_data_platypus"] +) def test_filter_by_level(data, request, classifier): x_train, x_test, y_train = request.getfixturevalue(data) rfc = RandomForestClassifier() @@ -270,9 +285,15 @@ def test_filter_by_level(data, request, classifier): "classifier", [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], ) -@pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) +@pytest.mark.parametrize( + "data", ["explainer_data", "explainer_data_no_root", "explainer_data_platypus"] +) def test_filter_by_class(data, request, classifier): - x_train, x_test, y_train = request.getfixturevalue(data) + ( + x_train, + x_test, + y_train, + ) = request.getfixturevalue(data) rfc = RandomForestClassifier() clf = classifier(local_classifier=rfc, replace_classifiers=False) @@ -284,5 +305,37 @@ def test_filter_by_class(data, request, classifier): for pred in predictions: for y in pred: - shap_y = explainer.filter_by_class(explanations, y) - assert isinstance(shap_y, np.ndarray) + if y == "": + # Check that ValueError is raised for empty string + with pytest.raises(ValueError): + explainer.filter_by_class(explanations, y) + else: + # Normal behavior for non-empty strings + shap_y = explainer.filter_by_class(explanations, y) + assert isinstance(shap_y, np.ndarray) + + +@pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.parametrize( + "classifier", + [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], +) +@pytest.mark.parametrize( + "data", ["explainer_data", "explainer_data_no_root", "explainer_data_platypus"] +) +def test_shap_multi_plot(data, request, classifier): + x_train, x_test, y_train = request.getfixturevalue(data) + rfc = RandomForestClassifier() + clf = classifier(local_classifier=rfc, replace_classifiers=False) + + clf.fit(x_train, y_train) + predictions = clf.predict(x_test) + explainer = Explainer(clf, data=x_train) + + class_names = np.random.choice(predictions[0, :], size=2) + explanations = explainer.shap_multi_plot( + class_names=np.random.choice(predictions[0, :], size=2), + features=x_test, + pred_class=class_names[0], + ) + assert isinstance(explanations, xarray.Dataset) From e3a54ee57d86a0651051cc2a33ef2e79b4c936dc Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 18:18:47 +0200 Subject: [PATCH 25/31] matplotlib requirement added --- Pipfile | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Pipfile b/Pipfile index 4a08c49b..7c43d844 100644 --- a/Pipfile +++ b/Pipfile @@ -7,6 +7,7 @@ name = "pypi" networkx = "*" numpy = "*" scikit-learn = "*" +matplotlib = "*" [dev-packages] pytest = "*" @@ -20,4 +21,4 @@ sphinx-rtd-theme = "0.5.2" [extras] ray = "*" shap = "0.44.1" -xarray = "*" +xarray = "*" \ No newline at end of file diff --git a/setup.py b/setup.py index dbe5a1fa..af6a678a 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ KEYWORDS = ["hierarchical classification"] DACS_SOFTWARE = "https://gitlab.com/dacs-hpi" # What packages are required for this module to be executed? -REQUIRED = ["networkx", "numpy", "scikit-learn", "scipy<1.13"] +REQUIRED = ["networkx", "numpy", "scikit-learn", "scipy<1.13", "matplotlib"] # What packages are optional? # 'fancy feature': ['django'],} From 73f74f61aaada76ed27ae78e2e750bedbfd8d590 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 18:53:33 +0200 Subject: [PATCH 26/31] algorithm explaining updated --- docs/source/algorithms/explainer.rst | 36 +++++++++++++++------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/docs/source/algorithms/explainer.rst b/docs/source/algorithms/explainer.rst index 3ed95aa9..eae7e624 100644 --- a/docs/source/algorithms/explainer.rst +++ b/docs/source/algorithms/explainer.rst @@ -111,6 +111,8 @@ Code sample lcppn.fit(x_train, y_train) explainer = Explainer(lcppn, data=x_train, mode="tree") + + # One of the possible ways to get explanations explanations = explainer.explain(x_test) @@ -118,32 +120,32 @@ Code sample Filtering and Manipulation ++++++++++++++++++++++++++ -The Explanation object returned by the Explainer is built using the :literal:`xarray.Dataset` data structure, that enables the application of any xarray dataset operation. For example, filtering specific values can be quickly done. To illustrate the filtering operation, suppose we have SHAP values stored in the Explanation object named :literal:`explanation`. +When you work with the `Explanation` object generated by the `Explainer`, you're leveraging the power of the `xarray.Dataset`. This structure is not just robust but also flexible, allowing for comprehensive dataset operations—especially filtering. + +**Practical Example: Filtering SHAP Values** -A common use case is to extract SHAP values for only the predicted nodes. In Local classifier per parent node approach, each node except the leaf nodes represents a classifier. Hence, to find the SHAP values, we can pass the prediction until the penultimate element to obtain the SHAP values. -To achieve this, we can use xarray's :literal:`.sel()` method: +Consider a scenario where you need to focus only on SHAP values corresponding to predicted nodes. In the context of our `LocalClassifierPerParentNode` model, each node—except for the leaf nodes—acts as a classifier. This setup is particularly useful when you're looking to isolate SHAP values up to the penultimate node in your predictions. Here’s how you can do this efficiently using the `sel()` method from xarray: .. code-block:: python + # Creating a mask for selecting SHAP values for predicted classes mask = {'class': lcppn.predict(x_test).flatten()[:-1]} - x = explanations.sel(mask).shap_values + selected_shap_values = explanations.sel(mask).shap_values -Also, we developed some helper functions built in the Explainer class as its methods to simplify standard explanation manipulation and filtering such as filtering explanations by level, or filtering explanations by class and returning its Shapley values. A basic example below is a continuation of the example from the beginning of this section: +**Advanced Visualization: Multi-Plot SHAP Values** -.. code-block:: python +For an even deeper analysis, you might want to visualize the SHAP values. The `shap_multi_plot()` method not only filters the data but also provides a visual representation of the SHAP values for specified classes. Below is an example that illustrates how to plot SHAP values for the classes "Covid" and "Respiratory": - predictions = lcppn.predict(x_test) - # Get the correcponding samples - covid_idx = explainer.get_sample_indices(predictions, 'Covid') - - # Filter the shap values - shap_values_covid = explainer.filter_by_class(explanations, 'Covid', covid_idx) - print(shap_values_covid) +.. code-block:: python - # Filter explanations by level - level = 1 - explanations_level_1 = explainer.filter_by_level(explanations, level) - print(explanations_level_1) + # Generating and plotting explanations for specific classes + explanations = explainer.shap_multi_plot( + class_names=["Covid", "Respiratory"], + features=x_test, + pred_class="Covid", + # Feature names specifiaction possible if x_train is a dataframe with specified columns_names + feature_names=x_train.columns.values + ) From 4ab46ad86fe30f728e15c48548930fab8c7fdb02 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Mon, 15 Apr 2024 21:00:05 +0200 Subject: [PATCH 27/31] small changes in indices selection --- docs/examples/plot_lcppn_explainer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/examples/plot_lcppn_explainer.py b/docs/examples/plot_lcppn_explainer.py index d1290a46..fd048c3c 100644 --- a/docs/examples/plot_lcppn_explainer.py +++ b/docs/examples/plot_lcppn_explainer.py @@ -33,12 +33,11 @@ explanations = explainer.explain(X_test.values) print(explanations) -# Filter samples which only predicted "Respiratory" -respiratory_idx = explainer.get_sample_indices(predictions, "Respiratory") - # Use .sel() method to apply the filter and obtain filtered results shap_val_respiratory = explainer.filter_by_class( - explanations, class_name="Respiratory", sample_indices=respiratory_idx + explanations, + class_name="Respiratory", + sample_indices=explainer.get_sample_indices(predictions, "Respiratory"), ) From e43180350e517af221b425d8ce001d6598676d8e Mon Sep 17 00:00:00 2001 From: dniprocat Date: Tue, 16 Apr 2024 14:38:06 +0200 Subject: [PATCH 28/31] some skipiff added --- tests/test_Explainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_Explainer.py b/tests/test_Explainer.py index 1c8dbd69..226b6361 100644 --- a/tests/test_Explainer.py +++ b/tests/test_Explainer.py @@ -259,6 +259,7 @@ def test_explainers(data, request, classifier, mode): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize( "classifier", [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], @@ -281,6 +282,7 @@ def test_filter_by_level(data, request, classifier): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize( "classifier", [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], @@ -316,6 +318,7 @@ def test_filter_by_class(data, request, classifier): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize( "classifier", [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], From bf64bfefbec8f6a1c5cf8c9638dfb3e8038e793a Mon Sep 17 00:00:00 2001 From: dniprocat Date: Tue, 16 Apr 2024 17:18:36 +0200 Subject: [PATCH 29/31] ray support added and used as a default (instead of joblib) --- hiclass/Explainer.py | 38 ++++++++++++++++++++++++++++++++++---- tests/test_Explainer.py | 3 +++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 86f1fb33..5051a884 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -25,6 +25,13 @@ else: shap_installed = True +try: + import ray +except ImportError: + _has_ray = False +else: + _has_ray = True + class Explainer: """Explainer class for returning shap values for each of the three hierarchical classifiers.""" @@ -125,7 +132,7 @@ def explain(self, X): else: raise ValueError(f"Invalid model: {self.hierarchical_model}.") - def _explain_with_xr(self, X): + def _explain_with_xr(self, X, use_joblib: bool = False): """ Generate SHAP values for each node using the SHAP package. @@ -139,10 +146,28 @@ def _explain_with_xr(self, X): explanation : xarray.Dataset An xarray Dataset consisting of SHAP values for each sample. """ - explanations = Parallel(n_jobs=self.n_jobs, backend="threading")( - delayed(self._calculate_shap_values)(sample.reshape(1, -1)) for sample in X - ) + if self.n_jobs > 1: + if _has_ray and not use_joblib: + if not ray.is_initialized(): + ray.init(num_cpus=self.n_jobs) + + calculate_shap_values_remote = ray.remote(calculate_shap_values_wrapper) + tasks = [ + calculate_shap_values_remote.remote(self, sample.reshape(1, -1)) + for sample in X + ] + + explanations = ray.get(tasks) + else: + explanations = Parallel(n_jobs=self.n_jobs, backend="threading")( + delayed(self._calculate_shap_values)(sample.reshape(1, -1)) + for sample in X + ) + else: + explanations = [ + self._calculate_shap_values(sample.reshape(1, -1)) for sample in X + ] dataset = xr.concat(explanations, dim="sample") return dataset @@ -590,3 +615,8 @@ def shap_multi_plot(self, class_names, features, pred_class, features_names=None class_names=class_names, ) return explanations + + +# A wrapper function for Ray enabling +def calculate_shap_values_wrapper(explainer, sample): + return explainer._calculate_shap_values(sample) diff --git a/tests/test_Explainer.py b/tests/test_Explainer.py index 226b6361..d0f7fd75 100644 --- a/tests/test_Explainer.py +++ b/tests/test_Explainer.py @@ -336,6 +336,9 @@ def test_shap_multi_plot(data, request, classifier): explainer = Explainer(clf, data=x_train) class_names = np.random.choice(predictions[0, :], size=2) + while class_names[0] == "" or class_names[1] == "": + class_names = np.random.choice(predictions[0, :], size=2) + explanations = explainer.shap_multi_plot( class_names=np.random.choice(predictions[0, :], size=2), features=x_test, From 1047a2cd43eff6d5eb3c71d4c1ca1fdf6e1228ab Mon Sep 17 00:00:00 2001 From: dniprocat Date: Tue, 16 Apr 2024 17:34:20 +0200 Subject: [PATCH 30/31] pydocs --- hiclass/Explainer.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 5051a884..3f23ab01 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -587,7 +587,7 @@ def shap_multi_plot(self, class_names, features, pred_class, features_names=None Returns ------- - explanations: xarray.Dataset3 + explanations: xarray.Dataset Whole explanations of data in features provided. """ classifier = self.hierarchical_model @@ -619,4 +619,19 @@ def shap_multi_plot(self, class_names, features, pred_class, features_names=None # A wrapper function for Ray enabling def calculate_shap_values_wrapper(explainer, sample): + """ + Wrapper function for shap_values calculations. + + Parameters + __________ + explainer: Explainer + Explainer + sample: array-like + Sample to calculate SHAP values for. + + Returns + _______ + shap_values: xarray.Dataset + Dataset of explanations for the sample. + """ return explainer._calculate_shap_values(sample) From c1dd74602f6115b5034485b094c81f78c190d612 Mon Sep 17 00:00:00 2001 From: dniprocat Date: Tue, 16 Apr 2024 17:40:30 +0200 Subject: [PATCH 31/31] pydocs --- hiclass/Explainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 3f23ab01..625f29aa 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -620,7 +620,7 @@ def shap_multi_plot(self, class_names, features, pred_class, features_names=None # A wrapper function for Ray enabling def calculate_shap_values_wrapper(explainer, sample): """ - Wrapper function for shap_values calculations. + Wrap the function for shap_values calculations. Parameters __________