-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #395 from knaaptime/skplt
rm vendor silhouette plotting
- Loading branch information
Showing
11 changed files
with
163 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,6 @@ dependencies: | |
|
||
- python-wget | ||
- contextily | ||
- scikit-plot | ||
- python-graphviz | ||
- nbsphinx | ||
- numpydoc | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,5 +35,4 @@ dependencies: | |
- coverage | ||
- python-wget | ||
- contextily | ||
- scikit-plot | ||
- python-graphviz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from sklearn.metrics import ( | ||
silhouette_samples, | ||
silhouette_score, | ||
) | ||
from sklearn.preprocessing import LabelEncoder | ||
|
||
|
||
def plot_silhouette( | ||
X, | ||
cluster_labels, | ||
title="Silhouette Analysis", | ||
metric="euclidean", | ||
ax=None, | ||
figsize=None, | ||
cmap="nipy_spectral", | ||
title_fontsize="large", | ||
text_fontsize="medium", | ||
): | ||
"""Plots silhouette analysis of clusters provided. | ||
NOTE: this function is vendored from scikit-plot which is no longer maintained | ||
Args: | ||
X (array-like, shape (n_samples, n_features)): | ||
Data to cluster, where n_samples is the number of samples and | ||
n_features is the number of features. | ||
cluster_labels (array-like, shape (n_samples,)): | ||
Cluster label for each sample. | ||
title (string, optional): Title of the generated plot. Defaults to | ||
"Silhouette Analysis" | ||
metric (string or callable, optional): The metric to use when | ||
calculating distance between instances in a feature array. | ||
If metric is a string, it must be one of the options allowed by | ||
sklearn.metrics.pairwise.pairwise_distances. If X is | ||
the distance array itself, use "precomputed" as the metric. | ||
copy (boolean, optional): Determines whether ``fit`` is used on | ||
**clf** or on a copy of **clf**. | ||
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to | ||
plot the curve. If None, the plot is drawn on a new set of axes. | ||
figsize (2-tuple, optional): Tuple denoting figure size of the plot | ||
e.g. (6, 6). Defaults to ``None``. | ||
cmap (string or :class:`matplotlib.colors.Colormap` instance, optional): | ||
Colormap used for plotting the projection. View Matplotlib Colormap | ||
documentation for available options. | ||
https://matplotlib.org/users/colormaps.html | ||
title_fontsize (string or int, optional): Matplotlib-style fontsizes. | ||
Use e.g. "small", "medium", "large" or integer-values. Defaults to | ||
"large". | ||
text_fontsize (string or int, optional): Matplotlib-style fontsizes. | ||
Use e.g. "small", "medium", "large" or integer-values. Defaults to | ||
"medium". | ||
Returns: | ||
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was | ||
drawn. | ||
Example: | ||
>>> import scikitplot as skplt | ||
>>> kmeans = KMeans(n_clusters=4, random_state=1) | ||
>>> cluster_labels = kmeans.fit_predict(X) | ||
>>> skplt.metrics.plot_silhouette(X, cluster_labels) | ||
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490> | ||
>>> plt.show() | ||
.. image:: _static/examples/plot_silhouette.png | ||
:align: center | ||
:alt: Silhouette Plot | ||
""" | ||
cluster_labels = np.asarray(cluster_labels) | ||
|
||
le = LabelEncoder() | ||
cluster_labels_encoded = le.fit_transform(cluster_labels) | ||
|
||
n_clusters = len(np.unique(cluster_labels)) | ||
|
||
silhouette_avg = silhouette_score(X, cluster_labels, metric=metric) | ||
|
||
sample_silhouette_values = silhouette_samples(X, cluster_labels, metric=metric) | ||
|
||
if ax is None: | ||
fig, ax = plt.subplots(1, 1, figsize=figsize) | ||
|
||
ax.set_title(title, fontsize=title_fontsize) | ||
ax.set_xlim([-0.1, 1]) | ||
|
||
ax.set_ylim([0, len(X) + (n_clusters + 1) * 10 + 10]) | ||
|
||
ax.set_xlabel("Silhouette coefficient values", fontsize=text_fontsize) | ||
ax.set_ylabel("Cluster label", fontsize=text_fontsize) | ||
|
||
y_lower = 10 | ||
|
||
for i in range(n_clusters): | ||
ith_cluster_silhouette_values = sample_silhouette_values[ | ||
cluster_labels_encoded == i | ||
] | ||
|
||
ith_cluster_silhouette_values.sort() | ||
|
||
size_cluster_i = ith_cluster_silhouette_values.shape[0] | ||
y_upper = y_lower + size_cluster_i | ||
|
||
color = plt.cm.get_cmap(cmap)(float(i) / n_clusters) | ||
|
||
ax.fill_betweenx( | ||
np.arange(y_lower, y_upper), | ||
0, | ||
ith_cluster_silhouette_values, | ||
facecolor=color, | ||
edgecolor=color, | ||
alpha=0.7, | ||
) | ||
|
||
ax.text( | ||
-0.05, | ||
y_lower + 0.5 * size_cluster_i, | ||
str(le.classes_[i]), | ||
fontsize=text_fontsize, | ||
) | ||
|
||
y_lower = y_upper + 10 | ||
|
||
ax.axvline( | ||
x=silhouette_avg, | ||
color="red", | ||
linestyle="--", | ||
label="Silhouette score: {0:0.3f}".format(silhouette_avg), | ||
) | ||
|
||
ax.set_yticks([]) # Clear the y-axis labels / ticks | ||
ax.set_xticks(np.arange(-0.1, 1.0, 0.2)) | ||
|
||
ax.tick_params(labelsize=text_fontsize) | ||
ax.legend(loc="best", fontsize=text_fontsize) | ||
|
||
return ax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters