DEV Community

Lance Galletti
Lance Galletti

Posted on

GMM Clustering From Scratch

In this article you will learn how to implement the EM algorithm for solving GMM clustering from scratch.

Image description

Your friend, who works at Jurassic Park, needs to routinely record the weights of the various dinosaurs to monitor their health and make sure they are each in a normal range for their species. This time though, they forgot to label which weights corresponds to which dino species so they don’t know what range to compare each weight against… Luckily, they know how many different species are in the park but they need your help to figure out which species a given weight is most likely associated with.

By the end of this article, you will be able to help your friend.

Breaking down the task

This is not an easy problem. For each weight in our sample, we need to report, for each species, the probability that the given weight comes from that species. Formally, we need to find the following conditional probability: P(SjXi)P(S_j | X_i)

Where SjS_j is the jthj^{th} species and XiX_i is a specific animal weight from the dataset your friend gave you. Computing this value is highly complex because:

  1. Some dinosaurs are more common than others: for example there are many more Stegosauruses than Raptors in the park. This means a given data point, knowing nothing about it would just have a higher chance of being a Stegosaurus than a Raptor.
  2. The weights of different species vary differently. For example, the weights of a Sauropod might be similar to a bell curve, symmetric around an average weight about 100 tons. But the weights of Maiasaura might differ greatly between male and female so we might observe more of a bimodal distribution (with peaks at each of the average weight of males and females).

Image description

Image description

Doing the math and applying Bayes Theorem reveals these probabilities:

P(SjXi)=P(XiSj)P(Sj)P(Xi)P(S_j | X_i) = \frac{P(X_i | S_j)P(S_j)}{P(X_i)}


  1. P(Sj)P(S_j) is the prior probability of seeing species SjS_j (that probability would be higher for the Stegosauruses than the Raptors for example).
  2. P(XiSj)P(X_i | S_j) is the PDF of species SjS_j weights evaluated at weight XiX_i (seeing a Sauropod that weighs 100 tons is way more likely than seeing a Raptor that weighs 100 tons)

What about P(Xi)P(X_i) ?

To compute P(Xi)P(X_i) , we need to understand the distribution of XiX_i . Let’s work with a simple example where there are only two species in the park: the Stegosaurus and the Ankylosaurus. If we looked at the distribution of the weights of all the Stegosauruses, we would see a normal distribution around 4 metric tons, with a standard deviation of about .1 tons. Looking at the Ankylosaurus, we would observe a normal distribution around a mean of 5 tons with a standard deviation of about .5 tons.

Image description

To get the distribution of the weights of both we need to account for how likely it is to meet a Stegosaurus compared to an Ankylosaurus. Keeping things simple, we can assume they have equal numbers of Stegosauruses than Ankylosauruses in the park. So the probability that we would encounter one vs the other would be 50%.

Image description

In general, from any number of species each with individual weight distributions, we can construct a combined distribution as above by simply specifying the proportion of each individual species we expect.

For example, we could have, in the park, 10% raptors, 25% Sauropods, 5% T-Rexs, 30% Stegosauruses, 15% Ankylosauruses, 15% Maiasaura. The individual weight distributions P(XiSj)P(X_i | S_j) could look like this:

Image description

To combine them, we factor in their relative proportion P(Sj)P(S_j) , as such

import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt

maiasaura = lambda x : .5 * norm.pdf(x, 7, .3) + .5 * norm.pdf(x, 8, .3)
stegosaurus = lambda x : norm.pdf(x, 4, .3)
ankylosaurus = lambda x : norm.pdf(x, 5, .5)
trex = lambda x : norm.pdf(x, 10, 1.5)
raptor = lambda x : norm.pdf(x, .7, .2)
sauropod = lambda x : norm.pdf(x, 20, 3)

x = np.arange(0, 30, .01)
plt.plot(x, .1 * raptor(x) + .25 * sauropod(x) + .05 * trex(x) + .3 * stegosaurus(x) + .15 * ankylosaurus(x) + .15 * maiasaura(x), color='blue')
Enter fullscreen mode Exit fullscreen mode

Image description

Hence, for a given weight XiX_i in the dataset, P(Xi)P(X_i) is computed by the weighted sum of the PDFs of each species’ weights as such:

P(Xi)=jP(Sj)P(XiSj)P(X_i) = \sum_j P(S_j)P(X_i | S_j)

We say that XiX_i follows a mixture distribution with a set number k of components. When every component has a Normal Distribution, we refer to that special case as a Gaussian Mixture Distribution.

Gaussian Mixture Model

