Skip to content

Commit

Permalink
Merge pull request #395 from knaaptime/skplt
Browse files Browse the repository at this point in the history
rm vendor silhouette plotting
  • Loading branch information
knaaptime authored Feb 28, 2024
2 parents ca5faf3 + b3dc289 commit 4375bae
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 20 deletions.
1 change: 0 additions & 1 deletion .ci/310.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ dependencies:

- python-wget
- contextily
- scikit-plot
- python-graphviz
- nbsphinx
- numpydoc
Expand Down
1 change: 0 additions & 1 deletion .ci/311.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,4 @@ dependencies:
- coverage
- python-wget
- contextily
- scikit-plot
- python-graphviz
1 change: 0 additions & 1 deletion .ci/312.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ dependencies:

- python-wget
- contextily
- scikit-plot
- python-graphviz
- sphinx>=1.4.3
- sphinxcontrib-bibtex==1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
micromamba-version: 'latest'

- name: Install geosnap
run: pip install . ;python geosnap/tests/_dl_data.py;
run: pip install . --no-deps ;python geosnap/tests/_dl_data.py;
env:
COMBO_DATA: ${{ secrets.COMBO_DATA }}

Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ dependencies:
- xlrd
- tobler >=0.8.2
- contextily
- scikit-plot
- mapclassify
- spopt >=0.3.0
- s3fs
Expand Down
2 changes: 1 addition & 1 deletion geosnap/analyze/_cluster_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def kmeans(
verbose=0,
random_state=None,
copy_x=True,
algorithm="auto",
algorithm="lloyd",
**kwargs,
):
"""K-Means clustering.
Expand Down
4 changes: 2 additions & 2 deletions geosnap/analyze/_model_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

import esda
import geopandas as gpd
import scikitplot as skplt
from sklearn.metrics import (
calinski_harabasz_score,
davies_bouldin_score,
silhouette_samples,
)

from ..visualize.mapping import plot_timeseries
from ..visualize.skplt import plot_silhouette as _plot_silhouette
from .dynamics import predict_markov_labels as _predict_markov_labels
from .incs import lincs_from_gdf

Expand Down Expand Up @@ -369,7 +369,7 @@ def plot_silhouette(self, metric="euclidean", title="Silhouette Score"):
elif self.pooling == "pooled":
# if pooled, scale the whole series at once
df.loc[:, self.columns] = self.scaler.fit_transform(df.values)
fig = skplt.metrics.plot_silhouette(
fig = _plot_silhouette(
df[self.columns].values, self.labels, metric=metric, title=title
)

Expand Down
9 changes: 6 additions & 3 deletions geosnap/analyze/_region_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def kmeans_spatial(data, columns, w, n_clusters=5, **kwargs):
return model


def spenc(data, w, columns, n_clusters=5, gamma=500, random_state=None, **kwargs):
def spenc(
data, w, columns, n_clusters=5, gamma=1, random_state=None, n_jobs=-1, **kwargs
):
"""Spatially encouraged spectral clustering.
:cite:`wolf2018`
Expand Down Expand Up @@ -101,6 +103,7 @@ def spenc(data, w, columns, n_clusters=5, gamma=500, random_state=None, **kwargs
attrs_name=columns,
gamma=gamma,
random_state=random_state,
n_jobs=n_jobs,
)

