Skip to content

A simple visualization tool of how GMMs are fit to data using simple animations

Notifications You must be signed in to change notification settings

friedmanroy/GMM-visualizations

Repository files navigation

Gaussian Mixture Models

Gaussian Mixture Models (GMMs for short) are statistical models that are commenly used for clustering. In GMMs we assume that the datapoints we are trying to learn originated from a mixture of Gaussian distributions (hence the name). A GMM distribution can, in general, be very complex (try and play around with this example), but can be broken down into multiple simple components; a number of Gaussian distributions.

GMMs are typically trained using the Expectation Maximization (EM) algorithm - you can read about that in more details in this comprehensive blog post or in this summary I wrote. The whole point of this repository is to try and make the whole optimization process a bit more intuitive.

Simple GMM Visualization

To begin with, let's vizualize how a GMM is fitted to 2D points.

Fitting a GMM to a simple distribution

The simplest demo for how a GMM is fitted is to create some points from a GMM and then fit a GMM model to the same points. This is exactly what the optimization process is trying to do, so usually the model succeeds in doing exactly that, however the process can be "broken" sometimes. If few points were generated by a large amount of Gaussians, or all of the Gaussians are squeezed into a small area, the fitted model will diverge significantly from the distribution the points were generated from. This makes sense but is quite interesting to actually see. Note that, almost always, the fitted model will diverge from the actual distribution; this makes sense as there are a finite number of points to fit against. Try playing around for yourself to see these effects and gain an intuition for what's happening.

Using the visualization tool

Using the following line:

python visualize_gmm.py -k 5 -m 5

Produces a simple demo with data points that were generated from 5 random Gaussians (controlled by the argument -m), that is fitted to a GMM with 5 clusters (controlled by the argument k), similar to the following: Demo GMM

What are we seeing here? This is an animation of the optimization process for a GMM using the EM algorithm. Each frame is an iteration of the EM algorithm, with:

  • The points are the samples the GMM is being fitted to, colored to the cluster that the model thinks they belong to
  • The gray ellipses are contour lines of the Gaussians that the data was generated from (the real Gaussians)
  • The colored crosses are the means of the GMM
  • The colored ellipses are the contour lines of the fitted Gaussians (the ones that the GMM is comprised of)

One thing that really stands out is that this animation looks almost exactly like the optimization process for k-means; in fact, it becomes even more similar when you don't plot the ellipses (the covariances). You can see for yourself by using the flag --hide_covs together with --hide_real, in the following line: python visualize_gmm.py -k 5 -m 5 --hide_covs --hide_real. So why does it look so much like k-means? That's because k-means is also fitted using a (sort of) EM algorithm! In fact, the GMM is a sort of generalized k-means; instead of simply finding the centers of mass of each cluster, we also say that they are distributed in some manner. The big advantage over k-means is that we now actually hold the distribution itself, and can generate new points from it or ask how likely it is to see a point in any place around the space.

Fitting a GMM to a more complex distribution

Any 2D data, saved as a numpy ndarray in a .npy file with shape [N, 2], can also be fitted to a GMM using this tool; the following line demonstrates how:

python visualize_gmm.py -k <choose number of clusters> --load_path <your .npy file path here> [--print_ll]

You can add the flag --print_ll to track progress. An example of how to load data:

python visualize_gmm.py -k 20 --load_path examples/circles.npy --print_ll -i 100 --fps 15

The result should be something similar to: Demo Circles

Note that the number of clusters we would (idealy) like to fit to in the above example is 3, one for each "ring". The GMM fails miserably at this, instead seperating all of the points into three portions (try it out for yourself).

About

A simple visualization tool of how GMMs are fit to data using simple animations

Topics

Resources

Stars

Watchers

Forks

Languages