Variational Inference Part 1: Intro and Mean-Field

(30 min) Part 1 of 2 series on variational inference. This part first introduces high-level concepts and the mean-field approximation. We then demonstrate the mean-field theory on a Bayesian Gaussian Mixture Model (Bayesian GMM).

1. Introduction

Bayesian inference offers an attractive story with principled uncertainty estimation and the ability to combine expert knowledge into the model. For reasons that will hopefully become clear later, the challenge in Bayesian inference is to be able to provide both fast and scalable products during both development and serving. In this post, I will summarise what variational inference is trying to solve, it’s benefits and potential pitfalls. I will then introduce the mean-field method, and use it to solve a Bayesian Gaussian Mixture Model. In the second post, I will elaborate on black-box variational inference methods.

VI is an advanced topic, so intermediate-level probability theory and ML knowledge are assumed. Knowledge of the Kullback–Leibler (KL) divergence is also expected. To better understand Bayesian GMM (which will be maths heavy), I would like to recommend my previous post on Expectation Maximisation and this video for intuition. Let’s dive in!

2. Theory of VI

2.1 What Is It Trying to Solve?

Let’s start with the classic bayesian inference equation: $$p(\theta|X) = \dfrac{p(X|\theta)p(\theta)}{p(X)}$$ in words: $$\text{posterior} = \dfrac{\text{likelihood}\times\text{prior}}{\text{evidence}}$$ following this equation, there are three steps involving developing a Bayesian inference model.

  1. Build a generative model: choose prior and choose likelihood, i.e. how can our data be generated using a probabilistic model? (GMM for discovering mixtures, LDA for discovering topics). This step results in an easily evaluated joint distribution $p(X, \theta)$
  2. Compute posterior $p(\theta|X)$ (from what distribution do our parameters actually come from after observing the data? As suppose to our belief, i.e. prior distribution)
  3. Report a summary, e.g. posterior means and (co)variances. Inform your boss/client about the ‘insights’ provided by the model, i.e. “on average 30% of our data come from distribution A, with an uncertainty of ±5% etc.”

Each of the three steps alone requires a separate blog post(s). We only focus on step 2, i.e. how do we compute the posterior? This is in fact a very hard problem because of the normalising $p(X)$ (evidence) term. The only way to evaluate it exactly is to compute: $$p(X) = \int_\theta p(X|\theta)p(\theta)$$ suppose in our model, our parameter space includes things like “cluster mean” $\mu_1, \cdots \mu_M$, “cluster variance” $\sigma^2_1, \cdots \sigma^2_M$, etc, this integral over $\theta$ becomes: $$p(X) = \int_{\mu_1}\cdots \int_{\mu_M}\int_{\sigma^2_1}\cdots \int_{\sigma^2_M} p(X|\mu_1, \cdots \sigma^2_M)d\mu_1\cdots d\sigma^2_M$$

not really looking pretty… since we are integrating over some pretty complicated, high dimensional function over a high dimensional space. As with a lot of real life problems, this integration typically has no closed form, and it will be extremely inefficient to use numerical integration tools due to the high dimensionality. The gold standard approach to obtain the posterior $p(\theta|X)$ is to use Markov Chain Monte Carlo (MCMC), which deserves a separate blog post. But in short MCMC provides an unbiased estimate of the posterior, i.e. if you run it for long enough, you are guaranteed to become arbitrarily close to the posterior. But we run into the same problem of “running it for long enough”, turns out, for large dataset and parameters, “long enough” is “too long to be acceptable for development/production”. What would be a fast and reliable (good enough approximate) of the posterior? Here comes variational inference!

2.2 VI Overview

Essentially, variational inference turns the numerical integral problem into an optimisation problem. Remember what we want to obtain is the posterior distribution $p(\theta|X)$, which can be pretty complicated looking (we in fact don’t normally know what it looks like). In VI, we try to approximate this distribution with a nice distribution $q(\theta)$ (more on “nice” later, for now, picture Gaussian). Graphically speaking, left figure below shows what we might want to achieve:

drawing

