diff --git a/Pipfile b/Pipfile index 4a08c49b..7c43d844 100644 --- a/Pipfile +++ b/Pipfile @@ -7,6 +7,7 @@ name = "pypi" networkx = "*" numpy = "*" scikit-learn = "*" +matplotlib = "*" [dev-packages] pytest = "*" @@ -20,4 +21,4 @@ sphinx-rtd-theme = "0.5.2" [extras] ray = "*" shap = "0.44.1" -xarray = "*" +xarray = "*" \ No newline at end of file diff --git a/docs/examples/plot_lcpl_explainer.py b/docs/examples/plot_lcpl_explainer.py index d085c791..b32dec0c 100644 --- a/docs/examples/plot_lcpl_explainer.py +++ b/docs/examples/plot_lcpl_explainer.py @@ -25,35 +25,14 @@ # Define Explainer explainer = Explainer(classifier, data=X_train, mode="tree") -explanations = explainer.explain(X_test.values) -print(explanations) -# Let's filter the Shapley values corresponding to the Covid (level 1) -# and 'Respiratory' (level 0) +# Now, our task is to see how feature importance may vary from level to level +# We are going to calculate shap_values for 'Respiratory', 'Covid' and plot what we calculated +# This can be done with a single method .shap_multi_plot, which additionally returns calculated explanations -covid_idx = classifier.predict(X_test)[:, 1] == "Covid" - -shap_filter_covid = {"level": 1, "class": "Covid", "sample": covid_idx} -shap_filter_resp = {"level": 0, "class": "Respiratory", "sample": covid_idx} -shap_val_covid = explanations.sel(**shap_filter_covid) -shap_val_resp = explanations.sel(**shap_filter_resp) - - -# This code snippet demonstrates how to visually compare the mean absolute SHAP values for 'Covid' vs. 'Respiratory' diseases. - -# Feature names for the X-axis -feature_names = X_train.columns.values - -# SHAP values for 'Covid' -shap_values_covid = shap_val_covid.shap_values.values - -# SHAP values for 'Respiratory' -shap_values_resp = shap_val_resp.shap_values.values - -shap.summary_plot( - [shap_values_covid, shap_values_resp], - features=X_test.iloc[covid_idx], - feature_names=X_train.columns.values, - plot_type="bar", +explanations = explainer.shap_multi_plot( class_names=["Covid", "Respiratory"], + features=X_test.values, + pred_class="Respiratory", + features_names=X_train.columns.values, ) diff --git a/docs/examples/plot_lcppn_explainer.py b/docs/examples/plot_lcppn_explainer.py index ab27ce38..fd048c3c 100644 --- a/docs/examples/plot_lcppn_explainer.py +++ b/docs/examples/plot_lcppn_explainer.py @@ -25,23 +25,25 @@ # Train local classifier per parent node classifier.fit(X_train, Y_train) +# Get predictions +predictions = classifier.predict(X_test) + # Define Explainer explainer = Explainer(classifier, data=X_train.values, mode="tree") explanations = explainer.explain(X_test.values) print(explanations) -# Filter samples which only predicted "Respiratory" at first level -respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory" - -# Specify additional filters to obtain only level 0 -shap_filter = {"level": 0, "class": "Respiratory", "sample": respiratory_idx} - # Use .sel() method to apply the filter and obtain filtered results -shap_val_respiratory = explanations.sel(shap_filter) +shap_val_respiratory = explainer.filter_by_class( + explanations, + class_name="Respiratory", + sample_indices=explainer.get_sample_indices(predictions, "Respiratory"), +) + # Plot feature importance on test set shap.plots.violin( - shap_val_respiratory.shap_values, + shap_val_respiratory, feature_names=X_train.columns.values, plot_size=(13, 8), ) diff --git a/docs/source/algorithms/explainer.rst b/docs/source/algorithms/explainer.rst index b87ced9c..eae7e624 100644 --- a/docs/source/algorithms/explainer.rst +++ b/docs/source/algorithms/explainer.rst @@ -111,6 +111,8 @@ Code sample lcppn.fit(x_train, y_train) explainer = Explainer(lcppn, data=x_train, mode="tree") + + # One of the possible ways to get explanations explanations = explainer.explain(x_test) @@ -118,14 +120,33 @@ Code sample Filtering and Manipulation ++++++++++++++++++++++++++ -The Explanation object returned by the Explainer is built using the :literal:`xarray.Dataset` data structure, that enables the application of any xarray dataset operation. For example, filtering specific values can be quickly done. To illustrate the filtering operation, suppose we have SHAP values stored in the Explanation object named :literal:`explanation`. +When you work with the `Explanation` object generated by the `Explainer`, you're leveraging the power of the `xarray.Dataset`. This structure is not just robust but also flexible, allowing for comprehensive dataset operations—especially filtering. + +**Practical Example: Filtering SHAP Values** -A common use case is to extract SHAP values for only the predicted nodes. In Local classifier per parent node approach, each node except the leaf nodes represents a classifier. Hence, to find the SHAP values, we can pass the prediction until the penultimate element to obtain the SHAP values. -To achieve this, we can use xarray's :literal:`.sel()` method: +Consider a scenario where you need to focus only on SHAP values corresponding to predicted nodes. In the context of our `LocalClassifierPerParentNode` model, each node—except for the leaf nodes—acts as a classifier. This setup is particularly useful when you're looking to isolate SHAP values up to the penultimate node in your predictions. Here’s how you can do this efficiently using the `sel()` method from xarray: .. code-block:: python + # Creating a mask for selecting SHAP values for predicted classes mask = {'class': lcppn.predict(x_test).flatten()[:-1]} - x = explanations.sel(mask).shap_values + selected_shap_values = explanations.sel(mask).shap_values + +**Advanced Visualization: Multi-Plot SHAP Values** + +For an even deeper analysis, you might want to visualize the SHAP values. The `shap_multi_plot()` method not only filters the data but also provides a visual representation of the SHAP values for specified classes. Below is an example that illustrates how to plot SHAP values for the classes "Covid" and "Respiratory": + +.. code-block:: python + + # Generating and plotting explanations for specific classes + explanations = explainer.shap_multi_plot( + class_names=["Covid", "Respiratory"], + features=x_test, + pred_class="Covid", + # Feature names specifiaction possible if x_train is a dataframe with specified columns_names + feature_names=x_train.columns.values + ) + + More advanced usage and capabilities can be found at the `Xarray.Dataset `_ documentation. diff --git a/hiclass/Explainer.py b/hiclass/Explainer.py index 8708527b..625f29aa 100644 --- a/hiclass/Explainer.py +++ b/hiclass/Explainer.py @@ -25,6 +25,13 @@ else: shap_installed = True +try: + import ray +except ImportError: + _has_ray = False +else: + _has_ray = True + class Explainer: """Explainer class for returning shap values for each of the three hierarchical classifiers.""" @@ -125,7 +132,7 @@ def explain(self, X): else: raise ValueError(f"Invalid model: {self.hierarchical_model}.") - def _explain_with_xr(self, X): + def _explain_with_xr(self, X, use_joblib: bool = False): """ Generate SHAP values for each node using the SHAP package. @@ -139,10 +146,28 @@ def _explain_with_xr(self, X): explanation : xarray.Dataset An xarray Dataset consisting of SHAP values for each sample. """ - explanations = Parallel(n_jobs=self.n_jobs, backend="threading")( - delayed(self._calculate_shap_values)(sample.reshape(1, -1)) for sample in X - ) + if self.n_jobs > 1: + if _has_ray and not use_joblib: + if not ray.is_initialized(): + ray.init(num_cpus=self.n_jobs) + + calculate_shap_values_remote = ray.remote(calculate_shap_values_wrapper) + tasks = [ + calculate_shap_values_remote.remote(self, sample.reshape(1, -1)) + for sample in X + ] + + explanations = ray.get(tasks) + else: + explanations = Parallel(n_jobs=self.n_jobs, backend="threading")( + delayed(self._calculate_shap_values)(sample.reshape(1, -1)) + for sample in X + ) + else: + explanations = [ + self._calculate_shap_values(sample.reshape(1, -1)) for sample in X + ] dataset = xr.concat(explanations, dim="sample") return dataset @@ -365,3 +390,248 @@ def _calculate_shap_values(self, X): datasets.append(local_dataset) sample_explanation = xr.concat(datasets, dim="level") return sample_explanation + + def filter_by_level(self, explanations, level): + """ + Return the explanations filtered by the given level. + + Parameters + __________ + explanations : xarray.DataArray + The explanations to filter + level : int + level in the hierarchy to filter + + Returns + _______ + filtered_explanations : xarray.Dataset + Explanations filtered by the given level + + Examples + -------- + >>> from sklearn.ensemble import RandomForestClassifier + >>> import numpy as np + >>> from hiclass import LocalClassifierPerParentNode, Explainer + >>> rfc = RandomForestClassifier() + >>> lcppn = LocalClassifierPerParentNode(local_classifier=rfc, replace_classifiers=False) + >>> x_train = np.array([[1, 3], [2, 5]]) + >>> y_train = np.array([[1, 2], [3, 4]]) + >>> x_test = np.array([[4, 6]]) + >>> lcppn.fit(x_train, y_train) + >>> explainer = Explainer(lcppn, data=x_train, mode="tree") + >>> explanations = explainer.explain(x_test) + >>> explanations_level_1 = explainer.filter_by_level(explanations, level=1) + + Dimensions: (class: 3, sample: 1, feature: 2) + Coordinates: + * class (class) >> from sklearn.ensemble import RandomForestClassifier + >>> import numpy as np + >>> from hiclass import LocalClassifierPerParentNode, Explainer + >>> rfc = RandomForestClassifier() + >>> lcppn = LocalClassifierPerParentNode(local_classifier=rfc, replace_classifiers=False) + >>> x_train = np.array([[1, 3], [2, 5]]) + >>> y_train = np.array([[1, 2], [3, 4]]) + >>> x_test = np.array([[4, 6]]) + >>> lcppn.fit(x_train, y_train) + >>> predictions = lcppn.predict(x_test) + >>> explainer = Explainer(lcppn, data=x_train, mode="tree") + >>> explanations = explainer.explain(x_test) + >>> filtered_shap = explainer.filter_by_class(explanations, level=3) + >>> print(filtered_shap) + [['3' '4']] + [[0.1 0.105]] + """ + # Ensure that explanations are provided and have the expected structure + if not isinstance(explanations, xr.Dataset): + raise ValueError("Explanations should be an xarray.Dataset!") + + # Converting class_name to the string format + class_name = str(class_name) + + if class_name == "": + raise ValueError("Empty class!") + + # Define level + level = self.get_class_level(str(class_name)) + + # Handling with LocalClassifierPerNode case + if isinstance(self.hierarchical_model, LocalClassifierPerNode): + class_name = f"{class_name}_1" + + # Shap filter + shap_filter = {"class": class_name, "level": level} + if sample_indices is not None: + shap_filter["sample"] = sample_indices + + # Select the SHAP values according to the filter and handle possible errors + try: + filtered_explanations = explanations.sel(**shap_filter) + except KeyError as e: + raise KeyError( + f"Class name {class_name} with level {level} not found." + ) from e + + # Return the selected SHAP values as a NumPy array + return filtered_explanations.shap_values.values + + def get_class_level(self, class_name): + """ + Return level of the class in the hierarchy. + + Parameters + __________ + class_name : int or str + Name of the class + + Returns + _______ + class_level : int + Level of the class in the hierarchy + + + """ + # Set the classifier + classifier = self.hierarchical_model + + # Converting class_name to the string formatn + class_name = str(class_name) + + # Iterating through the nodes of hierarchy + for node_ in classifier.hierarchy_.nodes: + if class_name in node_.split(classifier.separator_): + node_classes = node_.split(classifier.separator_) + return node_classes.index(class_name) + + raise ValueError(f"Class '{class_name}' not found!") + + def get_sample_indices(self, predictions, class_name): + """ + Return indices of predictions corresponding to the certain class. + + Parameters + __________ + predictions: array-like + Array of predictions of the hierarchical classificator + class_name: str + Name of class + + Returns + _______ + sample_indices: boolean array of indices + """ + class_level = self.get_class_level(class_name) + return predictions[:, class_level] == class_name + + def shap_multi_plot(self, class_names, features, pred_class, features_names=None): + """ + Plot shap_values for multi-class case on a bar and return explanations. + + "Lazy" function which does not require any additional actions from the user + apart from classifier fitting and explainer initialization. + + Parameters + ---------- + class_names : list of str + A list of class names to calculate and visualize the Shapley values for. + features: array-like + Matrix of feature values with shape (# features) or (# samples x # features). + Typically, this would be the test set features (X_test). + pred_class : int or str + The class label that the classifier's predictions must match for a sample to be + included in the subset of data used for SHAP value calculation. If not provided, + no filtering is applied, and all samples are considered. + features_names : list, optional + A list of feature names to include in the bar plot for the shap_values. + + Returns + ------- + explanations: xarray.Dataset + Whole explanations of data in features provided. + """ + classifier = self.hierarchical_model + predictions = classifier.predict(features) + + if pred_class is not None and not any(pred_class in row for row in predictions): + raise ValueError( + f"The specified class '{pred_class}' was not found in the predictions." + ) + + explanations = self.explain(features) + sample_idx = self.get_sample_indices(predictions, pred_class) + shap_array = [] + for class_name in class_names: + shap_val = self.filter_by_class( + explanations, class_name=class_name, sample_indices=sample_idx + ) + shap_array.append(shap_val) + + shap.summary_plot( + shap_array, + features=features[sample_idx], + feature_names=features_names, + plot_type="bar", + class_names=class_names, + ) + return explanations + + +# A wrapper function for Ray enabling +def calculate_shap_values_wrapper(explainer, sample): + """ + Wrap the function for shap_values calculations. + + Parameters + __________ + explainer: Explainer + Explainer + sample: array-like + Sample to calculate SHAP values for. + + Returns + _______ + shap_values: xarray.Dataset + Dataset of explanations for the sample. + """ + return explainer._calculate_shap_values(sample) diff --git a/setup.py b/setup.py index dbe5a1fa..af6a678a 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ KEYWORDS = ["hierarchical classification"] DACS_SOFTWARE = "https://gitlab.com/dacs-hpi" # What packages are required for this module to be executed? -REQUIRED = ["networkx", "numpy", "scikit-learn", "scipy<1.13"] +REQUIRED = ["networkx", "numpy", "scikit-learn", "scipy<1.13", "matplotlib"] # What packages are optional? # 'fancy feature': ['django'],} diff --git a/tests/test_Explainer.py b/tests/test_Explainer.py index 303216f6..d0f7fd75 100644 --- a/tests/test_Explainer.py +++ b/tests/test_Explainer.py @@ -7,6 +7,8 @@ LocalClassifierPerNode, Explainer, ) +from hiclass.datasets import load_platypus + try: import shap @@ -51,6 +53,17 @@ def explainer_data_no_root(): return x_train, x_test, y_train +@pytest.fixture() +def explainer_data_platypus(): + X_train, X_test, Y_train, Y_test = load_platypus() + X_train = X_train.iloc[0:10].values + X_test = X_test.iloc[0:10].values + Y_train = Y_train.iloc[0:10].values + Y_test = Y_test.iloc[0:10].values + + return X_train, X_test, Y_train + + @pytest.mark.skipif(not shap_installed, reason="shap not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_explainer_tree_lcppn(data, request): @@ -243,3 +256,92 @@ def test_explainers(data, request, classifier, mode): "": shap.Explainer, } assert explainer.explainer == mode_mapping[mode] + + +@pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") +@pytest.mark.parametrize( + "classifier", + [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], +) +@pytest.mark.parametrize( + "data", ["explainer_data", "explainer_data_no_root", "explainer_data_platypus"] +) +def test_filter_by_level(data, request, classifier): + x_train, x_test, y_train = request.getfixturevalue(data) + rfc = RandomForestClassifier() + clf = classifier(local_classifier=rfc, replace_classifiers=False) + + clf.fit(x_train, y_train) + explainer = Explainer(clf, data=x_train) + explanations = explainer.explain(x_test) + + for i in range(3): + filtered_explanations = explainer.filter_by_level(explanations, i) + assert isinstance(filtered_explanations, xarray.Dataset) + + +@pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") +@pytest.mark.parametrize( + "classifier", + [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], +) +@pytest.mark.parametrize( + "data", ["explainer_data", "explainer_data_no_root", "explainer_data_platypus"] +) +def test_filter_by_class(data, request, classifier): + ( + x_train, + x_test, + y_train, + ) = request.getfixturevalue(data) + rfc = RandomForestClassifier() + clf = classifier(local_classifier=rfc, replace_classifiers=False) + + clf.fit(x_train, y_train) + predictions = clf.predict(x_test) + + explainer = Explainer(clf, data=x_train) + explanations = explainer.explain(x_test) + + for pred in predictions: + for y in pred: + if y == "": + # Check that ValueError is raised for empty string + with pytest.raises(ValueError): + explainer.filter_by_class(explanations, y) + else: + # Normal behavior for non-empty strings + shap_y = explainer.filter_by_class(explanations, y) + assert isinstance(shap_y, np.ndarray) + + +@pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") +@pytest.mark.parametrize( + "classifier", + [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], +) +@pytest.mark.parametrize( + "data", ["explainer_data", "explainer_data_no_root", "explainer_data_platypus"] +) +def test_shap_multi_plot(data, request, classifier): + x_train, x_test, y_train = request.getfixturevalue(data) + rfc = RandomForestClassifier() + clf = classifier(local_classifier=rfc, replace_classifiers=False) + + clf.fit(x_train, y_train) + predictions = clf.predict(x_test) + explainer = Explainer(clf, data=x_train) + + class_names = np.random.choice(predictions[0, :], size=2) + while class_names[0] == "" or class_names[1] == "": + class_names = np.random.choice(predictions[0, :], size=2) + + explanations = explainer.shap_multi_plot( + class_names=np.random.choice(predictions[0, :], size=2), + features=x_test, + pred_class=class_names[0], + ) + assert isinstance(explanations, xarray.Dataset)