diff --git a/CHANGELOG b/CHANGELOG index 0944ed489..679df31ae 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -16,6 +16,13 @@ The rules for CHANGELOG file: - Updating ``FPS`` to allow a numpy array of ints as an initialize parameter (#145) - Supported Python versions are now ranging from 3.9 - 3.12. - Updating ``skmatter.datasets`` submodule to support sklearn 1.5.0 (#229) +- Add `SparseKDE` class (#222) +- Add `QuickShift` class (#222) +- Add an example on how to conduct PAMM algorithm with `SparseKDE` and `QuickShift` + (#222) +- Add H2O-BLYP-Piglet dataset (#222) +- Add two distance metrics that support the periodic boundry condition, + `periodic_pairwise_euclidean_distances` and `pairwise_mahalanobis_distances` (#222) 0.2.0 (2023/08/24) ------------------ diff --git a/docs/src/bibliography.rst b/docs/src/bibliography.rst index 428925508..896b24478 100644 --- a/docs/src/bibliography.rst +++ b/docs/src/bibliography.rst @@ -6,6 +6,12 @@ References "Principal covariates regression: Part I. Theory", Chemom. intell. lab. syst. 14 (1992) 155-164 https://doi.org/10.1016/0169-7439(92)80100-I +.. [Gasparotto2014] + Piero Gasparotto, Michele Ceriotti, + "Recognizing molecular patterns by machine learning: An agnostic structural + definition of the hydrogen bond", J. Chem. Phys., 141 (17): 174110. + https://doi.org/10.1063/1.4900655. + .. [Imbalzano2018] Giulio Imbalzano, Andrea Anelli, Daniele Giofré,Sinja Klees, Jörg Behler, and Michele Ceriotti, “Automatic selection of atomic fingerprints and reference diff --git a/docs/src/conf.py b/docs/src/conf.py index 02bd041db..7c72f4cef 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -54,7 +54,7 @@ "sphinx_toggleprompt", ] -example_subdirs = ["pcovr", "selection", "regression", "reconstruction"] +example_subdirs = ["pcovr", "selection", "regression", "reconstruction", "neighbors"] sphinx_gallery_conf = { "filename_pattern": "/*", "examples_dirs": [f"../../examples/{p}" for p in example_subdirs], diff --git a/docs/src/references/clustering.rst b/docs/src/references/clustering.rst new file mode 100644 index 000000000..3ac355806 --- /dev/null +++ b/docs/src/references/clustering.rst @@ -0,0 +1,11 @@ +Clustering +========== + +.. automodule:: skmatter.clustering + +.. _quick-shift-api: + +Quick Shift +------------ + +.. autoclass:: skmatter.clustering.QuickShift diff --git a/docs/src/references/datasets.rst b/docs/src/references/datasets.rst index 5d2968735..98c09023f 100644 --- a/docs/src/references/datasets.rst +++ b/docs/src/references/datasets.rst @@ -5,6 +5,8 @@ Datasets .. include:: ../../../src/skmatter/datasets/descr/degenerate_CH4_manifold.rst +.. include:: ../../../src/skmatter/datasets/descr/h2o-blyp-piglet.rst + .. include:: ../../../src/skmatter/datasets/descr/nice_dataset.rst .. include:: ../../../src/skmatter/datasets/descr/who_dataset.rst diff --git a/docs/src/references/index.rst b/docs/src/references/index.rst index 52125238b..e7bfe54a8 100644 --- a/docs/src/references/index.rst +++ b/docs/src/references/index.rst @@ -10,7 +10,9 @@ API Reference preprocessing selection linear_models + clustering decomposition metrics + neighbors datasets utils diff --git a/docs/src/references/metrics.rst b/docs/src/references/metrics.rst index 2ea0bb634..f01146f3b 100644 --- a/docs/src/references/metrics.rst +++ b/docs/src/references/metrics.rst @@ -40,3 +40,18 @@ Component-wise Prediction Rigidity ---------------------------------- .. autofunction:: skmatter.metrics.componentwise_prediction_rigidity + + +.. _pairwise-euclidian-api: + +Pairwise Euclidean Distances +---------------------------- + +.. autofunction:: skmatter.metrics.periodic_pairwise_euclidean_distances + +.. _pairwise-mahalanobis-api: + +Pairwise Mahalanobis Distance +----------------------------- + +.. autofunction:: skmatter.metrics.pairwise_mahalanobis_distances diff --git a/docs/src/references/neighbors.rst b/docs/src/references/neighbors.rst new file mode 100644 index 000000000..e96df5fcc --- /dev/null +++ b/docs/src/references/neighbors.rst @@ -0,0 +1,16 @@ +Neighbors +========= + +.. automodule:: skmatter.neighbors + +.. _sparse-kde-api: + +Sparse Kernel Density Estimation +-------------------------------- + +.. autoclass:: skmatter.neighbors.SparseKDE + :show-inheritance: + + .. automethod:: fit + .. automethod:: score_samples + .. automethod:: score diff --git a/docs/src/references/utils.rst b/docs/src/references/utils.rst index 41a017156..fbdec1a96 100644 --- a/docs/src/references/utils.rst +++ b/docs/src/references/utils.rst @@ -30,3 +30,14 @@ Random Partitioning with Overlaps --------------------------------- .. autofunction:: skmatter.model_selection.train_test_split + + +Effective Dimension of Covariance Matrix +---------------------------------------- + +.. autofunction:: skmatter.utils.effdim + +Oracle Approximating Shrinkage +------------------------------ + +.. autofunction:: skmatter.utils.oas diff --git a/docs/src/tutorials.rst b/docs/src/tutorials.rst index fa3461fd0..96381ea39 100644 --- a/docs/src/tutorials.rst +++ b/docs/src/tutorials.rst @@ -6,3 +6,4 @@ examples/selection/index examples/regression/index examples/reconstruction/index + examples/neighbors/index diff --git a/examples/neighbors/README.rst b/examples/neighbors/README.rst new file mode 100644 index 000000000..bdbac52bf --- /dev/null +++ b/examples/neighbors/README.rst @@ -0,0 +1,2 @@ +Neighbors +========= diff --git a/examples/neighbors/pamm.py b/examples/neighbors/pamm.py new file mode 100644 index 000000000..db36b4cb2 --- /dev/null +++ b/examples/neighbors/pamm.py @@ -0,0 +1,405 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Probabilistic Analysis of Molecular Motifs (PAMM) +================================================= + +Probabilistic analysis of molecular motifs (`PAMM `_) +is a method identifying molecular patterns based on an analysis of the probability +distribution of fragments observed in an atomistic simulation. With the help of sparse +KDE, it can be easily conducted. + +Here we define some functions to help us. `quick_shift_refinement` is used to refine the +clusters generated by `QuickShift` by merging outlier clusters into their nearest +neighbours. `generate_probability_model` is to interpret the quick shift results into +a probability model. `cluster_distribution_3D` is to plot the probability model +of the H-bond motif. +""" + + +# %% +from typing import Callable, Union + +import matplotlib.pyplot as plt +import numpy as np +from scipy.special import logsumexp + +from skmatter.clustering import QuickShift +from skmatter.datasets import load_hbond_dataset +from skmatter.feature_selection import FPS +from skmatter.metrics import periodic_pairwise_euclidean_distances +from skmatter.neighbors import SparseKDE +from skmatter.neighbors._sparsekde import _covariance +from skmatter.utils import oas + + +# %% +def quick_shift_refinement( + X: np.ndarray, + cluster_centers_idx: np.ndarray, + labels: np.ndarray, + probs: np.ndarray, + metric: Callable = periodic_pairwise_euclidean_distances, + metric_params: Union[dict, None] = None, + thrpcl: float = 0.0, +): + """ + Parameters + ---------- + X : np.ndarray + Input data for fitting of quick shift + cluster_centers_idx : np.ndarray + Index of the cluster centers in `X` + labels : np.ndarray + Labels of the input data, generated by `QuickShift` + probs : numpy.ndarray + Probability density of the input data + metric : Callable, default=pairwise_euclidean_distances + The metric to use. + metric_params : dict, default=None + Additional parameters to be passed to the use of + metric. i.e. the cell dimension for `periodic_euclidean` + {'cell': [2, 2]} + thrpcl : float, default=0.0 + Clusters with a pk lower than this value are merged with the nearest cluster. + """ + if metric_params is not None: + cell = metric_params["cell_length"] + if len(cell) != X.shape[1]: + raise ValueError("Cell dimension does not match the data dimension.") + else: + cell = None + + normpks = logsumexp(probs) + nk = len(cluster_centers_idx) + to_merge = np.full(nk, False) + + for k in range(nk): + dummd1 = np.exp(logsumexp(probs[labels == cluster_centers_idx[k]]) - normpks) + to_merge[k] = dummd1 > thrpcl + # merge the outliers + for i in range(nk): + if not to_merge[k]: + continue + dummd1yi1 = cluster_centers_idx[i] + dummd1 = np.inf + for j in range(nk): + if to_merge[k]: + continue + dummd2 = metric(X[labels[dummd1yi1]], X[labels[j]], cell=cell) + if dummd2 < dummd1: + dummd1 = dummd2 + cluster_centers_idx[i] = j + labels[labels == dummd1yi1] = cluster_centers_idx[i] + if sum(to_merge) > 0: + cluster_centers_idx = np.concatenate( + np.argwhere(labels == np.arange(len(labels))) + ) + nk = len(cluster_centers_idx) + for i in range(nk): + dummd1yi1 = cluster_centers_idx[i] + cluster_centers_idx[i] = np.argmax( + np.ma.array(probs, mask=labels != cluster_centers_idx[i]) + ) + labels[labels == dummd1yi1] = cluster_centers_idx[i] + + return cluster_centers_idx, labels + + +# %% +def generate_probability_model( + cluster_center_idx: np.ndarray, + labels: np.ndarray, + X: np.ndarray, + descriptors: np.ndarray, + descriptor_labels: np.ndarray, + descriptor_weights: np.ndarray, + probs: np.ndarray, + cell: np.ndarray = None, +): + """ + Generates a probability model based on the given inputs. + + Parameters + ---------- + cluster_center_idx : np.ndarray + Index of the cluster centers in `X` + labels : np.ndarray + Labels of the input data, generated by `QuickShift` + X : np.ndarray + Input data + descriptors : np.ndarray + Descriptors from original data set + descriptor_labels : np.ndarray + Labels of the descriptors, generated by + `skmatter.neighbors._sparsekde._NearestGridAssigner` + descriptor_weights : np.ndarray + Weights of the descriptors + probs : np.ndarray + Probability density of the input data + cell : np.ndarray + Cell dimension for distance metrics + """ + + def _update_cluster_cov( + X: np.ndarray, + k: int, + sample_labels: np.ndarray, + probs: np.ndarray, + idxroot: np.ndarray, + center_idx: np.ndarray, + ): + + if cell is not None: + cov = _get_lcov_clusterp( + len(X), nsamples, X, idxroot, center_idx[k], probs, cell + ) + if np.sum(idxroot == center_idx[k]) == 1: + cov = _get_lcov_clusterp( + nsamples, + nsamples, + descriptors, + sample_labels, + center_idx[k], + descriptor_weights, + cell, + ) + print("Warning: single point cluster!") + else: + cov = _get_lcov_cluster(len(X), X, idxroot, center_idx[k], probs, cell) + if np.sum(idxroot == center_idx[k]) == 1: + cov = _get_lcov_cluster( + nsamples, + descriptors, + sample_labels, + center_idx[k], + descriptor_weights, + cell, + ) + print("Warning: single point cluster!") + cov = oas( + cov, + logsumexp(probs[idxroot == center_idx[k]]) * nsamples, + X.shape[1], + ) + + return cov + + def _get_lcov_cluster( + N: int, + x: np.ndarray, + clroots: np.ndarray, + idcl: int, + probs: np.ndarray, + cell: np.ndarray, + ): + + ww = np.zeros(N) + normww = logsumexp(probs[clroots == idcl]) + ww[clroots == idcl] = np.exp(probs[clroots == idcl] - normww) + cov = _covariance(x, ww, cell) + + return cov + + def _get_lcov_clusterp( + N: int, + Ntot: int, + x: np.ndarray, + clroots: np.ndarray, + idcl: int, + probs: np.ndarray, + cell: np.ndarray, + ): + + ww = np.zeros(N) + totnormp = logsumexp(probs) + cov = np.zeros((x.shape[1], x.shape[1]), dtype=float) + xx = np.zeros(x.shape, dtype=float) + ww[clroots == idcl] = np.exp(probs[clroots == idcl] - totnormp) + ww *= Ntot + nlk = np.sum(ww) + for i in range(x.shape[1]): + xx[:, i] = x[:, i] - np.round(x[:, i] / cell[i]) * cell[i] + r2 = (np.sum(ww * np.cos(xx[:, i])) / nlk) ** 2 + ( + np.sum(ww * np.sin(xx[:, i])) / nlk + ) ** 2 + re2 = (nlk / (nlk - 1)) * (r2 - (1 / nlk)) + cov[i, i] = 1 / (np.sqrt(re2) * (2 - re2) / (1 - re2)) + + return cov + + if cell is not None and (X.shape[1] != len(cell)): + raise ValueError("Cell dimension does not match the data dimension.") + nclusters = len(cluster_center_idx) + nsamples = len(descriptors) + dimension = X.shape[1] + cluster_mean = np.zeros((nclusters, dimension), dtype=float) + cluster_cov = np.zeros((nclusters, dimension, dimension), dtype=float) + cluster_weight = np.zeros(nclusters, dtype=float) + center_idx = np.unique(labels) + normpks = logsumexp(probs) + + for k in range(nclusters): + cluster_weight[k] = np.exp(logsumexp(probs[labels == center_idx[k]]) - normpks) + cluster_cov[k] = _update_cluster_cov( + X, k, descriptor_labels, probs, labels, center_idx + ) + for k in range(nclusters): + labels[labels == center_idx[k]] = k + 1 + + return cluster_weight, cluster_mean, cluster_cov, labels + + +# %% +def cluster_distribution_3D( + grids: np.ndarray, + grid_weights: np.ndarray, + grid_label_: np.ndarray = None, + use_index: list[int] = None, + label_text: list[str] = None, + size_scale: float = 1e4, + fig_size: tuple[int, int] = (12, 12), +) -> tuple[plt.Figure, plt.Axes]: + """ + Generate a 3D scatter plot of the cluster distribution. + + Parameters + ---------- + grids (numpy.ndarray): The array containing the grid data. + use_index (Optional[list[int]]): The indices of the features to use for the + scatter plot. + If None, the first three features will be used. + label_text (Optional[list[str]]): The labels for the x, y, and z axes. + If None, the labels will be set to + 'Feature 0', 'Feature 1', and 'Feature 2'. + size_scale (float): The scale factor for the size of the scatter points. + Default is 1e4. + fig_size (tuple[int, int]): The size of the figure. Default is (12, 12) + + Returns + ------- + tuple[plt.Figure, plt.Axes]: A tuple containing the matplotlib + Figure and Axes objects. + """ + if use_index is None: + use_index = [0, 1, 2] + if label_text is None: + label_text = [f"Feature {i}" for i in range(3)] + + fig, ax = plt.subplots(subplot_kw={"projection": "3d"}, figsize=fig_size, dpi=100) + scatter = ax.scatter( + grids[:, use_index[0]], + grids[:, use_index[1]], + grids[:, use_index[2]], + c=grid_label_, + s=grid_weights * size_scale, + ) + legend1 = ax.legend(*scatter.legend_elements(), loc="lower left", title="Gaussian") + ax.add_artist(legend1) + ax.set_xlabel(label_text[0]) + ax.set_ylabel(label_text[1]) + ax.set_zlabel(label_text[2]) + + return fig, ax + + +# %% +# We first load our dataset: +# +# + +# %% +hbond_data = load_hbond_dataset() +descriptors = hbond_data["descriptors"] +weights = hbond_data["weights"] + +# %% +# We use the `FPS` class to select the `ngrid` descriptors with the highest. It is +# recommended to set the number of grids as the square root of the number of +# descriptors. Then we do the fit of the KDE. + + +# %% +ngrid = int(len(descriptors) ** 0.5) +selector = FPS(initialize=26310, n_to_select=ngrid) +selector.fit(descriptors.T) +selector.selected_idx_ +grids = descriptors[selector.selected_idx_] + +# %% +estimator = SparseKDE(descriptors, weights) +estimator.fit(grids) + +# %% +# Now we visualize the distribution and the weight of clusters. + +# %% +cluster_distribution_3D( + grids, estimator._sample_weights, label_text=[r"$\nu$", r"$\mu$", r"r"] +) + +# %% +# We need to estimate the probability at each grid point to do quick shift, which can +# further partition the set of grid points into several clusters. The resulting +# clusters can be interpreted as (meta-)stable states of the system. +# +# + +# %% +probs = estimator.score_samples(grids) +qscuts = np.array([np.trace(cov) for cov in estimator._covariance]) +clustering = QuickShift( + qscuts**2, + metric_params=estimator.metric_params, +) +clustering.fit(grids, samples_weight=probs) +cluster_centers_idx = clustering.cluster_centers_idx_ +labels = clustering.labels_ +normpks = logsumexp(probs) + +cluster_centers, labels = quick_shift_refinement( + grids, + cluster_centers_idx, + labels, + probs, + estimator.metric, + estimator.cell, +) + +# %% +# Based on the results, the Gaussian mixture model of the system can be generated: +# +# + +# %% +cluster_weights, cluster_means, cluster_covs, labels = generate_probability_model( + cluster_centers_idx, + labels, + grids, + estimator.descriptors, + estimator._sample_labels_, + estimator.weights, + probs, + estimator.cell, +) + +# %% +# The final result shows seven (meta-)stable states of hydrogen bond. Here we also show +# the reference hydrogen bond descriptor. The Gaussian with the largest weight locates +# closest to the reference point. This result shows that, with the help of the +# `SparseKDE` and `QuickShift` algorithm, we can easily identify the (meta-)stable +# states of the system objectively and without any prior knowledge about the system. +# +# + +# %% +REF_HB = np.array([0.82, 2.82, 2.74]) # The coordinate of the "standard" hydrogen bond + +fig, ax = cluster_distribution_3D( + grids, estimator._sample_weights, labels, label_text=[r"$\nu$", r"$\mu$", r"r"] +) +ax.scatter(REF_HB[0], REF_HB[1], REF_HB[2], marker="+", color="red", s=1000) + +# %% +f"The Gaussian with the highest probability is {np.argmax(cluster_weights) + 1}" diff --git a/examples/neighbors/sparse-kde.py b/examples/neighbors/sparse-kde.py new file mode 100644 index 000000000..a8148d42a --- /dev/null +++ b/examples/neighbors/sparse-kde.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Sparse KDE examples +=================== + +Example for the usage of the :class:`skmatter.neighbors.SparseKDE` class. This class is +specifically designed for conducting pobabilistic analysis of molecular motifs +(`PAMM `_), +which is quite useful for analyzing motifs like H-bonds, coordination polyhedra, and +protein secondary structure. + +Here we show how to use the sparse KDE model to fit the probability distribution based +on sampled data and how to use PAMM to analyze the H-bond. + +We start from a simple system, which is consist of three 2D Gaussians. Our task is to +estimate the parameters of these Gaussians from our sampled data. + +Here we first sample from these three Gaussians. +""" + + +# %% +import time + +import matplotlib.pyplot as plt +import numpy as np +from scipy.stats import gaussian_kde + +from skmatter.feature_selection import FPS +from skmatter.neighbors import SparseKDE + + +# %% +means = np.array([[0, 0], [4, 4], [6, -2]]) +covariances = np.array( + [[[1, 0.5], [0.5, 1]], [[1, 0.5], [0.5, 0.5]], [[1, -0.5], [-0.5, 1]]] +) +N_SAMPLES = 100_000 +samples = np.concatenate( + [ + np.random.multivariate_normal(means[0], covariances[0], N_SAMPLES), + np.random.multivariate_normal(means[1], covariances[1], N_SAMPLES), + np.random.multivariate_normal(means[2], covariances[2], N_SAMPLES), + ] +) + +# %% +# We can visualize our sample result: +# +# + +# %% +fig, ax = plt.subplots() +ax.scatter(samples[:, 0], samples[:, 1], alpha=0.05, s=1) +ax.scatter(means[:, 0], means[:, 1], marker="+", color="red", s=100) +ax.set_xlabel("x") +ax.set_ylabel("y") +plt.show() + +# %% +# Sparse KDE requires a discretization of the sample space. Here, we use +# the FPS method to generate grid points in the sample space: +# +# + +# %% +start1 = time.time() +selector = FPS(n_to_select=int(np.sqrt(3 * N_SAMPLES))) +grids = selector.fit_transform(samples.T).T +end1 = time.time() +fig, ax = plt.subplots() +ax.scatter(samples[:, 0], samples[:, 1], alpha=0.05, s=1) +ax.scatter(means[:, 0], means[:, 1], marker="+", color="red", s=100) +ax.scatter(grids[:, 0], grids[:, 1], color="orange", s=1) +ax.set_xlabel("x") +ax.set_ylabel("y") +plt.show() + +# %% +# Now we can do sparse KDE (usually takes tens of seconds): +# +# + +# %% +start2 = time.time() +estimator = SparseKDE(samples, None, fpoints=0.5) +estimator.fit(grids) +end2 = time.time() + +# %% +# We can have a comparison with the original sampling result by plotting them. +# +# For the convenience, we create a class for the Gaussian mixture model to help us plot +# the result. + + +# %% +class GaussianMixtureModel: + + def __init__( + self, + weights: np.ndarray, + means: np.ndarray, + covariances: np.ndarray, + period: np.ndarray = None, + ): + self.weights = weights + self.means = means + self.covariances = covariances + self.period = period + self.dimension = self.means.shape[1] + self.cov_inv = np.linalg.inv(self.covariances) + self.cov_det = np.linalg.det(self.covariances) + self.norm = 1 / np.sqrt((2 * np.pi) ** self.dimension * self.cov_det) + + def __call__(self, x: np.ndarray, i: int = None): + + if len(x.shape) == 1: + x = x[np.newaxis, :] + if self.period is not None: + xij = np.zeros(self.means.shape) + xij = rij(self.period, xij, x) + else: + xij = x - self.means + p = ( + self.weights + * self.norm + * np.exp( + -0.5 * (xij[:, np.newaxis, :] @ self.cov_inv @ xij[:, :, np.newaxis]) + ).reshape(-1) + ) + sum_p = np.sum(p) + if i is None: + return sum_p + + return np.sum(p[i]) / sum_p + + +# %% +def rij(period: np.ndarray, xi: np.ndarray, xj: np.ndarray) -> np.ndarray: + """Get the position vectors between two points. PBC are taken into account.""" + xij = xi - xj + if period is not None: + xij -= np.round(xij / period) * period + + return xij + + +# %% +# The original model that we want to fit: +original_model = GaussianMixtureModel(np.full(3, 1 / 3), means, covariances) +# The fitted model: +fitted_model = GaussianMixtureModel( + estimator._sample_weights, estimator._grids, estimator.bandwidth_ +) + +# To plot the probability density contour, we need to create a grid of points: +x, y = np.meshgrid(np.linspace(-6, 12, 100), np.linspace(-8, 8)) +points = np.concatenate(np.stack([x, y], axis=-1)) +probs = np.array([original_model(point) for point in points]) +fitted_probs = np.array([fitted_model(point) for point in points]) + +fig, ax = plt.subplots() +ct1 = ax.contour(x, y, probs.reshape(x.shape), colors="blue") +ct2 = ax.contour(x, y, fitted_probs.reshape(x.shape), colors="orange") +h1, _ = ct1.legend_elements() +h2, _ = ct2.legend_elements() +ax.legend( + [h1[0], h2[0]], + ["original", "fitted"], +) +ax.set_xlabel("x") +ax.set_ylabel("y") +plt.show() + +# %% +# The performance of the probability density estimation can be characterized by the +# Mean Integrated Squared Error (MISE), which is defined as: +# :math:`\text{MISE}=\text{E}[\int (\hat{P}(\textbf{x})-P(\textbf{x}))^2 d\textbf{x}]` + +# %% +RMSE = np.sum((probs - fitted_probs) ** 2 * (x[0][1] - x[0][0]) * (y[1][0] - y[0][0])) +print(f"Time sparse-kde: {end2 - start2} s") +print(f"RMSE = {RMSE:.2e}") + +# %% +# We can compare the result with the KDE class from scipy. (Usually takes +# several minutes to run) + +# %% +data = np.vstack([x.ravel(), y.ravel()]) +start = time.time() +kde = gaussian_kde(samples.T) +sklearn_probs = kde(data).T +end = time.time() +print(f"Time scipy: {end - start} s") +RMSE_kde = np.sum( + (probs - sklearn_probs) ** 2 * (x[0][1] - x[0][0]) * (y[1][0] - y[0][0]) +) +print(f"RMSE_kde = {RMSE_kde:.2e}") + +# %% +# We can see that the fitted model can perfectly capture the original one. Even though +# we have not specified the number of the Gaussians, it can still perform well. This +# allows us to fit distributions of the data automatically at a comparable quality +# within a much shorter time than scipy. diff --git a/src/skmatter/clustering/__init__.py b/src/skmatter/clustering/__init__.py new file mode 100644 index 000000000..d6ce15c41 --- /dev/null +++ b/src/skmatter/clustering/__init__.py @@ -0,0 +1,11 @@ +r""" +The module implements the quick shift clustering algorithm, which is used in +probabilistic analysis of molecular motifs (PAMM). See `Gasparotto and Ceriotti +`_ for more details. +""" + +from ._quick_shift import QuickShift + +__all__ = [ + "QuickShift", +] diff --git a/src/skmatter/clustering/_quick_shift.py b/src/skmatter/clustering/_quick_shift.py new file mode 100644 index 000000000..92fc4eb6e --- /dev/null +++ b/src/skmatter/clustering/_quick_shift.py @@ -0,0 +1,252 @@ +from typing import Callable, Union + +import numpy as np +from numpy.typing import ArrayLike +from sklearn.base import BaseEstimator +from tqdm import tqdm + +from ..metrics._pairwise import periodic_pairwise_euclidean_distances + + +class QuickShift(BaseEstimator): + """Conducts quick shift clustering. + + This class is used to implement the quick shift clustering algorithm, + which is used in probabilistic analysis of molecular motifs (PAMM). There + are two ways of searching the next point: (1) search for the point within the given + distance cutoff and (2) search for the point within the given number of neighbor + shell of the Gabriel graph. If both of them are set, the distance cutoff + is used. + + Parameters + ---------- + dist_cutoff_sq : float, default=None + The squared distance cutoff for searching for the next point. Two points are + considered as neighbors if they are within this distance. If :obj:`None`, + the scheme of Gabriel graph is used. + gabriel_shell : int, default=None + The number of neighbor shell of Gabriel graph for searching for the next point. + For example, if the number is 1, two points will be considered as neighbors if + they have at least one common neighbor, like for the case "A-B-C", we will + consider "A-C" as neighbors. If the number is 2, for the case "A-B-C-D", + we will consider "A-D" as neighbors. If :obj:`None`, the scheme of distance + cutoff is used. + scale : float, default=1.0 + Distance cutoff scaling factor used during the QS clustering. It will be squared + since the squared distance is used in this class. + metric : Callable[[ArrayLike, ArrayLike, bool, dict], ArrayLike], \ + default= :func:`skmatter.metrics.pairwise_euclidean_distances()` + The metric to use. Your metric should be able to take at least three arguments + in secquence: `X`, `Y`, and `squared=True`. Here, `X` and `Y` are two array-like + of shape (n_samples, n_components). The return of the metric is an array-like of + shape (n_samples, n_samples). If you want to use periodic boundary + conditions, be sure to provide the cell length in the ``metric_params`` and + provide a metric that can take the cell argument. + metric_params : dict, default=None + Additional parameters to be passed to the use of + metric. i.e. the dimension of a rectangular cell of side length :math:`a_i` + for :func:`skmatter.metrics.pairwise_euclidean_distances()` + `{'cell_length': [a_1, a_2, ..., a_n]}` + + Attributes + ---------- + labels_ : numpy.ndarray + An array of labels for each input data. + cluster_centers_idx_ : numpy.ndarray + An array of indices of cluster centers. + cluster_centers_ : numpy.ndarray + An array of cluster centers. + + Examples + -------- + >>> import numpy as np + >>> from skmatter.clustering import QuickShift + + Create some points and their weights for quick shift clustering + + >>> feature1 = np.array([-1.72, -4.44, 0.54, 3.19, -1.13, 0.55]) + >>> feature2 = np.array([-1.32, -2.13, -2.43, -0.49, 2.33, 0.18]) + >>> points = np.vstack((feature1, feature2)).T + >>> weights = np.array([-3.94, -12.68, -7.07, -9.03, -8.26, -2.61]) + + Set cutoffs for seraching + + >>> cuts = np.array([6.99, 8.80, 7.68, 9.51, 8.07, 6.22]) + + Do the clustering + + >>> model = QuickShift(cuts).fit(points, samples_weight=weights) + >>> print(model.labels_) + [0 0 0 5 5 5] + >>> print(model.cluster_centers_idx_) + [0 5] + + We can also apply a periodic boundary condition + + >>> model = QuickShift(cuts, metric_params={"cell_length": [3, 3]}) + >>> model = model.fit(points, samples_weight=weights) + >>> print(model.labels_) + [5 5 5 5 5 5] + >>> print(model.cluster_centers_idx_) + [5] + + Since the searching cuts are all larger than the maximum distance in the PBC box, + it can be expected that all points are assigned to the same cluster, of the center + that has the largest weight. + """ + + def __init__( + self, + dist_cutoff_sq: Union[float, None] = None, + gabriel_shell: Union[int, None] = None, + scale: float = 1.0, + metric: Callable[ + [ArrayLike, ArrayLike, bool, dict], ArrayLike + ] = periodic_pairwise_euclidean_distances, + metric_params: Union[dict, None] = None, + ): + if (dist_cutoff_sq is None) and (gabriel_shell is None): + raise ValueError("Either dist_cutoff or gabriel_depth must be set.") + self.dist_cutoff_sq = dist_cutoff_sq + self.gabriel_shell = gabriel_shell + self.scale = scale + if self.dist_cutoff_sq is not None: + self.dist_cutoff_sq *= self.scale**2 + self.metric_params = ( + metric_params if metric_params is not None else {"cell_length": None} + ) + self.metric = lambda X, Y: metric(X, Y, squared=True, **self.metric_params) + if isinstance(self.metric_params, dict): + self.cell = self.metric_params["cell_length"] + else: + self.cell = None + + def fit(self, X, y=None, samples_weight=None): + """Fit the model using X as training data and y as target values. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training data, where `n_samples` is the number of samples + and `n_features` is the number of features. + Y : None + Ignored. This parameter exists only for compatibility with + :class:`~sklearn.pipeline.Pipeline`. + samples_weight : array-like of shape (n_samples,), default=None + List of sample weights attached to the data X. This parameter + must be given in order to do the quick shift clustering. + """ + if (self.cell is not None) and (X.shape[1] != len(self.cell)): + raise ValueError( + "Dimension of the cell length does not match the data dimension." + ) + dist_matrix = self.metric(X, X) + np.fill_diagonal(dist_matrix, np.inf) + if self.dist_cutoff_sq is None: + gabrial = _get_gabriel_graph(dist_matrix) + idmindist = np.argmin(dist_matrix, axis=1) + idxroot = np.full(dist_matrix.shape[0], -1, dtype=int) + for i in tqdm(range(dist_matrix.shape[0]), desc="Quick-Shift"): + if idxroot[i] != -1: + continue + qspath = [] + qspath.append(i) + current = qspath[-1] + while current != idxroot[current]: + if self.gabriel_shell is not None: + idxroot[current] = self._gs_next( + current, samples_weight, dist_matrix, gabrial + ) + else: + idxroot[current] = self._qs_next( + current, + idmindist[current], + samples_weight, + dist_matrix, + self.dist_cutoff_sq[current], + ) + if idxroot[idxroot[current]] != -1: + # Found a path to a root + break + qspath.append(idxroot[current]) + current = qspath[-1] + idxroot[qspath] = idxroot[idxroot[current]] + + self.labels_ = idxroot + self.cluster_centers_idx_ = np.concatenate( + np.argwhere(idxroot == np.arange(dist_matrix.shape[0])) + ) + self.cluster_centers_ = X[self.cluster_centers_idx_] + + return self + + def _gs_next( + self, + idx: int, + probs: np.ndarray, + distmm: np.ndarray, + gabriel: np.ndarray, + ): + """Find next cluster in Gabriel graph.""" + ngrid = len(probs) + neighs = np.copy(gabriel[idx]) + for _ in range(1, self.gabriel_shell): + nneighs = np.full(ngrid, False) + for j in range(ngrid): + if neighs[j]: + # j can be accessed from idx + # j's neighbors can also be accessed from idx + nneighs |= gabriel[j] + neighs |= nneighs + + next_idx = idx + dmin = np.inf + for j in range(ngrid): + if probs[j] > probs[idx] and distmm[idx, j] < dmin and neighs[j]: + # find the closest neighbor + next_idx = j + dmin = distmm[idx, j] + + return next_idx + + def _qs_next( + self, idx: int, idxn: int, probs: np.ndarray, distmm: np.ndarray, cutoff: float + ): + """Find next cluster with respect to cutoff.""" + ngrid = len(probs) + dmin = np.inf + next_idx = idx + if probs[idxn] > probs[idx]: + next_idx = idxn + for j in range(ngrid): + if probs[j] > probs[idx] and distmm[idx, j] < min(dmin, cutoff): + next_idx = j + dmin = distmm[idx, j] + + return next_idx + + +def _get_gabriel_graph(dist_matrix_sq: np.ndarray): + """ + Generate the Gabriel graph based on the given squared distance matrix. + + Parameters + ---------- + dist_matrix_sq : np.ndarray + The squared distance matrix of shape (n_points, n_points). + + Returns + ------- + np.ndarray + The Gabriel graph matrix of shape (n_points, n_points). + """ + n_points = dist_matrix_sq.shape[0] + gabriel = np.full((n_points, n_points), True) + for i in tqdm(range(n_points), desc="Calculating Gabriel graph"): + gabriel[i, i] = False + for j in range(i, n_points): + if np.sum(dist_matrix_sq[i] + dist_matrix_sq[j] < dist_matrix_sq[i, j]): + gabriel[i, j] = False + gabriel[j, i] = False + + return gabriel diff --git a/src/skmatter/datasets/__init__.py b/src/skmatter/datasets/__init__.py index c72113195..b0846e7a1 100644 --- a/src/skmatter/datasets/__init__.py +++ b/src/skmatter/datasets/__init__.py @@ -3,6 +3,7 @@ from ._base import ( load_csd_1000r, load_degenerate_CH4_manifold, + load_hbond_dataset, load_nice_dataset, load_roy_dataset, load_who_dataset, @@ -12,6 +13,7 @@ __all__ = [ "load_degenerate_CH4_manifold", "load_csd_1000r", + "load_hbond_dataset", "load_nice_dataset", "load_roy_dataset", "load_who_dataset", diff --git a/src/skmatter/datasets/_base.py b/src/skmatter/datasets/_base.py index 90dd7f38c..b7cb701a7 100644 --- a/src/skmatter/datasets/_base.py +++ b/src/skmatter/datasets/_base.py @@ -143,3 +143,30 @@ def load_roy_dataset(): structure_types=properties["structure_types"], features=properties["feats"], ) + + +def load_hbond_dataset(): + """Load and returns the hydrogen bond dataset, which contains + a set of 3D descriptors for 27233 hydrogen bonds and corresponding + weights, from [Gasparotto et Al, The Journal of Chemical Physics] + (https://doi.org/10.1063/1.4900655) + + Returns + ------- + hbond_dataset : sklearn.utils.Bunch + Dictionary-like object, with the following attributes: + descriptors : `numpy.ndarray` -- the descriptors of hydrogen bond dataset + weights : `numpy.ndarray` -- the weights of each sample in the dataset + """ + module_path = dirname(__file__) + target_filename = join(module_path, "data", "h2o-blyp-piglet.npz") + raw_data = np.load(target_filename) + + with open(join(module_path, "descr", "h2o-blyp-piglet.rst")) as rst_file: + fdescr = rst_file.read() + + return Bunch( + descriptors=raw_data["descriptors"], + weights=raw_data["weights"], + DESCR=fdescr, + ) diff --git a/src/skmatter/datasets/data/h2o-blyp-piglet.npz b/src/skmatter/datasets/data/h2o-blyp-piglet.npz new file mode 100644 index 000000000..9f9112f0d Binary files /dev/null and b/src/skmatter/datasets/data/h2o-blyp-piglet.npz differ diff --git a/src/skmatter/datasets/descr/h2o-blyp-piglet.rst b/src/skmatter/datasets/descr/h2o-blyp-piglet.rst new file mode 100644 index 000000000..92387415a --- /dev/null +++ b/src/skmatter/datasets/descr/h2o-blyp-piglet.rst @@ -0,0 +1,35 @@ +.. _water: + +H2O-BLYP-Piglet +############### + +This dataset contains 27233 hydrogen bond descriptors and corresponding weights from a +trajectory of a classical simulation performed with a BLYP exchange-correlation +functional and a DZVP basis set. The simulation box contined 64 water molecules. This +dataset was originally published in +[Gasparotto2014]_. + +Function Call +------------- + +.. function:: skmatter.datasets.load_hbond_dataset + +Data Set Characteristics +------------------------ + +:Number of Instances: 27233 + +:Number of Features: 3 + +Reference +--------- + +[1] https://github.com/lab-cosmo/pamm/tree/master/examples/water + +Reference Code +-------------- + +[2] https://github.com/GardevoirX/pypamm/blob/master/tutorials/water/tutorial.ipynb + +[3] https://github.com/lab-cosmo/pamm/blob/master/examples/water/README + diff --git a/src/skmatter/metrics/__init__.py b/src/skmatter/metrics/__init__.py index 16cfe8f04..50fa04acf 100644 --- a/src/skmatter/metrics/__init__.py +++ b/src/skmatter/metrics/__init__.py @@ -36,6 +36,20 @@ kernel model. * :ref:`CPR-api` (CPR) computes the component-wise prediction rigidity of a linear or kernel model. + +There are also two distance metrics compatible with the periodic boundary conditions +available. + + .. note:: + Currently only rectangular cells are supported. + Cell format: [side_length_1, ..., side_length_n] + +* :ref:`pairwise-euclidian-api` computes the euclidean distance between two sets + of points. It is compatible with the periodic boundary conditions. + If the cell length is not provided, it will fall back to the ``scikit-learn`` version + of the euclidean distance :func:`sklearn.metrics.pairwise.euclidean_distances`. +* :ref:`pairwise-mahalanobis-api` computes the Mahalanobis distance between two sets + of points. It is compatible with the periodic boundary conditions. """ from ._reconstruction_measures import ( @@ -54,6 +68,11 @@ componentwise_prediction_rigidity, ) +from ._pairwise import ( + periodic_pairwise_euclidean_distances, + pairwise_mahalanobis_distances, +) + __all__ = [ "pointwise_global_reconstruction_error", "global_reconstruction_error", @@ -65,4 +84,10 @@ "check_local_reconstruction_measures_input", "local_prediction_rigidity", "componentwise_prediction_rigidity", + "periodic_pairwise_euclidean_distances", + "pairwise_mahalanobis_distances", ] + +DIST_METRICS = { + "periodic_euclidean": periodic_pairwise_euclidean_distances, +} diff --git a/src/skmatter/metrics/_pairwise.py b/src/skmatter/metrics/_pairwise.py new file mode 100644 index 000000000..4455a7b3a --- /dev/null +++ b/src/skmatter/metrics/_pairwise.py @@ -0,0 +1,174 @@ +from typing import Union + +import numpy as np +from sklearn.metrics.pairwise import _euclidean_distances, check_pairwise_arrays + + +def periodic_pairwise_euclidean_distances( + X, + Y=None, + *, + squared=False, + cell_length=None, +): + r""" + Compute the pairwise distance matrix between each pair from a vector array X and Y. + + .. math:: + d_{i, j} = \\sqrt{\\sum_{k=1}^n (x_{i, k} - y_{j, k})^2} + + For efficiency reasons, the euclidean distance between a pair of row + vector x and y is computed as:: + + dist(x, y) = sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y)) + + This formulation has two advantages over other ways of computing distances. First, + it is computationally efficient when dealing with sparse data. Second, if one + argument varies but the other remains unchanged, then `dot(x, x)` and/or `dot(y, y)` + can be pre-computed. + + However, this is not the most precise way of doing this computation, because this + equation potentially suffers from "catastrophic cancellation". Also, the distance + matrix returned by this function may not be exactly symmetric as required by, e.g., + ``scipy.spatial.distance`` functions. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples_X, n_components) + An array where each row is a sample and each column is a component. + Y : {array-like, sparse matrix} of shape (n_samples_Y, n_components), \ + default=None + An array where each row is a sample and each column is a component. + If `None`, method uses `Y=X`. + cell_length : array-like of shape (n_components,), default=None + The side length of rectangular cell used for periodic boundary conditions. + `None` for non-periodic boundary conditions. + + .. note:: + Only side lengths of rectangular cells are supported. + Cell format: `[side_length_1, ..., side_length_n]` + + Returns + ------- + distances : ndarray of shape (n_samples_X, n_samples_Y) + Returns the distances between the row vectors of `X` + and the row vectors of `Y`. + + Examples + -------- + >>> import numpy as np + >>> from skmatter.metrics import periodic_pairwise_euclidean_distances + >>> X = np.array([[0, 1], [1, 1]]) + >>> origin = np.array([[0, 0]]) + >>> # distance between rows of X + >>> periodic_pairwise_euclidean_distances(X, X) + array([[0., 1.], + [1., 0.]]) + >>> # get distance to origin + >>> periodic_pairwise_euclidean_distances(X, origin, cell_length=[0.5, 0.7]) + array([[0.3], + [0.3]]) + """ + _check_dimension(X, cell_length) + X, Y = check_pairwise_arrays(X, Y) + + if cell_length is None: + return _euclidean_distances(X, Y, squared=squared) + else: + return _periodic_euclidean_distances(X, Y, squared=squared, cell=cell_length) + + +def _periodic_euclidean_distances(X, Y=None, *, squared=False, cell=None): + X, Y = np.array(X).astype(float), np.array(Y).astype(float) + XY = np.concatenate([x - Y for x in X]) + XY -= np.round(XY / cell) * cell + distance = np.linalg.norm(XY, axis=1).reshape(X.shape[0], Y.shape[0]) + if squared: + distance **= 2 + return distance + + +def pairwise_mahalanobis_distances( + X: np.ndarray, + Y: np.ndarray, + cov_inv: np.ndarray, + cell_length: Union[np.ndarray, None] = None, + squared: bool = False, +): + r""" + Calculate the pairwise Mahalanobis distance between two arrays. + + This metric is used for calculating the distances between observations from Gaussian + distributions. It is defined as: + + .. math:: + d_{\Sigma}(x, y)^2 = (x - y)^T \Sigma^{-1} (x - y) + + where :math:`\Sigma` is the covariance matrix, :math:`x` and :math:`y` are + observations from the same distribution. + + Parameters + ---------- + X : numpy.ndarray of shape (n_samples_X, n_components) + An array where each row is a sample and each column is a component. + Y : np.ndarray of shape (n_samples_Y, n_components) + An array where each row is a sample and each column is a component. + cov_inv : np.ndarray + The inverse covariance matrix of shape (n_components, n_components). + cell_length : np.ndarray, optinal, default=None + The cell size for periodic boundary conditions. + None for non-periodic boundary conditions. + + .. note:: + Only cubic cells are supported. + Cell format: `[side_length_1, ..., side_length_n]` + + squared : bool, default=False + Whether to return the squared distance. + + Returns + ------- + np.ndarray + The pairwise Mahalanobis distance between the two input arrays, + of shape `(cov_inv.shape[0], x.shape[0], y.shape[0])`. + + Examples + -------- + >>> import numpy as np + >>> from skmatter.metrics import pairwise_mahalanobis_distances + >>> iv = np.array([[1, 0.5, 0.5], [0.5, 1, 0.5], [0.5, 0.5, 1]]) + >>> X = np.array([[1, 0, 0], [0, 2, 0], [2, 0, 0]]) + >>> Y = np.array([[0, 1, 0]]) + >>> pairwise_mahalanobis_distances(X, Y, iv) + array([[[1. ], + [1. ], + [1.73205081]]]) + """ + + def _mahalanobis( + cell: np.ndarray, X: np.ndarray, Y: np.ndarray, cov_inv: np.ndarray + ): + + XY = np.concatenate([x - Y for x in X]) + if cell is not None: + XY -= np.round(XY / cell) * cell + + return np.sum(XY * np.transpose(cov_inv @ XY.T, (0, 2, 1)), axis=-1).reshape( + (cov_inv.shape[0], X.shape[0], Y.shape[0]) + ) + + _check_dimension(X, cell_length) + X, Y = check_pairwise_arrays(X, Y) + if len(cov_inv.shape) == 2: + cov_inv = cov_inv[np.newaxis, :, :] + dists = _mahalanobis(cell_length, X, Y, cov_inv) + if not squared: + dists **= 0.5 + return dists + + +def _check_dimension(X, cell_length): + if (cell_length is not None) and (X.shape[1] != len(cell_length)): + raise ValueError("Cell dimension does not match the data dimension.") diff --git a/src/skmatter/neighbors/__init__.py b/src/skmatter/neighbors/__init__.py new file mode 100644 index 000000000..6acc8a54d --- /dev/null +++ b/src/skmatter/neighbors/__init__.py @@ -0,0 +1,22 @@ +""" +The module implements the sparse kernel density estimator. + +A large dataset can be generated during the molecular dynamics sampling. The +distribution of the sampled data reflects the (free) energetic stability of molecular +patterns. The KDE model can be used to characterize the probability distribution, and +thus to identify the stable patterns in the system. However, the computational +cost of KDE is `O(N^2)` where `N` is the number of sampled points, which is very +expensive. Here we offer a sparse implementation of the KDE model with a +`O(MN)` computational cost, where `M` is the number of grid points generated from the +sampled data. + +The following class is available: + +* :ref:`sparse-kde-api` computes the kernel density estimator based on a set of grid + points generated from the sampled data. + +""" + +from ._sparsekde import SparseKDE + +__all__ = ["SparseKDE"] diff --git a/src/skmatter/neighbors/_sparsekde.py b/src/skmatter/neighbors/_sparsekde.py new file mode 100644 index 000000000..b98c6a9e1 --- /dev/null +++ b/src/skmatter/neighbors/_sparsekde.py @@ -0,0 +1,672 @@ +import warnings +from typing import Callable, Union + +import numpy as np +from numpy.typing import ArrayLike +from scipy.special import logsumexp as LSE +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted, check_random_state +from tqdm import tqdm + +from ..metrics._pairwise import ( + pairwise_mahalanobis_distances, + periodic_pairwise_euclidean_distances, +) +from ..utils._sparsekde import effdim, oas + + +class SparseKDE(BaseEstimator): + """A sparse implementation of the Kernel Density Estimation. + This class is used to build a sparse kernel density estimator. + It takes a set of descriptors and a set of weights as input, + and fit the KDE model on the sampled point (e.g. the grid point + selected by FPS). + + .. note:: + Currently only the Gaussian kernel is supported. + + Parameters + ---------- + descriptors: numpy.ndarray + Descriptors of the system where you want to build a sparse KDE. + It should be an array of shape `(n_descriptors, n_features)`. + weights: numpy.ndarray, default=None + Weights of the descriptors. + If None, all weights are set to `1/n_descriptors`. + metric : Callable[[ArrayLike, ArrayLike, bool, dict], ArrayLike], + default=:func:`skmatter.metrics.pairwise_euclidean_distances()` + The metric to use. Your metric should be able to take at least three arguments + in secquence: `X`, `Y`, and `squared=True`. Here, `X` and `Y` are two array-like + of shape (n_samples, n_components). The return of the metric is an array-like of + shape (n_samples, n_samples). If you want to use periodic boundary + conditions, be sure to provide the cell size in the metric_params and + provide a metric that can take the cell argument. + metric_params : dict, default=None + Additional parameters to be passed to the use of + metric. i.e. the cell dimension for + :func:`skmatter.metrics.pairwise_euclidean_distances()` + `{'cell_length': [side_length_1, ..., side_length_n]}` + fspread : float, default=-1.0 + The fractional "space" occupied by the voronoi cell of each grid. Use this when + each cell is of a similar size. + fpoints : float, default=0.15 + The fractional number of points in the voronoi cell of each grid points. Use + this when each cell has a similar number of points. + kernel : str, default=gaussian + The kernel used here. Now only the Gaussian kernel is available. + verbose : bool, default=False + Whether to print progress. + + + Attributes + ---------- + n_samples : int + The number of descriptors. + kdecut_squared : float + The cut-off value for the KDE. If the mahalanobis distance between two grid + points is larger than kdecut2, they are considered to be far away. + cell : numpy.ndarray + The cell dimension for the metric. + bandwidth_: numpy.ndarray + The bandwidth of the KDE. + + + Examples + -------- + >>> import numpy as np + >>> from skmatter.neighbors import SparseKDE + >>> from skmatter.feature_selection import FPS + >>> np.random.seed(0) + >>> n_samples = 10_000 + + To create two Gaussians with different means and covariance and sample from them + + >>> cov1 = [[1, 0.5], [0.5, 1]] + >>> cov2 = [[1, 0.5], [0.5, 0.5]] + >>> sample1 = np.random.multivariate_normal([0, 0], cov1, n_samples) + >>> sample2 = np.random.multivariate_normal([4, 4], cov2, n_samples) + >>> samples = np.concatenate([sample1, sample2]) + + To select grid points using FPS + + >>> selector = FPS(n_to_select=int(np.sqrt(2 * n_samples))) + >>> result = selector.fit_transform(samples.T).T + + Conduct sparse KDE based on the grid points + + >>> estimator = SparseKDE(samples, None, fpoints=0.5) + >>> _ = estimator.fit(result) + + The total log-likelihood under the model + + >>> print(round(estimator.score(result), 3)) + -759.831 + """ + + def __init__( + self, + descriptors: np.ndarray, + weights: Union[np.ndarray, None] = None, + metric: Callable[ + [ArrayLike, ArrayLike, bool, dict], ArrayLike + ] = periodic_pairwise_euclidean_distances, + metric_params: Union[dict, None] = None, + fspread: float = -1.0, + fpoints: float = 0.15, + kernel: str = "gaussian", + verbose: bool = False, + ): + self.metric_params = ( + metric_params if metric_params is not None else {"cell_length": None} + ) + self.metric = lambda X, Y: metric(X, Y, squared=True, **self.metric_params) + self.cell = metric_params["cell_length"] if metric_params is not None else None + self._check_dimension(descriptors) + self.descriptors = descriptors + self.weights = weights if weights is not None else np.ones(len(descriptors)) + self.weights /= np.sum(self.weights) + self.fspread = fspread + self.fpoints = fpoints + self.kernel = kernel + if self.kernel != "gaussian": + raise NotImplementedError + self.verbose = verbose + + if self.fspread > 0: + self.fpoints = -1.0 + + self.bandwidth_ = None + self._covariance = None + + @property + def nsamples(self): + return len(self.descriptors) + + @property + def ndimension(self): + return self.descriptors.shape[1] + + @property + def kdecut_squared(self): + return (3 * (np.sqrt(self.descriptors.shape[1]) + 1)) ** 2 + + @property + def _bandwidth_inv(self): + if self.fitted_: + if self._bandwidth_inv_ is None: + self._bandwidth_inv_ = np.array( + [np.linalg.inv(h) for h in self.bandwidth_] + ) + else: + raise ValueError("The model is not fitted yet.") + return self._bandwidth_inv_ + + @property + def _normkernels(self): + if self.fitted_: + if self._normkernels_ is None: + self._normkernels_ = np.array( + [ + self.ndimension * np.log(2 * np.pi) + np.linalg.slogdet(h)[1] + for h in self.bandwidth_ + ] + ) + else: + raise ValueError("The model is not fitted yet.") + return self._normkernels_ + + def fit(self, X, y=None, sample_weight=None): + """Fit the Kernel Density model on the data. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. + + y : None + Ignored. This parameter exists only for compatibility with + :class:`~sklearn.pipeline.Pipeline`. + + sample_weight : array-like of shape (n_samples,), default=None + List of sample weights attached to the data X. This parameter + is ignored. Instead of reading sample_weight from the input, + it is calculated internally. + + + Returns + ------- + self : object + Returns the instance itself. + """ + # Initialize/reset the cached properties, _bandwidth_inv and _normkernels + self._bandwidth_inv_ = None + self._normkernels_ = None + self._check_dimension(X) + self._grids = X + grid_dist_mat = self.metric(X, X) + np.fill_diagonal(grid_dist_mat, np.inf) + min_grid_dist = np.min(grid_dist_mat, axis=1) + _, self._grid_neighbour, self._sample_labels_, self._sample_weights = ( + self._assign_descriptors_to_grids(X) + ) + self._computes_localized_bandwidth(X, self._sample_weights, min_grid_dist) + + self.fitted_ = True + + return self + + def score_samples(self, X): + """Compute the log-likelihood of each sample under the model. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + An array of points to query. Last dimension should match dimension + of training data (n_features). + + Returns + ------- + density : ndarray of shape (n_samples,) + Log-likelihood of each sample in `X`. These are normalized to be + probability densities, so values will be low for high-dimensional + data. + """ + return self._computes_kernel_density_estimation(X) + + def score(self, X, y=None): + """Compute the total log-likelihood under the model. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. + + y : None + Ignored. This parameter exists only for compatibility with + :class:`~sklearn.pipeline.Pipeline`. + + Returns + ------- + logprob : float + Total log-likelihood of the data in X. This is normalized to be a + probability density, so the value will be low for high-dimensional + data. + """ + return np.sum(self.score_samples(X)) + + def sample(self, n_samples=1, random_state=None): + """Generate random samples from the model. + + Parameters + ---------- + n_samples : int, default=1 + Number of samples to generate. + + random_state : int, RandomState instance or None, default=None + Determines random number generation used to generate + random samples. Pass an int for reproducible results + across multiple function calls. + See :term:`Glossary `. + + Returns + ------- + X : array-like of shape (n_samples, n_features) + List of samples. + """ + check_is_fitted(self) + rng = check_random_state(random_state) + u = rng.uniform(0, 1, size=n_samples) + cumsum_weight = np.cumsum(np.asarray(self._sample_weights)) + sum_weight = cumsum_weight[-1] + idxs = np.searchsorted(cumsum_weight, u * sum_weight) + + return np.concatenate( + [ + np.atleast_2d( + rng.multivariate_normal(self._grids[i], self.bandwidth_[i]) + ) + for i in idxs + ] + ) + + def _check_dimension(self, X): + if (self.cell is not None) and (X.shape[1] != len(self.cell)): + raise ValueError("Cell dimension does not match the data dimension.") + + def _assign_descriptors_to_grids(self, X): + + assigner = _NearestGridAssigner(self.metric, self.metric_params, self.verbose) + assigner.fit(X) + labels = assigner.predict(self.descriptors, sample_weight=self.weights) + grid_npoints = assigner.grid_npoints + grid_neighbour = assigner.grid_neighbour + + return grid_npoints, grid_neighbour, labels, assigner.grid_weight + + def _computes_localized_bandwidth( + self, X, sample_weights: np.ndarray, mindist: np.ndarray + ): + """Compute the localized bandwidth of the kernel density estimator + on grid points. + """ + # estimate the sigma + cov = _covariance(X, sample_weights, self.cell) + if self.cell is not None: + tune = sum(self.cell**2) + else: + tune = np.trace(cov) + sigma2 = np.full(len(X), tune, dtype=float) + + # initialize the localization based on fraction of data spread + if self.fspread > 0: + sigma2 *= self.fspread**2 + flocal = np.zeros(len(X)) + self.bandwidth_ = np.zeros((len(X), X.shape[1], X.shape[1])) + self._covariance = np.zeros((len(X), X.shape[1], X.shape[1])) + + for i in tqdm( + range(len(X)), + desc="Estimating kernel density bandwidths", + disable=not self.verbose, + ): + wlocal, flocal[i] = _local_population( + self.cell, X, X[i], sample_weights, sigma2[i] + ) + if self.fpoints > 0: + sigma2, flocal, wlocal = ( + self._tune_localization_factor_based_on_fraction_of_points( + X, sample_weights, sigma2, flocal, i, 1 / self.nsamples, tune + ) + ) + elif sigma2[i] < flocal[i]: + sigma2, flocal, wlocal = ( + self._tune_localization_factor_based_on_fraction_of_spread( + X, sample_weights, sigma2, flocal, i, mindist + ) + ) + self.bandwidth_[i], self._covariance[i] = ( + self._bandwidth_estimation_from_localization(X, wlocal, flocal, i) + ) + + def _tune_localization_factor_based_on_fraction_of_points( + self, X, sample_weights, sigma2, flocal, idx, delta, tune + ): + """Used in cases where one expects clusters with very different spreads, + but similar populations + """ + lim = self.fpoints + if lim <= sample_weights[idx]: + lim = sample_weights[idx] + delta + warnings.warn( + " Warning: localization smaller than voronoi," + " increase grid size (meanwhile adjusted localization)!", + stacklevel=2, + ) + while flocal[idx] < lim: + sigma2[idx] += tune + wlocal, flocal[idx] = _local_population( + self.cell, X, X[idx], sample_weights, sigma2[idx] + ) + j = 1 + while True: + if flocal[idx] > lim: + sigma2[idx] -= tune / 2**j + else: + sigma2[idx] += tune / 2**j + wlocal, flocal[idx] = _local_population( + self.cell, X, X[idx], sample_weights, sigma2[idx] + ) + if abs(flocal[idx] - lim) < delta: + break + j += 1 + + return sigma2, flocal, wlocal + + def _tune_localization_factor_based_on_fraction_of_spread( + self, X, sample_weights, sigma2, flocal, idx, mindist + ): + """Used in cases where one expects the spatial extent of clusters to be + relatively homogeneous + """ + sigma2[idx] = mindist[idx] + wlocal, flocal[idx] = _local_population( + self.cell, self.descriptors, X, sample_weights, sigma2[idx] + ) + + return sigma2, flocal, wlocal + + def _bandwidth_estimation_from_localization(self, X, wlocal, flocal, idx): + """Compute the bandwidth based on localized version of Silverman's rule""" + cov = _covariance(X, wlocal, self.cell) + nlocal = flocal[idx] * self.nsamples + local_dimension = effdim(cov) + cov = oas(cov, nlocal, X.shape[1]) + # localized version of Silverman's rule + h = (4.0 / nlocal / (local_dimension + 2.0)) ** ( + 2.0 / (local_dimension + 4.0) + ) * cov + + return h, cov + + def _computes_kernel_density_estimation(self, X: np.ndarray): + + prob = np.full(len(X), -np.inf) + dummd1s_mat = pairwise_mahalanobis_distances( + X, self._grids, self._bandwidth_inv, self.cell, squared=True + ) + for i in tqdm( + range(len(X)), + desc="Computing kernel density on reference points", + disable=not self.verbose, + ): + for j, dummd1 in enumerate(np.diagonal(dummd1s_mat[:, i, :])): + # The second point is the mean corresponding to the cov + if dummd1 > self.kdecut_squared: + lnk = -0.5 * (self._normkernels[j] + dummd1) + np.log( + self._sample_weights[j] + ) + prob[i] = LSE([prob[i], lnk]) + else: + neighbours = self._grid_neighbour[j][ + np.any( + self.descriptors[self._grid_neighbour[j]] != X[i], axis=1 + ) + ] + if neighbours.size == 0: + continue + dummd1s = pairwise_mahalanobis_distances( + self.descriptors[neighbours], + X[i][np.newaxis, ...], + self._bandwidth_inv[j], + self.cell, + squared=True, + ).reshape(-1) + lnks = -0.5 * (self._normkernels[j] + dummd1s) + np.log( + self.weights[neighbours] + ) + prob[i] = LSE(np.concatenate([[prob[i]], lnks])) + + prob -= np.log(np.sum(self._sample_weights)) + + return prob + + +class _NearestGridAssigner: + """Assign descriptor to its nearest grid. This is an auxilirary class. + + Parameters + ---------- + metric : + The metric to use. + Currently only `sklearn.metrics.pairwise.pairwise_euclidean_distances`. + metric_params : dict, default=None + Additional parameters to be passed to the use of + metric. i.e. the cell dimension for ``periodic_euclidean`` + {'cell_length': [2, 2]} + verbose : bool, default=False + Whether to print progress. + + Attributes + ---------- + grid_pos : np.ndarray + An array of grid positions. + grid_npoints : np.ndarray + An array of number of points in each grid. + grid_weight : np.ndarray + An array of weights in each grid. + grid_neighbour : dict + A dictionary of neighbor lists for each grid. + labels_ : np.ndarray + An array of labels for each descriptor. + """ + + def __init__( + self, + metric, + metric_params: Union[dict, None] = None, + verbose: bool = False, + ) -> None: + + self.labels_ = None + self.metric = metric + self.metric_params = metric_params + self.verbose = verbose + if isinstance(self.metric_params, dict): + self.cell = self.metric_params["cell_length"] + else: + self.cell = None + self.grid_pos = None + self.grid_npoints = None + self.grid_weight = None + self.grid_neighbour = None + + def fit(self, X: np.ndarray, y: Union[np.ndarray, None] = None) -> None: + """Fit the data. + + Parameters + ---------- + X : np.ndarray + An array of grid positions. + y : np.ndarray, optional, default=None + Igonred. + """ + ngrid = len(X) + self.grid_pos = X + self.grid_npoints = np.zeros(ngrid, dtype=int) + self.grid_weight = np.zeros(ngrid, dtype=float) + self.grid_neighbour = {i: [] for i in range(ngrid)} + + def predict( + self, + X: np.ndarray, + y: Union[np.ndarray, None] = None, + sample_weight: Union[np.ndarray, None] = None, + ) -> np.ndarray: + """ + Predicts labels for input data and returns an array of labels. + + Parameters + ---------- + X : np.ndarray + Input data to predict labels for. + y : np.ndarray, optional, default=None + Igonred. + sample_weight : np.ndarray, optional + Sample weights for each data point. + + Returns + ------- + np.ndarray + Array of predicted labels. + """ + if sample_weight is None: + sample_weight = np.ones(len(X)) / len(X) + self.labels_ = [] + for i, point in tqdm( + enumerate(X), + desc="Assigning samples to grids...", + total=len(X), + disable=not self.verbose, + ): + descriptor2grid = self.metric(point.reshape(1, -1), self.grid_pos) + self.labels_.append(np.argmin(descriptor2grid)) + self.grid_npoints[self.labels_[-1]] += 1 + self.grid_weight[self.labels_[-1]] += sample_weight[i] + self.grid_neighbour[self.labels_[-1]].append(i) + + for key in self.grid_neighbour: + self.grid_neighbour[key] = np.array(self.grid_neighbour[key]) + + return self.labels_ + + +def _covariance(X: np.ndarray, sample_weights: np.ndarray, cell: np.ndarray): + """ + Calculate the covariance matrix for a given set of grid positions and weights. + + Parameters + ---------- + X : np.ndarray + An array of shape (nsample, dimension) representing the grid positions. + sample_weights : np.ndarray + An array of shape (nsample,) representing the weights of the grid positions. + cell : np.ndarray + An array of shape (dimension,) representing the periodicity of each dimension. + + Returns + ------- + np.ndarray + The covariance matrix of shape (dimension, dimension). + + Notes + ----- + The function assumes that the grid positions, weights, + and total weight are provided correctly. + The function handles periodic and non-periodic dimensions differently to + calculate the covariance matrix. + """ + totw = np.sum(sample_weights) + + if cell is None: + xm = np.average(X, axis=0, weights=sample_weights / totw) + else: + sumsin = np.average( + np.sin(X) * (2 * np.pi) / cell, + axis=0, + weights=sample_weights / totw, + ) + sumcos = np.average( + np.cos(X) * (2 * np.pi) / cell, + axis=0, + weights=sample_weights / totw, + ) + xm = np.arctan2(sumsin, sumcos) + + xxm = X - xm + if cell is not None: + xxm -= np.round(xxm / cell) * cell + xxmw = xxm * sample_weights.reshape(-1, 1) / totw + cov = xxmw.T.dot(xxm) + cov /= 1 - sum((sample_weights / totw) ** 2) + + return cov + + +def _local_population( + cell: np.ndarray, + grid_j: np.ndarray, + grid_i: np.ndarray, + grid_j_weight: np.ndarray, + sigma_squared: float, +): + r""" + Calculates the local population of a selected grid. The local population is defined + as a sum of the weighting factors for each other grid arond it. + + .. math:: + N_i = \\sum_j u_{ij} + + where :math:`u_{ij}` is the weighting factor. The weighting factor is calculated + from an spherical Gaussian + + .. math:: + u_{ij} = \\exp\\left[-\\frac{(x_i - x_j)^2}{2\\sigma^2} \\right] N w_j / + \\sum_j w_j + + where :math:`w_j` is the weighting factor for each other grid, :math:`N` is the + number of grid points, and :math:`\\sigma` is the localization factor. + + Parameters + ---------- + cell : np.ndarray + An array of periods for each dimension of the grid. + grid_j : np.ndarray + An array of vectors of the grid around the selected grid. + grid_i : np.ndarray + An array of the vector of the selected grid. + grid_j_weight : np.ndarray + An array of weights for each target vector. + sigma_squared : float + The localization factor for the spherical Gaussian. + + + Returns + ------- + tuple + A tuple containing two numpy arrays: + wl : np.ndarray + An array of localized weights for each vector. + num : np.ndarray + The sum of the localized weights. + + """ + xy = grid_j - grid_i + if cell is not None: + xy -= np.round(xy / cell) * cell + + wl = np.exp(-0.5 / sigma_squared * np.sum(xy**2, axis=1)) * grid_j_weight + num = np.sum(wl) + + return wl, num diff --git a/src/skmatter/utils/__init__.py b/src/skmatter/utils/__init__.py index 459e1caa4..2f0c6b969 100644 --- a/src/skmatter/utils/__init__.py +++ b/src/skmatter/utils/__init__.py @@ -19,6 +19,11 @@ no_progress_bar, ) +from ._sparsekde import ( + effdim, + oas, +) + __all__ = [ "get_progress_bar", "no_progress_bar", @@ -29,4 +34,6 @@ "X_orthogonalizer", "Y_sample_orthogonalizer", "Y_feature_orthogonalizer", + "effdim", + "oas", ] diff --git a/src/skmatter/utils/_sparsekde.py b/src/skmatter/utils/_sparsekde.py new file mode 100644 index 000000000..ba366ca00 --- /dev/null +++ b/src/skmatter/utils/_sparsekde.py @@ -0,0 +1,80 @@ +"""The file holds utility functions and classes for the sparse KDE.""" + +import numpy as np + + +def effdim(cov): + """ + Calculate the effective dimension of a covariance matrix based on Shannon entropy. + + Parameters + ---------- + cov : numpy.ndarray + The covariance matrix. + + Returns + ------- + float + The effective dimension of the covariance matrix. + + Examples + -------- + >>> import numpy as np + >>> from skmatter.utils import effdim + >>> cov = np.array([[25, 15, -5], [15, 18, 0], [-5, 0, 11]], dtype=np.float64) + >>> print(round(effdim(cov), 3)) + 2.214 + + References + ---------- + https://ieeexplore.ieee.org/document/7098875 + """ + eigval = np.linalg.eigvals(cov) + if (lowest_eigval := np.min(eigval)) <= -np.max(cov.shape) * np.finfo( + cov.dtype + ).eps: + raise np.linalg.LinAlgError( + f"Matrix is not positive definite." + f"Lowest eigenvalue {lowest_eigval} is " + f"above numerical threshold." + ) + eigval[eigval < 0.0] = 0.0 + eigval /= sum(eigval) + eigval *= np.log(eigval) + + return np.exp(-sum(eigval)) + + +def oas(cov: np.ndarray, n: float, D: int) -> np.ndarray: + """ + Oracle approximating shrinkage (OAS) estimator + + Parameters + ---------- + cov : numpy.ndarray + A covariance matrix + n : float + The local population + D : int + Dimension + + Examples + -------- + >>> import numpy as np + >>> from skmatter.utils import oas + >>> cov = np.array([[0.5, 1.0], [0.7, 0.4]]) + >>> oas(cov, 10, 2) + array([[0.48903924, 0.78078484], + [0.54654939, 0.41096076]]) + + Returns + ------- + np.ndarray + Covariance matrix + """ + tr = np.trace(cov) + tr2 = tr**2 + tr_cov2 = np.trace(cov**2) + phi = ((1 - 2 / D) * tr_cov2 + tr2) / ((n + 1 - 2 / D) * tr_cov2 - tr2 / D) + + return (1 - phi) * cov + phi * np.eye(D) * tr / D diff --git a/tests/test_clustering.py b/tests/test_clustering.py new file mode 100644 index 000000000..73445eaea --- /dev/null +++ b/tests/test_clustering.py @@ -0,0 +1,59 @@ +import unittest + +import numpy as np + +from skmatter.clustering import QuickShift + + +class QuickShiftTests(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.points = np.array( + [ + [-1.72779275, -1.32763554], + [-4.44991964, -2.13474901], + [0.54817734, -2.43319467], + [3.19881307, -0.49547222], + [-1.1335991, 2.33478428], + [0.55437388, 0.18745963], + ] + ) + cls.cuts = np.array( + [6.99485011, 8.80292681, 7.68486852, 9.5115009, 8.07736919, 6.22057056] + ) + cls.weights = np.array( + [ + -3.94008092, + -12.68095664, + -7.07512499, + -9.03064023, + -8.26529849, + -2.61132267, + ] + ) + cls.qs_labels_ = np.array([0, 0, 0, 5, 5, 5]) + cls.qs_cluster_centers_idx_ = np.array([0, 5]) + cls.gabriel_labels_ = np.array([5, 5, 5, 5, 5, 5]) + cls.gabriel_cluster_centers_idx_ = np.array([5]) + cls.cell = [3, 3] + cls.gabriel_shell = 2 + + def test_fit_qs(self): + model = QuickShift(dist_cutoff_sq=self.cuts) + model.fit(self.points, samples_weight=self.weights) + self.assertTrue(np.all(model.labels_ == self.qs_labels_)) + self.assertTrue( + np.all(model.cluster_centers_idx_ == self.qs_cluster_centers_idx_) + ) + + def test_fit_garbriel(self): + model = QuickShift(gabriel_shell=self.gabriel_shell) + model.fit(self.points, samples_weight=self.weights) + self.assertTrue(np.all(model.labels_ == self.gabriel_labels_)) + self.assertTrue( + np.all(model.cluster_centers_idx_ == self.gabriel_cluster_centers_idx_) + ) + + def test_dimension_check(self): + model = QuickShift(self.cuts, metric_params={"cell_length": self.cell}) + self.assertRaises(ValueError, model.fit, np.array([[2]])) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 4d7340e6a..fbe278b47 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -5,6 +5,7 @@ from skmatter.datasets import ( load_csd_1000r, load_degenerate_CH4_manifold, + load_hbond_dataset, load_nice_dataset, load_roy_dataset, load_who_dataset, @@ -119,5 +120,21 @@ def test_dataset_content(self): self.assertEqual(len(self.roy["energies"]), self.size) +class HBondTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.size = 27233 + cls.shape = (27233, 3) + cls.hbond = load_hbond_dataset() + + def test_dataset_size_and_shape(self): + """ + Check if the correct number of datapoints are present in the dataset. + Also check if the size of the dataset is correct. + """ + self.assertEqual(self.hbond["descriptors"].shape, self.shape) + self.assertEqual(self.hbond["weights"].size, self.size) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 4e94d6848..cdaaf5519 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -11,6 +11,8 @@ global_reconstruction_error, local_prediction_rigidity, local_reconstruction_error, + pairwise_mahalanobis_distances, + periodic_pairwise_euclidean_distances, pointwise_local_reconstruction_error, ) @@ -214,5 +216,65 @@ def test_local_reconstruction_error_test_idx(self): ) +class DistanceTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.X = np.array([[1, 2], [3, 4], [5, 6]]) + cls.Y = np.array([[7, 8], [9, 10]]) + cls.covs = np.array([[[1, 0.5], [0.5, 1]], [[1, 0.0], [0.0, 1]]]) + cls.cell = [5, 7] + cls.distances = np.array( + [ + [8.48528137, 11.3137085], + [5.65685425, 8.48528137], + [2.82842712, 5.65685425], + ] + ) + cls.periodic_distances = np.array( + [ + [1.41421356, 2.23606798], + [3.16227766, 1.41421356], + [2.82842712, 3.16227766], + ] + ) + cls.mahalanobis_distances = np.array( + [ + [ + [10.39230485, 13.85640646], + [6.92820323, 10.39230485], + [3.46410162, 6.92820323], + ], + cls.distances, + ] + ) + + def test_euclidean_distance(self): + distances = periodic_pairwise_euclidean_distances(self.X, self.Y) + self.assertTrue( + np.allclose(distances, self.distances), + f"Calculated distance does not match expected value" + f"Calculated: {distances} Expected: {self.distances}", + ) + + def test_periodic_euclidean_distance(self): + distances = periodic_pairwise_euclidean_distances( + self.X, self.Y, cell_length=self.cell + ) + self.assertTrue( + np.allclose(distances, self.periodic_distances), + f"Calculated distance does not match expected value" + f"Calculated: {distances} Expected: {self.periodic_distances}", + ) + + def test_mahalanobis_distance(self): + distances = pairwise_mahalanobis_distances(self.X, self.Y, self.covs) + self.assertTrue( + np.allclose(distances, self.mahalanobis_distances), + f"Calculated distance does not match expected value" + f"Calculated: {distances} Expected: {self.mahalanobis_distances}", + ) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py new file mode 100644 index 000000000..fd6b4c0af --- /dev/null +++ b/tests/test_neighbors.py @@ -0,0 +1,119 @@ +import unittest + +import numpy as np + +from skmatter.feature_selection import FPS +from skmatter.neighbors import SparseKDE +from skmatter.neighbors._sparsekde import _covariance +from skmatter.utils import effdim, oas + + +class SparseKDETests(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + np.random.seed(0) + cls.n_samples_per_cov = 10000 + cls.samples = np.concatenate( + [ + np.random.multivariate_normal( + [0, 0], [[1, 0.5], [0.5, 1]], cls.n_samples_per_cov + ), + np.random.multivariate_normal( + [4, 4], [[1, 0.5], [0.5, 0.5]], cls.n_samples_per_cov + ), + ] + ) + cls.sample_results = np.array( + [[4.56393465, 4.20566218], [0.73562454, 1.11116178]] + ) + cls.selector = FPS(n_to_select=int(np.sqrt(2 * cls.n_samples_per_cov))) + cls.grids = cls.selector.fit_transform(cls.samples.T).T + cls.expect_score_fp = -759.831 + cls.expect_score_fs = -781.567 + + cls.cell = np.array([4, 4]) + cls.expect_score_periodic = -456.744 + + def test_sparse_kde(self): + estimator = SparseKDE(self.samples, None, fpoints=0.5) + estimator.fit(self.grids) + self.assertTrue(round(estimator.score(self.grids), 3) == self.expect_score_fp) + self.assertTrue(np.allclose(estimator.sample(2), self.sample_results)) + + def test_sparce_kde_fs(self): + estimator = SparseKDE(self.samples, None, fspread=0.5) + estimator.fit(self.grids) + self.assertTrue(round(estimator.score(self.grids), 3) == self.expect_score_fs) + + def test_sparse_kde_periodic(self): + estimator = SparseKDE( + self.samples, + None, + metric_params={"cell_length": self.cell}, + fpoints=0.5, + ) + estimator.fit(self.grids) + self.assertTrue( + round(estimator.score(self.grids), 3) == self.expect_score_periodic + ) + + def test_dimension_check(self): + estimator = SparseKDE( + self.samples, None, metric_params={"cell_length": self.cell}, fpoints=0.5 + ) + self.assertRaises(ValueError, estimator.fit, np.array([[4]])) + + def test_fs_fp_imcompatibility(self): + estimator = SparseKDE( + self.samples, + None, + metric_params={"cell_length": self.cell}, + fspread=2, + fpoints=0.5, + ) + self.assertTrue(estimator.fpoints == -1) + + +class CovarianceTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.X = np.array([[1, 2], [3, 3], [4, 6]]) + cls.expected_cov = np.array( + [[2.33333333, 2.83333333], [2.83333333, 4.33333333]] + ) + cls.expected_cov_periodic = np.array( + [[1.12597216, 0.45645371], [0.45645371, 0.82318948]] + ) + cls.cell = np.array([3, 3]) + + def test_covariance(self): + cov = _covariance(self.X, np.full(len(self.X), 1 / len(self.X)), None) + self.assertTrue(np.allclose(cov, self.expected_cov)) + + def test_covariance_periodic(self): + cov = _covariance(self.X, np.full(len(self.X), 1 / len(self.X)), self.cell) + self.assertTrue(np.allclose(cov, self.expected_cov_periodic)) + + +class EffdimTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.cov = np.array([[1, 1, 0], [1, 1.5, 0], [0, 0, 1]], dtype=np.float64) + cls.expected_effdim = 2.24909102090124 + + def test_effdim(self): + self.assertTrue(np.allclose(effdim(self.cov), self.expected_effdim)) + + +class OASTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.cov = np.array([[0.5, 1.0], [0.7, 0.4]]) + cls.n = 10 + cls.D = 2 + cls.expected_oas = np.array( + [[0.48903924, 0.78078484], [0.54654939, 0.41096076]] + ) + + def test_oas(self): + self.assertTrue(np.allclose(oas(self.cov, self.n, self.D), self.expected_oas))