To think about this in terms of optimisation: the posterior distribution lives in some complicated space, we wish to find the optimum distribution $q^*(\theta)$ which is closest to the posterior, and come from a family of “nice” distribution (pictorially shown on the right). i.e. $$q^* = \arg\min_{q\in Q}c(q(\theta), p(\theta|X))$$ with $c$ being the “closeness” function. Two things we need to further define - what is “nice” and what is “close”? Let’s talk about “nice” first. Recall what we really want at the end of the day is to be able to report mean and variance of the (approximated) posterior distribution. Hence the family of distributions $Q$ are typically:

Next on the topic of “close”, note I did not use the word distance, since distance between two distributions entails both symmetry and positive definiteness. The default “closeness” function $c()$ that is used in VI is in fact the Kullback-Leibler (KL) divergence: $$KL(q(\theta)|| p(\theta | X))$$ which is definitely not symmetric. While it seems to be a reasonable function to choose for measuring closeness between two distributions, we can certainly question:

  1. why KL? Why not other metrics?
  2. why did we choose the order of divergence to be $q||p$ but not $p||q$

the short answer for the first question is that - because it makes maths simpler and it is shown to provide good approximations. The complete maths for VI is as follows, our objective is: $$q^{*}=\operatorname{argmin}_{q \in Q} \operatorname{KL}(q(\theta) \| p(\theta | X))$$ we know under KL measurement, if $Q$ covers all the distributions there are, we just need to set $q(\theta)=p(\theta | X)$ to achieve $\operatorname{KL}=0$. But we do not know what the posterior is, nor do we want to calculate it via that intractable integral. Instead, we write KL in it’s full form: $$\operatorname{KL}(q(\theta) \| p(\theta | X)) = \int_\theta q(\theta) \log \frac{q(\theta)}{p(\theta | X)} d \theta \quad [1]$$ and re-write the posterior: $$\operatorname{KL}(\cdots)=\int_\theta q(\theta) \log \frac{q(\theta) p(X)}{p(\theta, X)} d \theta=\log p(X)-\int_\theta q(\theta) \log \frac{p(\theta, X)}{q(\theta)} d \theta \quad [2]$$ and recall $p(X)$ term is intractable, but it does not depend on $q$, hence to find the optimum $q$ we can completely avoid this intractable term! $p(\theta, X)$ can be evaluated easily since we needed it to construct our model in the first place (step one of the three procedures in the previous section). Our objective function becomes: $$q^* = \operatorname{argmax}_{q \in Q}\int_\theta q(\theta) \log \frac{p(\theta, X)}{q(\theta)} d \theta$$ (argmax because we are minimising the negative). This objective function is given the confusing name “Evidence lower bound” (ELBO), because if we look at equation $[2]$, KL divergence is none-negative, the evidence term $\log p(X)$ is a constant, hence the remaining term is a lower bound on the evidence term. I find this ELBO objective more intuitive by expanding it fully: $$q^{*}=\operatorname{argmax}_{q \in Q}\text{ELBO} = \int_\theta q(\theta) \log p(\theta, X)d\theta - \int_\theta q(\theta)\log q(\theta)$$ $$ = \operatorname{argmax}_{q \in Q} \mathbb{E}_{\theta \sim q}[\log p(\theta, X)] + \mathcal{H}(q(\theta))$$ where $\mathcal{H}(q(\theta))$ is the entropy of our posterior approximation (a.k.a variational distribution). We see maximising the ELBO objective has two contradictive and competing sub-objectives:

  1. $\mathbb{E}_{\theta \sim q}[\log p(\theta, X)]$ can be seen as maxising the likelihood of our dataset being generated by our parameters (drawn from $q$) and our model (fixed). Prefers point estimates of $\theta$
  2. $\mathcal{H}(q(\theta))$ entropy term, which prefers uncertain estimates of $\theta$

It is these two competing objectives that makes Bayesian inference so attractive, and less prone to overfitting! i.e. we want a good fit to our dataset but also uncertainty in our parameters.

To summarise, variational inference aims to approximate the intractable posterior $p(\theta|X)$ with some “nice” distribution $q^*(\theta)$. Approximation is turned into an optimisation problem, where the objective function is ELBO: $$q^{*}=\operatorname{argmax}_{q \in Q} \mathrm{ELBO} = \operatorname{argmax}_{q \in Q} \int_{\theta} q(\theta) \log \frac{p(\theta, X)}{q(\theta)} d \theta$$

Now we circle back to an earlier problem we raised - why did we take $\text{KL}(q||p)$? Why not $\text{KL}(p||q)$. In short, inspecting equation $[1]$, had we taken the divergence the other way round, we would end up taking expectation over our intractable posterior $p(\theta|X)$. This sort of maximisation is actually still possible, and is termed Belief-Propagation (which will not be discussed here). What is the effect of the order? Observe that: $$\mathrm{KL}(q(x) \| p(x)) = \int_x q(x)\dfrac{q(x)}{p(x)} dx $$ in the cases where $p(x)$ is very close or equal to zero (i.e. low-likelihood regions), the only way to minimise the divergence is to ensure $q(x)$ is close to zero at these locations. Hence we say that $\mathrm{KL}(q\| p)$ is zero-seeking. Similar argument may be made for $\mathrm{KL}(p\| q)$ which is called zero-avoiding. This is more intuitive through graphical visualisation. Suppose we want to approximate a bimodal distribution with a Gaussian, graphs below shows the behaviour of the two minimisation methods:

drawing

where the blue line is the true posterior $p(\theta|X)$, red lines are our variational approximations. So methods with $q||p$ tends to under estimate the posterior variance, c.f. over estimation in $p||q$. Both can be quiet bad if you think about the real-life impact (under-estimating the uncertainty in getting cancer vs. over-estimating it). A lot of research is done on finding alternative methods/metrics, which of course is beyond the scope of this post.

2.3 Mean-Field Approximation

Recall we wish to approximate the posterior using a family of “nice” distributions $Q$, one of the criteria is low dimensionality. $\theta$ contains all parameters in our model, hence it is typically high dimensional. What if we can instead fully factorise it: $$Q_{\text{MF}}:=\left\{q: q(\theta)=\prod_{j=1}^{J} q_{j}\left(\theta_{j}\right)\right\}$$ so now we explicitly assume independence between parameters in our approximated distribution. This gives us the form of mean-field (MF) variational inference. Note this is NOT equivalent to assuming parameters are actually independent from one another. Hence MF is not a modelling assumption, but an approximation assumption. This factorized form of variational inference corresponds to an approximation framework developed in physics called mean field theory. Now we seek the distribution within the family $Q_\text{MF}$ for which ELBO is the largest, substitution $q(\theta)$ into ELBO: $$ \text{ELBO} = \int_\theta \prod_j q_j(\theta_j) \left\{ \log p(X,\theta) - \sum_j \log q_j (\theta) \right\} d\theta$$

It can be shown (maths is quiet involved, so ignored here), that the general expression for the optimal solution $q^*_k (\theta_k)$ satisfies:

$$\log q_k^*(\theta_k) = \mathbb{E}_{\theta \sim q_{j\neq k}}[ \log p(X, \theta)] + \text{const.}\quad [3]$$ which says that the log of the optimal solution for the factor $q_k$ is obtained by considering the log of the joint distribution and then taking the expectation with respect to all the other factors $\{q_j\} \;\forall j\neq k$

This means we can apply coordinate ascent to obtain $q^*(\theta)$ by iteratively setting each factor $q_{k}\left(\theta_{k}\right)$ to $q^*_{k}\left(\theta_{k}\right)$ according to equation [3]. The bad news is that - in order to do MF, you have to analytically derive each update steps (just like expectation maximisation for Gaussian Mixture Models), note in equation [3] the constant term is effectively the normalising denominator if you take exponent on both side to get the actual density function $q_{k}^{*}\left(\theta_{k}\right)$. In practice, people just find $q_{k}^{*}\left(\theta_{k}\right)$ by inspection, after writing down the analytical form of [3].

Suppose we have gone through the trouble, deriving the update rule for each parameter in our model, and we obtain our final $q^*(\theta)$. One natural question to ask: how correct are we from the true posterior? Mackay’s book introduces a very simple problem: given a set of one dimensional data points $X = [x_0, \cdots x_N]$, we would like to model the distribution with a Gaussian and report it’s mean and variance.

import numpy as np
μ, σ, N = 10.0, 2.5, 100
X = μ + σ * np.random.randn(N)

I mean … if you are sane, you would probably do maximum likelihood estimation and guess that $\mu_{\text{MLE}} = \bar{X}$ and $\sigma^2_{\text{MLE}} = \text{Var}[X]$. But what if we want to go Bayesian? Let’s place a Gamma prior on the precision of the normal distribution, and a normal prior on our mean:

drawing

Because we have used conjugate priors, we can actually get an analytical posterior via computing: $$p(\mu, \sigma |X) = \dfrac{p(X|\mu,\sigma)p(\mu, \sigma)}{p(X)}$$ (the exact format is actually quiet involved). One can show that this definitely does not factorize, i.e. $p(\mu, \sigma | X)\neq f_1(\mu)f_2(\sigma)$. For those who are insane, we can actually use mean-field approximation on this posterior (for which we know the analytical form), i.e. we will consider a factorised variational approximation to the posterior given by: $$q(\mu, \sigma) = q_\mu (\mu)q_\sigma (\sigma)$$ one can painstakingly derive the update rules according to equation [3] and do coordinate descent to obtain:

drawing

in figures above, green contours represent our true posterior, with x-axis representing posterior $\mu$ and y-axis being posterior $\sigma^{-1}$. Blue lines are our variational mean-field approximation at each time-step. Red lines show our converged posterior approximation. As a sanity check this shows that MF provides a decent fit for our toy example (moments seem to match)! What about models where we do not know the true posterior? (i.e. most real-life models)? Recall in section 2.1 we mentioned that another common method to obtain unbiased approximations of the posterior is through MCMC. If we trust MCMC (where the convergence checking is an art itself), then we can check our VI approximated mean values for parameters against MCMC’s mean. Figures below are from Fosdick 2013 where we see for a complicated model, VI and MCMC posteriors have good agreement.

drawing

Note from figures above, VB stands for variational Bayes, which is used interchangeably with variational inference. Do not be misled by figures above, there are plenty cases where VI is unable to provide good posterior estimations! Therefore, one possible way to check VI accuracy is to use the one-time more computationally expensive MCMC (for the experiment above, MCMC took over one day), once you trust your VI (which in this case, takes less than 20 minutes), you can ditch MCMC and work with VI from there.

3. VI in Practice: Bayesian GMM

3.1 Bayesian GMM VS GMM: Theory and “Derivation”

Mixture of Gaussians is a form of latent variable models where we wish to infer a set of hidden variables from the observed data. In our post Expectation Maximisation Deep Dive, we investigated in depth the update rules associated with maximum likelihood estimation of parameters for a Gaussian Mixture Model (GMM). As a recap, we have $N$ observed data points whose positions $x_n$ are influenced by the cluster they come from. We call $z_n$, a K-dimensional binary random variable that “one-hot encodes” which cluster the data point comes from (total number of clusters is $K$). Hence: $$p(x_n|z_{nk}=1) = \mathcal{N}(x_n|\mu_k ,\Sigma_k)$$ We also define the concept of mixing proportion, which is the marginal distribution over $z$: $$p(z_{nk}=1)=\pi_k, \quad p(z_n) = \prod_{k=1}^K \pi_k^{z_{nk}}$$ Hence we can write down the joint distribution of $x_n, z_n$: $$p(x_n, z_n) = p(x_n|z_n)p(z_n) = \prod_{k=1}^K \pi_k^{z_{nk}} \mathcal{N}\left(x_{n} | \mu_{k}, \Sigma_{k}\right)^{z_{nk}} $$ This model be compactly represented using the graphical representation shown in the left figure below.

drawing

