Expectation Maximisation Deep Dive
(45min) This post covers the in-depth theory regarding the EM algorithm, with python implementation for a Gaussian Mixture Model from scratch.
1. Introduction
Expectation maximisation (EM) is one of those things that I was taught several times at University, each time it was explained slightly differently, and I was confused by it every single time. The confusion is mostly because we never implemented any of the EM algorithms through Python, and EMs were always introduced to solve a specific problem.
Without a doubt, EM serves as a foundation for understanding variational inference. Here, I will try to go in-depth and derive and E and M steps under a general framework. Later I will derive and implement a Gaussian Mixture Model using EM.
Maths behind this post are based on Coursera’s Bayesian Methods for Machine Learning specialisation course. Intermediate-level probability theory and ML knowledge are assumed. Knowledge of the Kullback–Leibler (KL) divergence and Jensen’s inequality are also expected (link to videos explaining the concepts).
2. Theory of EM
2.1 What is it trying to solve?
EM aims to tackle a family of models termed latent variable models (examples like GMM, LDA, etc). Their graphical representation can be compactly represented as:
where $x_i$
is our data-points in the dataset $X = \{x_0, \cdots x_N\}$
and $z_i$
is the latent variables associated with each datapoint. The edge represents that $x_i$
is conditioned by $z_i$
. Albeit this graphical model looks simple, to fully specify the model, we need to:
- Specify prior distribution of
$z_i$
- Specify the conditional distribution
$p(x_i|z_i)$
both these two distributions can be arbitrarily complex (which is a problem during both model fitting and inference). If we summarise all the model parameters associated with our problem to be $\theta$
, our goal becomes finding a set of parameters $\theta$
which maximises the likelihood of our dataset $X$
given our model:
$$
\max_{\theta}p(X|\theta)
$$
We assume our datapoints are i.i.d under our model, therefore:
$$
\max_{\theta}p(X|\theta) = \max_\theta \prod p(x_i|\theta)
$$
simplifying the product terms with a sum by maximising the log-likelihood:
$$
\max_{\theta}\log p(X|\theta) = \max_\theta \sum_i \log p(x_i|\theta)
$$
now we explicitly add in the conditional distribution using the rule of marginalisation:
$$
\max_{\theta}\sum_i \log\left(\sum_c p(x_i, z_i=c|\theta)\right)
$$
where for simplicity, we assume discrete distribution for $z_i$
(derivation holds if you replace the sum with an integral). The support for $z_i$
is $[0, 1, \cdots, c]$
, i.e. it is categorical.
That’s it! EM tries to find $\theta$
that maximises the likelihood of our data, and through some maths we have shown:
$$
\max _{\theta} p(X | \theta) = \max _{\theta} \sum_{i} \log \left(\sum_{c} p\left(x_{i}, z_{i}=c | \theta\right)\right)
$$
Of course, one question is why can’t we directly differentiate the (log) likelihood? As we will see later with GMM, we cannot obtain an analytical equation for $\theta$
via direct differentiation!
2.2 Graphical representation of the solution
Here we jump straight into the solution of EM and obtain a graphical representation of the iterative update rule.
We wish to find a set of parameters $\theta$
which maximises the likelihood $\log(p(X|\theta))$
. This distribution can be arbitrarily complex, therefore at each timestep, we approximate this likelihood using a simpler distribution that forms a lower bound $\mathcal{L}(q, \theta)$
, which is below the likelihood at all $\theta$
values. At timestep $k$
, the EM update rule are as follows:
- E-step: fix
$\theta$
, find some arbitrary variational distribution$q(z_i)$
which maximises the lower-bound, i.e.$q^{k+1} = \arg\max_{q}\mathcal{L}(\theta^k, q)$
. - M-step: fix
$q$
, find model parameters$\theta$
that maximises our current lower-bound, i.e.$\theta^{k+1} = \arg\max_\theta \mathcal{L}(\theta, q^{k+1})$
One thing to get our head around is that in E-step, we are essentially taking derivative w.r.t. a function $q(z_i)$
. These two steps can be seen more clearly in the picture below.
Derivation later will show exactly how we have arrived at E and M steps, but jumping ahead, these two steps are equivalent to:
- E-step: set variational distribution
$q(z_i)$
to be equal (or approximately equal) to the posterior distribution, i.e.$q^k(z_i) = p(z_i|\theta_i^k, x_i)$
where$k$
denotes the timestep in our iteration. - M-step: under our distribution
$q(z_i)$
, find a set of parameters$\theta$
which maximises the data likelihood, i.e.$\theta^{k} = \arg\max_\theta \sum_i \mathbb{E}_{z_i\sim q^k}\left[\log p(x_i, z_i | \theta)\right]$
Hence EM algorithm boils down to the fact that we cannot directly maximise our likelihood function, hence we maximise its lower-bound under our own variational distribution which is a simple function. At each timestep, we approximate or set this function to the posterior distribution $p(Z|X,\theta)$
(E-step), and we maximises the likelihood of our data-points under this function instead (M-step). Alternating these two procedures, we are guaranteed to reach a local optima (in the picture above, depending on initialisation, we could have easily ended up on the left peak as our final solution).
2.3 Detailed Derivation
E-Step
TL;DR: $\;\arg\max _{q\left(z_{i}\right)} \mathcal{L}\left(\theta, q\right)=p\left(z_{i} | x_{i}, \theta\right)$
, to minimise the gap between our lower-bound and the likelihood under current parameters $\theta$
, set our variational distribution to be the posterior.
Proof: first, lets define how exactly a lower-bound arises and where on earth this variational distribution $q$
comes from. Decompose the likelihood:
$$\log p(X | \theta) = \sum_{i} \log \left(\sum_{c} p\left(x_{i}, z_{i}=c | \theta\right)\right)$$
Idea here is we cannot directly maximise this equation (if you can, then do not use EM ;)), instead, we multiply top and bottom by a distribution of our choice $q$
(coz why not?):
$$\log p(X | \theta) = \sum_{i} \log\left( \sum_{c} q\left(z_{i}=c\right) \frac{p\left(x_{i}, z_{i}=c | \theta\right)}{q\left(z_{i}=c\right)}\right)$$
which is equivalent to:
$$\log p(X|\theta) = \sum_{i} \log \mathbb{E}_{z_{i} \sim q} \frac{p\left(x_{i}, z_{i} | \theta\right)}{q\left(z_{i}\right)}$$
now we can utilise Jensen’s inequality to obtain a lower bound for this equality. As a reminder, for any concave function (such as $\log(x)$
in our case), $f (\mathbb{E}[x]) \geqslant \mathbb{E}[f (x)]$
. We can therefore write the equation above as:
$$
\log p(X|\theta)\geq \sum_{i} \mathbb{E}_{z_{i}\sim q} \log \frac{p\left(x_{i}, z_{i} | \theta\right)}{q\left(z_{i}\right)}
$$
expanding the expectation out:
$$
\log p(X|\theta)\geq\sum_{i} \sum_{c} q\left(z_{i}=c\right) \log \frac{p\left(x_{i}, z_{i}=c | \theta\right)}{q\left(z_{i}=c\right)} = \mathcal{L}(\theta, q)
$$
Our objective has the become minimising the gap between this lower-bound by optimising our arbitrary distribution $q$
:
$$
\arg\min_q \text{Objective} = \arg\min_q \left( \log p(X|\theta) - \mathcal{L}(\theta, q)\right)
$$
more maths:
$$
\text{Objective} = \sum_{i} \log p\left(x_{i} | \theta\right) - \sum_{i} \sum_{c} q\left(z_{i}=c\right) \log \frac{p\left(x_{i}, z_{i}=c | \theta\right)}{q\left(z_{i}=c\right)}
$$
for the first term, we multiply it with $1 = \sum_c q(z_i=c)$
:
$$
=\sum_{i} \left(\log p\left(x_{i} | \theta\right)\sum_{c} q\left(z_{i}=c\right)\right) - \sum_{i} \sum_{c} \cdots
$$
now we can move both summations out:
$$
=\sum_{i}\sum_{c}\left(\log p\left(x_{i} | \theta\right) q\left(z_{i}=c\right) - q\left(z_{i}=c\right) \log \frac{p\left(x_{i}, z_{i}=c | \theta\right)}{q\left(z_{i}=c\right)} \right)
$$
$$
=\sum_{i} \sum_{c} \left(q\left(z_{i}=c\right) \left(\log p\left(x_{i} | \theta\right) - \log \frac{p\left(x_{i}, z_{i}=c | \theta\right)}{q\left(z_{i}=c\right)}\right)\right)
$$
$$
=\sum_{i} \sum_{c}\left(q\left(z_{i}=c\right)\left(\log \left(\frac{p\left(x_{i} | \theta\right) q\left(z_{i}=c\right)}{p\left(x_{i}, z_{i}=c | \theta\right)}\right)\right)\right)
$$
re-write the joint distribution of the denominator inside the log:
$$
=\sum_{i} \sum_{c}\left(q\left(z_{i}=c\right)\left(\log \left(\frac{p\left(x_{i} | \theta\right) q\left(z_{i}=c\right)}{p(x_i|\theta)p(z_i=c|\theta, x_i)}\right)\right)\right)
$$
cancelling out the $p\left(x_{i} | \theta\right)$
we obtain:
$$
=\sum_{i} \sum_{c}\left(q\left(z_{i}=c\right)\log \left(\frac{ q\left(z_{i}=c\right)}{ p\left(z_{i}=c | \theta, x_{i}\right)}\right)\right)
$$
observe that the summation under $c$
term corresponds to the KL divergence between $q$
and $p$
, we therefore obtain:
$$
\text{Objective} = \text{KL}(q(z_i)|| p(z_i|\theta, x_i))
$$
after all this faff, we see to minimise the gap (objective) between the likelihood and our lower-bound is equivalent to:
$$
\arg \min _{q}(\log p(X | \theta)-\mathcal{L}(\theta, q)) = \arg \min _{q}\mathrm{KL}\left(q\left(z_{i}\right) \| p\left(z_{i} | \theta, x_{i}\right)\right)
$$
so effectively, E-step becomes an optimisation problem with the objective function being the KL divergence between variational distribution $q$
w.r.t. model posterior $p\left(z_{i} | \theta, x_{i}\right)$
. There is no guarantee that the posterior has a closed-form representation. This is where the rich literature of variational inference comes in. Fortunately, with a GMM (Gaussian mixture model), the posterior has a closed form solution! Hence at each iteration, we are able to set $q$
to be equal to the posterior.
M-Step
TL;DR given a fixed variational distribution $q$
, maximise the data and latent variable’s joint likelihood w.r.t. model parameters $\theta$
, i.e. $\arg\max_\theta \mathcal{L}(\theta, q) = \arg \max _{\theta} \mathbb{E}_{Z \sim q}[\log p(X, Z | \theta)]$
where latent variables come from our variational ditribution $q$
.
Recall from E-step, maximising the lower-bound is essentially maximising the likelihood w.r.t. $\theta$
, the only difference is we can easily perform this maximisation since we know our distribution $q$
(normally it is a Gaussian distribution). Writing the whole expectation term out:
$$
\arg\max_\theta \mathcal{L}(\theta, q) = \arg\max_\theta \sum_{i} \sum_{c} q\left(z_{i}=c\right) \log \frac{p\left(x_{i}, z_{i}=c | \theta\right)}{q\left(z_{i}=c\right)}
$$
splitting the terms within the log:
$$
\arg \max _{\theta} \sum_{i} \sum_{c} q\left(z_{i}=c\right) \left(\log p\left(x_{i}, z_{i}=c | \theta\right) - \log q\left(z_{i}=c\right)\right)
$$
the last term is independent of $\theta$
:
$$
\arg \max _{\theta} \sum_{i} \sum_{c} q\left(z_{i}=c\right)\log p\left(x_{i}, z_{i}=c | \theta\right) + \text{const}.
$$
condensing the sum over $c$
into expectation, dropping the constant:
$$
\arg \max _{\theta}\sum_{i}\mathbb{E}_{z_i}\left[ \log p(x_i,z_i=c|\theta)\right] = \mathbb{E}_{Z\sim q}\left[ \log p(X,Z|\theta)\right]
$$
So the M-step is simply maximising the joint distribution between collection of latent variables $Z$
and data points $X$
under the variational distribution $q$
:
$$
\hat{\theta} = \arg \max _{\theta}\mathbb{E}_{Z \sim q}[\log p(X, Z | \theta)]
$$
Again, we definitely know what $p(X, Z | \theta)$
is, since we needed it to set up our graphical model. However, maximisation w.r.t. $\theta$
can be non-trivial. In the case of GMM, with a lot of maths, we can obtain closed form solutions for model parameters $\theta$
.
3. EM in Practice: Gaussian Mixture Model
Here we unleash the EM algorithm on Gaussian Mixture Model (GMM), For a quick intuition on how GMM update rules work, here is a link to StatQuest’s video. I have to say, update rule for GMM is fairly intuitive, and videos like this is a tiny bit misleading. As it is not easy to shoehorn the general EM algorithm into this sort of intuition. Instead, we will not skip (too much) on the maths and will derive exactly how GMM update rules arise from EM.
3.1 Derivation
Model Definition
Remember the graph earlier?
Here we define exactly what the model parameters are for a GMM. First let’s look at what GMM is trying to achieve:
Left figure above is the data points that you have obtained, through some prior knowledge or by inspecting the plots, you believe that the data points come from three Gaussian distributions, and you wish to fit the Gaussian centres and variances to the data points (shown on the right).
To apply latent variable modelling, we think about “how do we generate these data points with a probabilistic model?”. Let $z_i$
follow a categorical distribution (three categories) which define the proportion of data points that belong to each cluster. When conditioned on the categories, the data point $x_i$
follows a gaussian distribution. Model parameters are therefore $\theta = [\pi_1, \cdots, \pi_c, \mu_1, \cdots, \mu_c, \sigma^2_1, \cdots \sigma^2_c]$
, which are the distribution over categories, mean and standard deviaiton of each cluster centre. We restrict ourselves to fitting isotropic Gaussians, although in practice, you can find the MLE covariance matrix (maths will of course be more disgusting).
Given this mode, we can write down the likelihood of our data points under the model:
$$
p(x_i |\theta) = \sum_c p(x_i|z_i=c, \theta)p(z_i=c|\theta) = \sum_c p(x_i|z_i=c, \theta) \pi_c
$$
$$
p(X |\theta) = \prod_i \sum_c p(x_i|z_i=c, \theta) \pi_c= \prod_i \sum_c \mathcal{N}(x_i;\mu_c, \sigma^2_c) \pi_c
$$
This completes our model definition for a GMM. Why don’t we directly maximise the likelihood w.r.t. model parameters? Let’s see:
$$
\log p(X |\theta) = \sum_i \log \sum_{c} \mathcal{N}\left(x_{i} ; \mu_{c}, \sigma_{c}^{2}\right) \pi_{c}
$$
$$
\nabla_\theta \log p(X|\theta) = \sum_i \nabla_\theta \log\left( \sum_{c} \mathcal{N}\left(x_{i} ; \mu_{c}, \sigma_{c}^{2}\right) \pi_{c}\right)
$$
And we are stuck… good luck dealing with the log of sum of exponential term above. This is why EM is used to solve GMMs.
E-Step
Recall in EM, the E-step is equivalent to setting our variational distribution $q(z_i)$
to be equal or approximately equal to the posterior $p(z_i|x_i,\theta)$
. In the case of GMM we can obtain the analytical posterior. Applying Bayes’ rule:
$$
q(z_i=c) = p(z_i=c|x_i,\theta) = \dfrac{p(x_i|z_i=c,\theta)p(z_i=c|\theta)}{\sum_k p(x_i|z_i=k,\theta)p(z_i=k|\theta)}
$$
Don’t be confused by the index $k$
in the denominator, we are just marginalising over all cluster centre categories (in our example, 1 to 3). In E-step we fix all model parameters, the posterior above therefore takes the form:
$$
q(z_i=c) = \dfrac{\mathcal{N}(x_i|\mu_c,\sigma^2_c)\pi_c}{\sum_k \mathcal{N}(x_i|\mu_k,\sigma^2_k)\pi_k}
$$
This result is fairly intuitive -
you are effectively asking: what is the chance of this data point belonging to cluster $c$
, given the current cluster centres and deviations.
For each data point, we do a soft (probabilistic) assignment on which cluster it comes from. Which turns out to be the probability of it being generated by cluster $c$
(numerator), divided by the sum of probabilities of it being generated by all clusters (denominator).
M-Step
Recall in M-step, we maximise the joint probability distribution under our own variational distribution $q(z_i)$
. I will just copy-paste from the previous section:
$$
\theta^* = \arg \max _{\theta} \sum_{i} \sum_{c} q\left(z_{i}=c\right) \log p\left(x_{i}, z_{i}=c | \theta\right) = \arg\max_\theta \mathcal{G}
$$
where $\mathcal{G}$
is just a short hand for the joint density, and $\theta$
are the mixture proportion $\pi_c$
terms, the Gaussian centres $\mu_c$
and $\sigma^2_c$
terms.
Note that here, $q(z_i)$
is fixed, even though from the E-step they depend on our parameters. Writing $\mathcal{G}$
out fully:
$$
\mathcal{G} = \sum_{i} \sum_{c} q\left(z_{i}=c\right) \log\left(\dfrac{1}{\sigma_c \sqrt{2\pi}}\exp\left(\dfrac{-(x_i-\mu_c)^2}{2\sigma_c^2}\right)\pi_c\right)
$$
$$
= \sum_{i} \sum_{c} q\left(z_{i}=c\right) \left(\log \dfrac{\pi_c}{\sigma_c\sqrt{2\pi}} - \dfrac{(x_i-\mu_c)^2}{2\sigma_c^2}\right)
$$
This is all we need to obtain the analytical gradient with respect to model parameters $\theta$
. For example, to get $\mu_k$
($k$
th cluster mean):
$$
\nabla_{\mu_k}\mathcal{G} = \sum_i q(z_i=k) \dfrac{x_i-\mu_k}{\sigma_c^2}
$$
since for all $c\neq k$
the gradient is zero. Setting above to zero (solve for optimum $\mu_k$
):
$$
\mu_k = \dfrac{\sum_i q(z_i=k)x_i}{\sum_i q(z_i=k)}
$$
The same can be done to $\sigma_k^2$
(via high school calculus) to obtain:
$$
\sigma_k^2 =\frac{\sum_{i}\left(x_{i}-\mu_{k}\right)^{2} q\left(z_{i}=k\right)}{\sum_{i} q\left(z_{i}=k\right)}
$$
To get optimum mixing proportions $\pi_k$
is more complicated, since we need to do constrained optimisation with $\pi_k \geq 0$
and $\sum_c \pi_c=1$
since it needs to be a probability distribution. This can be done through Lagrange multipliers. Here we actually omit the maths and jump straight into the result:
$$
\pi_k = \dfrac{\sum_i q(z_i=k)}{N}
$$
where $N$
is the total number of data points. The three equations above should all be fairly intuitive, the common theme is to re-estimate model parameters with our ‘self-labelled’ data points (labelled by our own variational distribution $q$
). One could have jumped straight into these formula (I couldn’t, but I’m sure some mathmos can). But here, we followed the EM algorithm in order to maximise likelihood of data points under our model.
3.2 Implementation
Of course, any sane person would call sklearn.mixture.GaussianMixture
and just do GMM.fit(X)
. But that is way too efficient. We start with some standard import:
3.2.1 Data/Ground Truth Generation
First, let’s generate our dataset. Define the ground truth parameters that we wish to estimate:
The dataset is drawn according to our model:
we observe the dataset at the beginning of the section (three fairly distinct clusters).
We also need to initialise guesses for our model parameters. Here we sample from a normal distribution for the cluster centres, and guess identity covariance matrix and uniform mixing proportion:
3.2.2 E-Step
E-step update rule is reproduced here for convenience:
$$
q\left(z_{i}=c\right)=\frac{\mathcal{N}\left(x_{i} | \mu_{c}, \sigma_{c}^{2}\right) \pi_{c}}{\sum_{k} \mathcal{N}\left(x_{i} | \mu_{k}, \sigma_{k}^{2}\right) \pi_{k}}
$$
the corresponding code is simply:
As a sanity check, under the variational distribution, likelihood of each data point for all clusters should sum to one:
3.2.3 M-Step
M-step update rule is copied here for all three classes of parameters:
$$
\mu_{k}=\frac{\sum_{i} q\left(z_{i}=k\right) x_{i}}{\sum_{i} q\left(z_{i}=k\right)}
$$
$$
\sigma_{k}^{2}=\frac{\sum_{i}\left(x_{i}-\mu_{k}\right)^{T}(x_{i}-\mu_{k}) q\left(z_{i}=k\right)}{\sum_{i} q\left(z_{i}=k\right)}$$
$$\pi_{k}=\frac{\sum_{i} q\left(z_{i}=k\right)}{N}
$$
Note we slightly modified covariance matrix update rule so it involves an outer product (realised I probably should have worked in vector form earlier, but oh well). The code becomes:
3.2.4 Wrapping Up
In theory, we need to check for convergence by evaluation the variational lower bound. This is omitted here. Instead, we run simulation for 150 time steps and obtain:
So as anticipated, our GMM indeed fitted three Gaussians onto the data points. Visually the fit is not bad. But how do the estimated parameters fair with the ground truth? We visualise the phase portrait for cluster means (2D) and variances within the covariance matrix (also 2D) across iterations. Plots below show the ground truth of model parameters in black, and the trajectories GMM generated across the iteration. We observe very good agreement between the predicted locations and ground truth!
In summary, we have show how EM algorithm at heart does maximum likelihood estimation of the dataset being generated by a given model. Again it is worth reflecting on the fact that neither E or M step guarantees closed form solution. This sets up the motivation for future blog posts regarding variational inference and Monte Carlo based methods. Finally, this post wouldn’t be complete if I didn’t include a link to Carl Rasmussen’s lecture notes on EM (it is super condensed, but good as a quick memory refresher). The best talk I can find online is given by Andrew Ng, which explains both the overall concept really well and does not shy away from the derivations.