Recall, our goal is to report back P(SjXi)P(S_j | X_i) for all weights and all species. So if there are k=10 species we would report back 10 probabilities per data point in the dataset. In order to compute P(SjXi)P(S_j | X_i) we need P(XiSj)P(X_i | S_j) which, could be any distribution with any number of parameters… To simplify things, GMM assumes that the data follows a Gaussian Mixture Distribution where every P(XiSj)P(X_i | S_j) is a Normal Distribution.

With this assumption, what do we need to know in order to compute P(SjXi)P(S_j | X_i) ?

The relative proportions of each species in the park P(Sj)P(S_j)
The parameters of each of the normal distributions P(XiSj)P(X_i | S_j) (which are μjμ_j and σjσ_j )

Maximum Likelihood Estimation

Suppose you are given a dataset of coin tosses and are asked to estimate the probability of Heads. How would you go about it? Let’s take the following sequence of coin tosses (which we can assume are independent):

H, T, T, H, T
Enter fullscreen mode Exit fullscreen mode

Given the limited information, the best we can do is find the probability of Heads that maximized the probability of having observed this particular sequences of coin tosses. Knowing that this coin can be modeled as a Bernoulli RV with probability p of Heads, we can formulate the probability of observing the above data as:

P(H,T,T,H,T)=(25)p2(1p)3P(H, T, T, H, T) = {2 \choose 5} p^2 (1-p)^3

To find the value of p that maximized the probability of observing the data we saw, we can find take the derivative of the above wrt p, set it equal to zero and solve for p.

Image description

Our estimate for p is then 2/5 which is the sample proportion of Heads in our dataset. And it’s the best we can do given the information we have. This approach is called the Maximum Likelihood Estimation approach where we estimate the parameters of the distribution that generated the dataset by finding the parameter values that maximize the probability of observing that dataset (i.e. assuming that the dataset is a sample that perfectly represents the distribution).

Maximum Likelihood Estimation of Gaussian Mixture Distribution parameters

We can use the same approach to estimate the parameters of the Gaussian Mixture Distribution that generated the data. Recall:

P(Xi)=jP(Sj)P(XiSj)P(X_i) = \sum_j P(S_j)P(X_i | S_j)

So the probability of seeing a dataset with N such values would be the product of those PDFs:

iP(Xi)=ijP(Sj)P(XiSj)\prod_i P(X_i) = \prod_i \sum_j P(S_j)P(X_i | S_j)

Taking the derivative of a product is hard… To make things easier we can take the log of the above to transform the product into a sum (which won’t change the critical points):

log(iP(Xi))=ilog(jP(Sj)P(XiSj))\log \left( \prod_i P(X_i) \right) = \sum_i \log \left( \sum_j P(S_j)P(X_i | S_j) \right)

Taking the derivative wrt all the parameters and setting it equal to zero we get the following estimates:

μj^=iP(SjXi)XiiP(SjXi)\hat{\mu_j} = \frac{\sum_i P(S_j | X_i) X_i}{\sum_i P(S_j | X_i)}
Σj^=iP(SjXi)(Xiμj^)T(Xiμj^)iP(SjXi)\hat{\Sigma_j} = \frac{\sum_i P(S_j | X_i) (X_i - \hat{\mu_j})^T (X_i - \hat{\mu_j})}{\sum_i P(S_j | X_i)}
P(Sj)^=1NiP(SjXi)\hat{P(S_j)} = \frac{1}{N} \sum_i P(S_j | X_i)

Something is strange here… Recall the entire reason we need to compute these values is to report P(SjXi)P(S_j | X_i) ! But in order to compute these values we need to know P(SjXi)P(S_j | X_i)

Expectation Maximization Algorithm

Since we need one value to compute the others and we need the others to compute the one, the EM Algorithm proposes the following approach:

  1. Start with random μj,Σj,P(Sj)\mu_j, \Sigma_j, P(S_j)
  2. Compute P(SjXi)P(S_j | X_i) for all X_i by using μj,Σj,P(Sj)\mu_j, \Sigma_j, P(S_j)
  3. Compute / Update μj,Σj,P(Sj)\mu_j, \Sigma_j, P(S_j) from P(SjXi)P(S_j | X_i)
  4. Repeat 2 & 3 until convergence


Generating the data

Let’s start by generating a dataset that follows a Gaussian Mixture Distribution:

from numpy import array, argmax
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from numpy.random import multivariate_normal as mvn_random
from scipy.stats import multivariate_normal
from numpy.random import normal, uniform

class Component:
    def __init__(self, mixture_prop, mean, variance):
        self.mixture_prop = mixture_prop
        self.mean = mean
        self.variance = variance