Note in basic GMM, we did not place a prior over any model parameters ($\pi,\; \mu, \; \Sigma\;$), we instead use maximum likelihood (ML) to find ML estimates (MLE) of these parameters. We will show in the next section, MLE can easily lead to overfitting. To make a GMM Bayesian, all we do is to place priors over model parameters, i.e. let the mixing proportion $\pi$ (which is multinomial) be drawn from a symmetric Dirchlet distribution: $$p(\pi) = \text{Dir}(\pi|\alpha_0)$$ the mean and precision of each Gaussian mixture are drawn from an independent Gaussian-Wishart prior (Wishart is the conjugate prior to precision matrices): $$p(\mu, \Lambda)=p(\mu|\Lambda)p(\Lambda)=\prod_{k=1}^K \mathcal{N}(\mu_k|m_0, (\beta_0 \Lambda_k)^{-1})\mathcal{W} (\Lambda_k|W_0, \nu_0)$$ where parameters $\alpha_0, \beta_0, m_0, \nu_0, W_0$ govern the shape of prior distributions. Conditional distributions for $z_n$ and $x_n$ are almost identical to one described by the GMM: $$p(z_n|\pi) = \prod_{k=1}^{K} \pi_{k}^{z_{n k}}$$ and $$ p(x_n|z_n,\mu, \Lambda) = \prod_{k=1}^{K} \mathcal{N}\left(x_{n} | \mu_{k} ,\Lambda_{k}^{-1}\right)^{z_{n k}}$$ Right figure above shows the complete graphical representation of this Bayesian GMM model.

In order to formulate a VI treatment of this model, we explicitly write down the joint distribution over all random variables: $$p(X, Z, \pi, \mu, \Lambda) = p(X|Z, \mu, \Lambda)p(Z|\pi)p(\pi)p(\mu|\Lambda)p(\Lambda) \quad [4]$$ where $X = \{x_0, \cdots x_N\}$ is our observed dataset and similarly $Z = \{z_0, \cdots, z_N\}$.

We wish to obtain the posterior distribution: $$p(Z, \pi, \mu, \Lambda|X)$$ which is intractable. Instead, consider a variational distribution which factorises between the latent variables and the parameters so that: $$q(Z, \pi, \mu, \Lambda) = q(Z) q(\pi, \mu ,\Lambda)$$ this follows the mean-field approach to VI, turns out, this is the only assumption we need to make to obtain a tractable solution to our Bayesian mixture model! Following the mean-field framework (equation [3]), to update the factor $q(Z)$, set $$\log q^{\star}(Z)=\mathbb{E}_{\pi, \mu, \Lambda}[\log p({X}, {Z}, \pi, \mu, \Lambda)]+\mathrm{const.}$$ the exact form of $q^*(Z)$ requires quiet a lot of maths, but essentially we just substitute the joint density term (equation [4]) in and take expectations w.r.t. the parameters. Jumping to the solution: $$q^{\star}(\mathbf{Z})=\prod_{n=1}^{N} \prod_{k=1}^{K} r_{n k}^{z_{n k}}$$ where $r_{n k} \propto \pi_{k}\left|\boldsymbol{\Lambda}_{k}\right|^{1 / 2} \exp \left\{-\frac{1}{2}\left(\mathbf{x}_{n}-\boldsymbol{\mu}_{k}\right)^{\mathrm{T}} \boldsymbol{\Lambda}_{k}\left(\mathbf{x}_{n}-\boldsymbol{\mu}_{k}\right)\right\}$ is quiet intuitive, i.e. it is the “responsibility” term (almost identical to the GMM model), that indicates the contribution from each cluster towards our observed data points. Next, we obtain the variational posterior for $\pi, \mu, \Lambda$: $$\ln q^{\star}({\pi}, {\mu}, {\Lambda}) = \mathbb{E}_Z[\log p(X, Z, \pi, \mu, \Lambda)]+\text { const. }$$ it can be shown that: $$q^{\star}({\pi})=\operatorname{Dir}({\pi} | {\alpha}), \quad \alpha_k = \alpha_0+N_k \quad [5]$$ and $$q^{\star}\left({\mu}_{k}, {\Lambda}_{k}\right)=\mathcal{N}\left({\mu}_{k} | {m}_{k},\left(\beta_{k} {\Lambda}_{k}\right)^{-1}\right) \mathcal{W}\left({\Lambda}_{k} | {W}_{k}, \nu_{k}\right) \quad [6]$$ where

