From da67f977082090b154770c8d7ca661445887fc34 Mon Sep 17 00:00:00 2001 From: JuMr <10944593+jumr@user.noreply.gitee.com> Date: Mon, 3 Oct 2022 16:23:04 +0800 Subject: [PATCH] add **metrickw to KMedoids for the convinence of the metric like 'minkowski' where additional args are need --- sklearn_extra/cluster/_k_medoids.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sklearn_extra/cluster/_k_medoids.py b/sklearn_extra/cluster/_k_medoids.py index f9d964df..f8e23a3e 100644 --- a/sklearn_extra/cluster/_k_medoids.py +++ b/sklearn_extra/cluster/_k_medoids.py @@ -155,6 +155,7 @@ def __init__( init="heuristic", max_iter=300, random_state=None, + **metrickw ): self.n_clusters = n_clusters self.metric = metric @@ -162,6 +163,7 @@ def __init__( self.init = init self.max_iter = max_iter self.random_state = random_state + self.metrickw=metrickw def _check_nonnegative_int(self, value, desc, strict=True): """Validates if value is a valid integer > 0""" @@ -235,7 +237,7 @@ def fit(self, X, y=None): % (self.n_clusters, X.shape[0]) ) - D = pairwise_distances(X, metric=self.metric) + D = pairwise_distances(X, metric=self.metric,**self.metrickwm) medoid_idxs = self._initialize_medoids( D, self.n_clusters, random_state_, X @@ -379,10 +381,10 @@ def transform(self, X): check_is_fitted(self, "cluster_centers_") Y = self.cluster_centers_ - kwargs = {} + if self.metric == "seuclidean": kwargs["V"] = np.var(np.vstack([X, Y]), axis=0, ddof=1) - DXY = pairwise_distances(X, Y=Y, metric=self.metric, **kwargs) + DXY = pairwise_distances(X, Y=Y, metric=self.metric, **self.metrickw) return DXY @@ -421,7 +423,7 @@ def predict(self, X): X, Y=self.cluster_centers_, metric=self.metric, - metric_kwargs=kwargs, + metric_kwargs=self.metrickw, ) return pd_argmin