def generate_gmm_dataset(gmm_params, sample_size):
    def get_random_component(gmm_params):
            returns component with prob
            proportional to mixture_prop
        r = uniform()
        for c in gmm_params:
            r -= c.mixture_prop
            if r <= 0:
                return c

    dataset = []
    for _ in range(sample_size):
        comp = get_random_component(gmm_params)
        dataset += [normal(comp.mean, comp.variance)]
    return dataset

gmm = [
    Component(.25, [-3, 3], [[1, 0], [0, 1]]),
    Component(.50, [0, 0], [[1, 0], [0, 1]]),
    Component(.25, [3, 3], [[1, 0], [0, 1]])
data = generate_gmm_dataset(gmm, sample_size)
Enter fullscreen mode Exit fullscreen mode

Image description

EM Algorithm

First we need to find reasonable initial values for the μj,Σj,P(Sj)\mu_j, \Sigma_j, P(S_j) which we can do by applying a clustering algorithm like Kmeans (which actually favors this type of globular cluster).

def gmm_init(k, dataset):
    kmeans = KMeans(k, init='k-means++').fit(dataset)
    gmm_params = []

    for j in range(k):
        p_cj = sum([1 if kmeans.labels_[i] == j else 0 for i in range(len(dataset))]) / len(dataset)
        mean_j = sum([dataset[i] for i in range(len(dataset)) if kmeans.labels_[i] == j]) / sum([1 if kmeans.labels_[i] == j else 0 for i in range(len(dataset))])
        var_j = sum([(dataset[i] - mean_j).reshape(-1, 1) * (dataset[i] - mean_j).reshape(1, -1) for i in range(len(dataset)) if kmeans.labels_[i] == j]) / sum([1 if kmeans.labels_[i] == j else 0 for i in range(len(dataset))])

        gmm_params.append(Component(p_cj, mean_j, var_j))

    return gmm_params
Enter fullscreen mode Exit fullscreen mode

From the clusters generated by Kmeans, we can get the mean and variance of each cluster, as well as the proportion of points in that cluster, to get initial values for μj,Σj,P(Sj)\mu_j, \Sigma_j, P(S_j) .

Then we have two steps in the EM Algorithm:

def expectation_maximization(k, dataset, iterations):
    gmm_params = gmm_init(k, dataset)

    for _ in range(iterations):
        # expectation step
        probs = compute_probs(k, dataset, gmm_params)

        # maximization step
        gmm_params = compute_gmm(k, dataset, probs)

    return probs, gmm_params
Enter fullscreen mode Exit fullscreen mode

Where the helper function are defined as such:

def compute_gmm(k, dataset, probs):
        Compute P(C_j), mean_j, var_j
        Here mean_j is a vector and var_j is a matrix
    gmm_params = []

    for j in range(k):
        p_cj = sum([probs[i][j] for i in range(len(dataset))]) / len(dataset)
        mean_j = sum([probs[i][j] * dataset[i] for i in range(len(dataset))]) / sum([probs[i][j] for i in range(len(dataset))])
        var_j = sum([probs[i][j] * (dataset[i] - mean_j).reshape(-1, 1) * (dataset[i] - mean_j).reshape(1, -1) for i in range(len(dataset))]) / sum([probs[i][j] for i in range(len(dataset))])

        gmm_params.append(Component(p_cj, mean_j, var_j))

    return gmm_params

def compute_probs(k, dataset, gmm_params):
        For all x_i in dataset, compute P(C_j | X_i) = P(X_i | C_j)P(C_j) / P(X_i) for all C_j
        return the list of lists of all P(C_j | X_i) for all x_i in dataset.
    probs = []

    for i in range(len(dataset)):
        p_cj_xi = []
        for j in range(k):
            p_cj_xi += [gmm_params[j].mixture_prop * multivariate_normal.pdf(dataset[i], gmm_params[j].mean, gmm_params[j].variance)]
        p_cj_xi = p_cj_xi / sum(p_cj_xi)

    return probs
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

To draw the above plots where the size of the data points are proportional to the probability of being in that cluster, you can do the following:

probs, gmm_p = expectation_maximization(num_clusters, data, 3)
labels = [argmax(array(p)) for p in probs] # create a hard assignment
size = 50 * array(probs).max(1) ** 2 # emphasizes the difference in probability

plt.scatter(data[:, 0], data[:, 1], c=labels, cmap='viridis', s=size)
plt.title('GMM with {} clusters and {} samples'.format(num_clusters, sample_size))
Enter fullscreen mode Exit fullscreen mode


Now you can help your friend figure out which weight most likely corresponds to which dino species.


Thank you to Reshab Chhabra for their contributions.

Top comments (0)