model.solve()
Expand All @@ -115,7 +118,7 @@ def skater(
floor=-np.inf,
islands="increase",
cluster_args=None,
**kwargs
**kwargs,
):
"""SKATER spatial clustering algorithm.
Expand Down Expand Up @@ -193,7 +196,7 @@ def max_p(
threshold=10,
max_iterations_construction=99,
top_n=2,
**kwargs
**kwargs,
):
"""Max-p clustering algorithm :cite:`Duque2012`.
Expand Down
14 changes: 6 additions & 8 deletions geosnap/harmonize/harmonize.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,7 @@ def harmonize(
times.remove(target_year)
target_df = dfs[dfs[temporal_index] == target_year]

if target_df.index.name:
unit_index = target_df.index.name
else:
unit_index = "id"
unit_index = target_df.index.name if target_df.index.name else "id"
target_df[unit_index] = target_df.index.values

geom_name = target_df.geometry.name
Expand Down Expand Up @@ -209,10 +206,11 @@ def harmonize(
pixel_values=pixel_values,
raster=raster,
)
except IOError:
raise IOError(
"Unable to locate raster. If using the `dasymetric` or model-based methods. You"
"must provide a raster file and indicate which pixel values contain developed land"
except OSError as e:
raise OSError from e(
"Unable to locate raster. If using the `dasymetric` or model-based "
"methods. You must provide a raster file and indicate which pixel "
"values contain developed land"
)
else:
raise ValueError('weights_method must of one of ["area", "dasymetric"]')
Expand Down
147 changes: 147 additions & 0 deletions geosnap/visualize/skplt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import (
silhouette_samples,
silhouette_score,
)
from sklearn.preprocessing import LabelEncoder


def plot_silhouette(
X,
cluster_labels,
title="Silhouette Analysis",
metric="euclidean",
ax=None,
figsize=None,
cmap="nipy_spectral",
title_fontsize="large",
text_fontsize="medium",
):
"""Plots silhouette analysis of clusters provided.
NOTE: this function is vendored from scikit-plot which is no longer maintained
Args:
X (array-like, shape (n_samples, n_features)):
Data to cluster, where n_samples is the number of samples and
n_features is the number of features.
cluster_labels (array-like, shape (n_samples,)):
Cluster label for each sample.
title (string, optional): Title of the generated plot. Defaults to
"Silhouette Analysis"
metric (string or callable, optional): The metric to use when
calculating distance between instances in a feature array.
If metric is a string, it must be one of the options allowed by
sklearn.metrics.pairwise.pairwise_distances. If X is
the distance array itself, use "precomputed" as the metric.
copy (boolean, optional): Determines whether ``fit`` is used on
**clf** or on a copy of **clf**.
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
plot the curve. If None, the plot is drawn on a new set of axes.
figsize (2-tuple, optional): Tuple denoting figure size of the plot
e.g. (6, 6). Defaults to ``None``.
cmap (string or :class:`matplotlib.colors.Colormap` instance, optional):
Colormap used for plotting the projection. View Matplotlib Colormap
documentation for available options.
https://matplotlib.org/users/colormaps.html
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
Use e.g. "small", "medium", "large" or integer-values. Defaults to
"large".
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
Use e.g. "small", "medium", "large" or integer-values. Defaults to
"medium".
Returns:
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
drawn.
Example:
>>> import scikitplot as skplt
>>> kmeans = KMeans(n_clusters=4, random_state=1)
>>> cluster_labels = kmeans.fit_predict(X)
>>> skplt.metrics.plot_silhouette(X, cluster_labels)
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
>>> plt.show()
.. image:: _static/examples/plot_silhouette.png
:align: center
:alt: Silhouette Plot
"""
cluster_labels = np.asarray(cluster_labels)

le = LabelEncoder()
cluster_labels_encoded = le.fit_transform(cluster_labels)

n_clusters = len(np.unique(cluster_labels))

silhouette_avg = silhouette_score(X, cluster_labels, metric=metric)

sample_silhouette_values = silhouette_samples(X, cluster_labels, metric=metric)

if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)

ax.set_title(title, fontsize=title_fontsize)
ax.set_xlim([-0.1, 1])

ax.set_ylim([0, len(X) + (n_clusters + 1) * 10 + 10])

ax.set_xlabel("Silhouette coefficient values", fontsize=text_fontsize)
ax.set_ylabel("Cluster label", fontsize=text_fontsize)

y_lower = 10

for i in range(n_clusters):
ith_cluster_silhouette_values = sample_silhouette_values[
cluster_labels_encoded == i
]

ith_cluster_silhouette_values.sort()

size_cluster_i = ith_cluster_silhouette_values.shape[0]
y_upper = y_lower + size_cluster_i

color = plt.cm.get_cmap(cmap)(float(i) / n_clusters)

ax.fill_betweenx(
np.arange(y_lower, y_upper),
0,
ith_cluster_silhouette_values,
facecolor=color,
edgecolor=color,
alpha=0.7,
)

ax.text(
-0.05,
y_lower + 0.5 * size_cluster_i,
str(le.classes_[i]),
fontsize=text_fontsize,
)

y_lower = y_upper + 10

ax.axvline(
x=silhouette_avg,
color="red",
linestyle="--",
label="Silhouette score: {0:0.3f}".format(silhouette_avg),
)

ax.set_yticks([]) # Clear the y-axis labels / ticks
ax.set_xticks(np.arange(-0.1, 1.0, 0.2))

ax.tick_params(labelsize=text_fontsize)
ax.legend(loc="best", fontsize=text_fontsize)

return ax
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ dependencies = [
"quilt3>=3.6",
"pyarrow>=0.14.1",
"contextily",
"scikit-plot",
"tobler>=0.8.2",
"spopt>=0.3.0",
"fsspec",
Expand Down

0 comments on commit 4375bae

Please sign in to comment.