$$ N_{k}=\sum_{n=1}^{N} r_{n k}, \quad\quad \bar{x}_{k}=\frac{1}{N_{k}} \sum_{n=1}^{N} r_{n k} {x}_{n}, \quad\quad {S}_{k}=\frac{1}{N_{k}} \sum_{n=1}^{N} r_{n k}\left({x}_{n}-\bar{x}_{k}\right)\left({x}_{n}-\overline{x}_{k}\right)^{\mathrm{T}}$$

define the “pseudo count”, “cluster mean” and “cluster variance respectively”;

$$\beta_{k}=\beta_{0}+N_{k}$$ $${m}_{k}=\frac{1}{\beta_{k}}\left(\beta_{0} {m}_{0}+N_{k} \overline{x}_{k}\right)$$ $${W}_{k}^{-1}={W}_{0}^{-1}+N_{k} {S}_{k}+\frac{\beta_{0} N_{k}}{\beta_{0}+N_{k}}\left(\overline{x}_{k}-{m}_{0}\right)\left(\overline{x}_{k}-{m}_{0}\right)^{\mathrm{T}}$$ $$\nu_{k}=\nu_{0}+N_{k}$$

are used to parameterise the VI posterior Gaussian clusters $q^{\star}\left(\mu_{k}, \Lambda_{k}\right)$. We should not be too daunted by the form of our VI posterior. We see the form of both [5] and [6] makes sense, since our priors were Dirchlet and Gaussian-Wishart, and our VI posterior ended up being the same kind of distributions. We also see that the VI posteriors are parameterised by both prior parameters like $\alpha_0, \beta_0, W_0$ and $\nu_0$ and taking into account of observed data.

It is worth noting the computational cost between GMM and Bayesian GMM, by inspecting the update equations above, there is nothing fundamentally different in Bayesian GMM (computationally speaking), the most expensive step are matrix inverts, which also manifests in GMM.

3.2 Implementation and Results

For my sanity, unlike the GMM algorithm, I will not be implementing Bayesian GMM from scratch. Instead, scikit-learn has readily available BayesianGaussianMixture module that does all the hard work for us. Starting with the imports:

from sklearn.datasets.samples_generator import make_blobs
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture

where we use make_blobs function to generate our dataset:

N, M, σ = 500, 3, 3.
X, Y = make_blobs(n_samples=N, centers=M,
                  cluster_std=σ, random_state=42)
X = X[:, ::-1] # reshape to simplify plotting indexing

Both GMM and Bayesian GMM are implemnted and presented for comparision:

gmm = GaussianMixture(n_components=3, random_state=42)
gmm.fit(X)
labels = gmm.predict(X)

bgmm = BayesianGaussianMixture(n_components=3,
                               weight_concentration_prior_type="dirichlet_distribution",
                               init_params='random', max_iter=5000, random_state=42)
bgmm.fit(X)
blabels = bgmm.predict(X)

That’s it! Forget about the maths above. Figures below show the fitted model and posterior for cluster size equal to 3 (sorry about the notation, $M$ in the plots means the same as $K$, i.e. number of clusters). We see both GMM and Bayesian GMM provide very good estimates.

drawing

The power of Bayesian GMM is observed when you do not know how many clusters there are, which is a very common thing to high dimensional, unknown dataset. Let’s set number of clusters to be equal to 20. Figures below show that GMM completely overfits the dataset, whereas Bayesian GMM is not fooled!

drawing

Indeed, if we plot out the posterior mixing proportions, we see GMM places almost uniform proportion across all clusters, whereas Bayesian GMM has managed to identify that there are only three clusters!

drawing

Of course, one can argue that through cross-validation, we can tune the optimum number of clusters for GMM. But a) Bayesian GMM has the same computational cost as GMM and b) We can obtain uncertainty in posterior parameters which is an added bonus! So in short, Bayesian GMM beats GMM :)

An obvious downside of the mean-field approach is the maths one needs to go through (we have in fact omitted 90% of the actual derivation), another issue lies in scalability. Part 2 of the post will cover the concept of black-box inference.