diff --git a/python_notebooks/Chapter_1_Exploratory_analysis_and_unsupervised_learning.ipynb b/python_notebooks/Chapter_1_Exploratory_analysis_and_unsupervised_learning.ipynb index 82259ef..5d06518 100644 --- a/python_notebooks/Chapter_1_Exploratory_analysis_and_unsupervised_learning.ipynb +++ b/python_notebooks/Chapter_1_Exploratory_analysis_and_unsupervised_learning.ipynb @@ -3203,6 +3203,197 @@ "# %load -r 89- solutions/solution_01_kmean.py" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[Back to ToC](#toc)\n", + "\n", + "## What is the best method for clustering ? \n", + "\n", + "As you surely suspect by now, there is no perfect method. \n", + "Each algorithm makes different assumptions about the structure of your data and will thus behave well or bad depending on howyour data is actually structured.\n", + "\n", + "Let's demonstrate by displaying how different algorithm perform on different dataset (example taken from the [sklearn documentation](https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html)):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import warnings\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from sklearn import cluster, datasets, mixture\n", + "from sklearn.neighbors import kneighbors_graph\n", + "from sklearn.preprocessing import StandardScaler\n", + "from itertools import cycle, islice\n", + "\n", + "np.random.seed(0)\n", + "\n", + "# ============\n", + "# Generate datasets. We choose the size big enough to see the scalability\n", + "# of the algorithms, but not too big to avoid too long running times\n", + "# ============\n", + "n_samples = 1500\n", + "noisy_circles = datasets.make_circles(n_samples=n_samples, factor=.5,\n", + " noise=.05)\n", + "noisy_moons = datasets.make_moons(n_samples=n_samples, noise=.05)\n", + "blobs = datasets.make_blobs(n_samples=n_samples, random_state=8)\n", + "no_structure = np.random.rand(n_samples, 2), None\n", + "\n", + "# Anisotropicly distributed data\n", + "random_state = 170\n", + "X, y = datasets.make_blobs(n_samples=n_samples, random_state=random_state)\n", + "transformation = [[0.6, -0.6], [-0.4, 0.8]]\n", + "X_aniso = np.dot(X, transformation)\n", + "aniso = (X_aniso, y)\n", + "\n", + "# blobs with varied variances\n", + "varied = datasets.make_blobs(n_samples=n_samples,\n", + " cluster_std=[1.0, 2.5, 0.5],\n", + " random_state=random_state)\n", + "\n", + "# ============\n", + "# Set up cluster parameters\n", + "# ============\n", + "plt.figure(figsize=(9 * 2 + 3, 13))\n", + "plt.subplots_adjust(left=.02, right=.98, bottom=.001, top=.95, wspace=.05,\n", + " hspace=.01)\n", + "\n", + "plot_num = 1\n", + "\n", + "default_base = {'quantile': .3,\n", + " 'eps': .3,\n", + " 'damping': .9,\n", + " 'preference': -200,\n", + " 'n_neighbors': 10,\n", + " 'n_clusters': 3,\n", + " 'min_samples': 20,\n", + " 'xi': 0.05,\n", + " 'min_cluster_size': 0.1}\n", + "\n", + "datasets = [\n", + " (noisy_circles, {'damping': .77, 'preference': -240,\n", + " 'quantile': .2, 'n_clusters': 2,\n", + " 'min_samples': 20, 'xi': 0.25}),\n", + " (noisy_moons, {'damping': .75, 'preference': -220, 'n_clusters': 2}),\n", + " (varied, {'eps': .18, 'n_neighbors': 2,\n", + " 'min_samples': 5, 'xi': 0.035, 'min_cluster_size': .2}),\n", + " (aniso, {'eps': .15, 'n_neighbors': 2,\n", + " 'min_samples': 20, 'xi': 0.1, 'min_cluster_size': .2}),\n", + " (blobs, {}),\n", + " (no_structure, {})]\n", + "\n", + "for i_dataset, (dataset, algo_params) in enumerate(datasets):\n", + " # update parameters with dataset-specific values\n", + " params = default_base.copy()\n", + " params.update(algo_params)\n", + "\n", + " X, y = dataset\n", + "\n", + " # normalize dataset for easier parameter selection\n", + " X = StandardScaler().fit_transform(X)\n", + "\n", + " # estimate bandwidth for mean shift\n", + " bandwidth = cluster.estimate_bandwidth(X, quantile=params['quantile'])\n", + "\n", + " # connectivity matrix for structured Ward\n", + " connectivity = kneighbors_graph(\n", + " X, n_neighbors=params['n_neighbors'], include_self=False)\n", + " # make connectivity symmetric\n", + " connectivity = 0.5 * (connectivity + connectivity.T)\n", + "\n", + " # ============\n", + " # Create cluster objects\n", + " # ============\n", + " ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)\n", + " two_means = cluster.MiniBatchKMeans(n_clusters=params['n_clusters'])\n", + " ward = cluster.AgglomerativeClustering(\n", + " n_clusters=params['n_clusters'], linkage='ward',\n", + " connectivity=connectivity)\n", + " spectral = cluster.SpectralClustering(\n", + " n_clusters=params['n_clusters'], eigen_solver='arpack',\n", + " affinity=\"nearest_neighbors\")\n", + " dbscan = cluster.DBSCAN(eps=params['eps'])\n", + " optics = cluster.OPTICS(min_samples=params['min_samples'],\n", + " xi=params['xi'],\n", + " min_cluster_size=params['min_cluster_size'])\n", + " affinity_propagation = cluster.AffinityPropagation(\n", + " damping=params['damping'], preference=params['preference'])\n", + " average_linkage = cluster.AgglomerativeClustering(\n", + " linkage=\"average\", affinity=\"cityblock\",\n", + " n_clusters=params['n_clusters'], connectivity=connectivity)\n", + " birch = cluster.Birch(n_clusters=params['n_clusters'])\n", + " gmm = mixture.GaussianMixture(\n", + " n_components=params['n_clusters'], covariance_type='full')\n", + "\n", + " clustering_algorithms = (\n", + " ('MiniBatch\\nKMeans', two_means),\n", + " ('Affinity\\nPropagation', affinity_propagation),\n", + " ('MeanShift', ms),\n", + " ('Spectral\\nClustering', spectral),\n", + " ('Ward', ward),\n", + " ('Agglomerative\\nClustering', average_linkage),\n", + " ('DBSCAN', dbscan),\n", + " ('OPTICS', optics),\n", + " ('BIRCH', birch),\n", + " ('Gaussian\\nMixture', gmm)\n", + " )\n", + "\n", + " for name, algorithm in clustering_algorithms:\n", + " t0 = time.time()\n", + "\n", + " # catch warnings related to kneighbors_graph\n", + " with warnings.catch_warnings():\n", + " warnings.filterwarnings(\n", + " \"ignore\",\n", + " message=\"the number of connected components of the \" +\n", + " \"connectivity matrix is [0-9]{1,2}\" +\n", + " \" > 1. Completing it to avoid stopping the tree early.\",\n", + " category=UserWarning)\n", + " warnings.filterwarnings(\n", + " \"ignore\",\n", + " message=\"Graph is not fully connected, spectral embedding\" +\n", + " \" may not work as expected.\",\n", + " category=UserWarning)\n", + " algorithm.fit(X)\n", + "\n", + " t1 = time.time()\n", + " if hasattr(algorithm, 'labels_'):\n", + " y_pred = algorithm.labels_.astype(int)\n", + " else:\n", + " y_pred = algorithm.predict(X)\n", + "\n", + " plt.subplot(len(datasets), len(clustering_algorithms), plot_num)\n", + " if i_dataset == 0:\n", + " plt.title(name, size=18)\n", + "\n", + " colors = np.array(list(islice(cycle(['#377eb8', '#ff7f00', '#4daf4a',\n", + " '#f781bf', '#a65628', '#984ea3',\n", + " '#999999', '#e41a1c', '#dede00']),\n", + " int(max(y_pred) + 1))))\n", + " # add black color for outliers (if any)\n", + " colors = np.append(colors, [\"#000000\"])\n", + " plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred])\n", + "\n", + " plt.xlim(-2.5, 2.5)\n", + " plt.ylim(-2.5, 2.5)\n", + " plt.xticks(())\n", + " plt.yticks(())\n", + " plt.text(.99, .01, ('%.2fs' % (t1 - t0)).lstrip('0'),\n", + " transform=plt.gca().transAxes, size=15,\n", + " horizontalalignment='right')\n", + " plot_num += 1\n", + "\n", + "plt.show()" + ] + }, { "cell_type": "markdown", "metadata": {},