Variational Inference Part 2: Black-Box VI

(30 min) Part 2 of 2 series on variational inference. This part dives into the more practical black-box variational inference. We discuss the REINFORCE algorithm and gradient variance reduction techniques (including the neural baseline).

1. Introduction

In part I of the series, we explored the origin of variational inference. As a quick recap, the posterior distribution $p(\theta |x)$ is intractable to compute, as we need to explicitly compute the evidence term:

$$p(X)=\int_{\theta} p(x \mid \theta) p(\theta)d\theta$$

this integration over the parameter space is very very expensive.

To address this problem, we cook up a variational distribution $q(\theta)$ that approximates the posterior. The problem of finding the posterior distribution therefore becomes an optimisation problem. Specifically, we wish to minimise: $$\text{KL}(q(\theta) \| p(\theta \mid x))$$ we discussed the implication of the ordering between $q$ and $p$ inside the brackets. In order to perform the optimisation, we showed that the objective function above is equivalent to maximising the evidence lower bound:

$$q^{*}=\operatorname{argmax}_{q \in Q}\text{ELBO} = \int_\theta q(\theta) \log p(\theta, x)d\theta - \int_\theta q(\theta)\log q(\theta)$$ $$ = \operatorname{argmax}_{q \in Q} \mathbb{E}_{\theta \sim q}[\log p(\theta, x)] + \mathcal{H}(q(\theta))\quad [1]$$

which intuitively speaking contains the maximum likelihood term $p(\theta, x)$ that prefers point estimation, and the entropy term $\mathcal{H}(q(\theta))$ that prefers diffusive estimation. Also note how this ELBO can be estimated easily with Monte-Carlo sampling. Since by construction, we know how to generate samples from $p(\theta, x)$.

We then introduced an analytical method called mean-field approximation, that falls under the framework of ELBO maximisation. The major problems with the mean-field approach are that:

This post describes the more modern techniques used to do variational inference: the black-box family. Specifically, we will look at the reparameterisation trick and the famous REINFORCE algorithm. We will also be discussing some variance reduction techniques such as the baseline method. To make this post less dry, I will also be including some Pyro (probabilistic programming language based on PyTorch) just to demonstrate how easy black-box VI is.

2. Theory of Black-Box VI

2.1 Why Is ELBO Backprop Hard?

One natural question to ask is - why can’t we just do backpropogation through this ELBO objective function? Let us first write down our ELBO (in a more workable format):

$$\mathrm{ELBO} = \mathbb{E}_{q_{\phi}(\theta)}\left[\log p(x, \theta)-\log q_{\phi}(\theta)\right]$$

This is not too much change from equation [1], we made it more explicit that our variational distribution $q$ is parameterised by $\phi$. Form simplicity, let’s call variational distribution $q$ the guide (this is commonly used in literature). The problem is that ELBO is an expectation, so we need to be able to compute unbiased estimates of

$$\nabla_{\theta, \phi} \mathrm{ELBO}=\nabla_{\theta, \phi} \mathbb{E}_{q_{\phi}(\theta)}\left[\log p(x, \theta)-\log q_{\phi}(\theta)\right]$$

Unfortunately, we can’t move the grad inside the expectation, since our expectation is taken under our guide $q$ that depends on $\phi$. This problem can be framed in a more generic way, we can drop the distinction between $\theta$ and $\phi$, also let’s use $f_\phi(\theta)$ to denote any arbituary function within the bracket:

$$\nabla_{\phi} \mathbb{E}_{q_{\phi}(\theta)}\left[f_{\phi}(\theta)\right] \quad [2]$$

Equation [2] describes the gradient through an expectation and can be also found in the policy gradient algorithm in reinforcement learning. The remaining of the post aims to obtain unbiased estimates of equation [2].

2.2 Reparameterisation Trick

The reparameterisation trick also known as path-wise gradients allow us to compute equation [2] by rewritting samples from the guide $q$ in terms of a noise variable $\epsilon$:

$$\mathbb{E}_{q_{\phi}(\theta)}\left[f_{\phi}(\theta)\right]=\mathbb{E}_{q(\epsilon)}\left[f_{\phi}\left(g_{\phi}(\epsilon)\right)\right]$$

concretely, $\epsilon \sim q(\epsilon)$ and $\theta = g_{\phi}(\epsilon)$. Crucially all the $\phi$ dependent terms have been moved inside of the expectation. This kind of reparameterisation can be done for many distributions (e.g. the normal distribution). An example of such reparameterisation can be highlighted by assuming that $\theta$ is sampled from a Gaussian $\theta \sim \mathcal{N}(\mu, \sigma)$. The function $g_{\phi}(\varepsilon)$ can then be expressed as

$$g_{\phi}(\varepsilon)=\mu_{\phi}+\varepsilon \sigma_{\phi}$$

where $\epsilon \sim \mathcal{N}(0, 1)$. Assuming $f(.)$ and $g(.)$ are sufficiently smooth, we can now get unbiased estimates of the gradient of interest by taking a Monte Carlo estimate of this expectation:

$$\nabla_{\phi} \mathbb{E}_{q(\epsilon)}\left[f_{\phi}\left(g_{\phi}(\epsilon)\right)\right]=\mathbb{E}_{q(\epsilon)}\left[\nabla_{\phi} f_{\phi}\left(g_{\phi}(\epsilon)\right)\right]$$

In this case, not only do $\theta$ need to come from specific distributions (i.e. the exponential family), we also need to be able to differentiate through $f_\phi$. In the case where we have discrete latent variables, the gradient of $f_\phi$ is zero almost everywhere in the parameter space.

2.3 REINFORCE for Non-Reparameterizable Random Variables

I am introducing REINFORCE under the scope of variational inference, however, this trick is most commonly seen in reinforcement learning (policy gradient). The concept of REINFORCE is quiet simple: re-write equation [2] to some form so we can obtain unbiased estimate via Monte-Carlo. Specifically, we want to message a gradient of expectation into expectation of some gradient. We begin by expanding the terms:

$$\nabla_{\phi}\text{ELBO}=\nabla_{\phi} \mathbb{E}_{q_{\phi}(\theta)}\left[f_{\phi}(\theta)\right]=\nabla_{\phi} \int d\theta q_{\phi}(\theta) f_{\phi}(\theta)$$

applying chain rule:

$$\int d \theta\left[\left(\nabla_{\phi} q_{\phi}(\theta)\right) f_{\phi}(\theta)+q_{\phi}(\theta)\left(\nabla_{\phi} f_{\phi}(\theta)\right)\right]$$

next, we use the log-derivative trick:

$$\nabla_{\phi} q_{\phi}(\theta)=q_{\phi}(\theta) \nabla_{\phi} \log q_{\phi}(\theta)$$

to obtain

$$\nabla_{\phi}\text{ELBO}=\mathbb{E}_{q_{\phi}(\theta)}\left[\left(\nabla_{\phi} \log q_{\phi}(\theta)\right) f_{\phi}(\theta)+\nabla_{\phi} f_{\phi}(\theta)\right] \quad [3]$$

Equation [3] is the REINFORCE equation. It is quiet disgusting looking and I do not have any intuition to what it represent (frankly, the ELBO objective took me a while to digest). But just by looking at it, we can now obtain unbiased estimates of our ELBO gradient! One way to package this result (for implementation) is by introducing a surrogate function:

$$\text { surrogate objective }= \log q_{\phi}(\theta) \overline{f_{\phi}(\theta)}+f_{\phi}(\theta)$$

to be passed through the autograd, where $\overline{f_{\phi}(\theta)}$ is held as a constant (detached during autograd). Equation [3] therefore becomes:

$$\nabla_{\phi} \mathrm{ELBO}=\mathbb{E}_{q_{\phi}(\theta)}\left[\nabla_{\phi}(\text { surrogate objective })\right] \quad [4]$$

It would be good to finish the story here. Unfortunately, equation [4] suffers from high variance for a range of $f(.)$. So in although we can theoretically obtain unbiased estimates of the true gradient, in reality, noone is going to wait for 10,000 samples per backprop iteration. We therefore need ways to reduce the variance of our estimated ELBO gradient.

3. Practical VI with Pyro

Let’s demonstrate how black-box VI actually works using Pyro. We can build a very dumb model that estimates whether a coin is biased based on the number of tosses. The good thing with this model is that - there exists a analytical solution. Unfortunately, this blog is not about Pyro, so the code won’t be explained in depth. However, the code is pretty much copy-pasted from the Pyro tutorial itself. And frankly Pyro is so high level, anyone who know some PyTorch/numpy can understand it.

Say we wish to find the posterior distribution of the parameter $p$, which governs the probability that a coin toss yields a head:

$$X\sim \text{Bernouli}(p)$$

we place beta prior on $p$:

$$p\sim \text{Beta}(10, 10)$$

first, we toss the coin 10 times and observe 6 heads and 4 tails:

import math
import pyro
import torch
import torch.distributions.constraints as constraints
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

N_STEPS = 5000 # number of gradient update
N_HEADS = 6 # observed number of heads
N_TAILS = 4

# clear param store
pyro.clear_param_store()

# create some data with some observed heads and some observed tails
data = torch.zeros(N_HEADS+N_TAILS)
data[0:N_HEADS] = torch.ones(N_HEADS)

We then build this generative model describing how we may draw coin toss observations from the prior:

def model(data):
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)

    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))

    for i in pyro.plate("data_loop", len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

with latent_fairness denoting $p$, the parameter whose distribution we are interested in estimating. Note we used pyro.plate to mark conditional independence in our observations, this is because $x\in X$ are i.i.d observations. Not only is this step important for model scaling (many observations), but it also helps with ELBO gradient variance reduction (see section below).

The analytical posterior can be calculated:

$$P(p|X)\sim \text{Beta}(10+h, 10+t)$$

where $h$ and $t$ are the number of observed head and tails respectively.

For the purpose of demonstrating SVI, we specify a variational posterior distribution, also called a guide in pyro $q$, where

$$q\sim \text{Beta}(\alpha, \beta)$$

we need to ensure that $\alpha$ and $\beta$ are constrained to be positive

def guide(data):
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0),
                        constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

it is important that in the guide, parameters names match for those we want to approximate (in this case, the latent_fairness). Finally, we use Adam optimiser, specifying ELBO loss and just do inference!

# setup the optimizer
adam_params = {"lr": 0.00025, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
losses = np.empty(N_STEPS)

# do gradient steps
for step in tqdm(range(N_STEPS)):
    losses[step] = svi.step(data)

That’s it! To evaluate model correctness:

# grab the learned variational parameters
alpha_q = pyro.param("alpha_q").item()
beta_q = pyro.param("beta_q").item()

inferred_mean = alpha_q / (alpha_q + beta_q)
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * math.sqrt(factor)

print("Posterior alpha and beta values are {:.1f} and {:.1f} respectively\n".format(alpha_q,
                                                                                    beta_q))
print("The fairness of the coin is {:.3f} ± {:.3f}".format(inferred_mean, inferred_std))

We get:

Posterior alpha and beta values are 16.1 and 13.9 respectively

The fairness of the coin is 0.536 ± 0.090

which is pretty much the same as our analytical result (16 and 14 for the posterior alpha, beta values).

4. ELBO Gradient Variance Reduction

4.1 Rao-Blackwellization

4.2 Baselines

Another strategy for reducing variance in the ELBO gradient estimator goes under the name of baselines. The key point is to note that for any constant $b$,

$$\mathbb{E}_{q_{\phi}(\theta)}\left[\nabla_{\phi}\left(\log q_{\phi}(\theta) \times b\right)\right]=0$$

the proof can be done in one line:

$$\mathbb{E}_{q_{\phi}(\theta)}\left[\nabla_{\phi} \log q_{\phi}(\theta)\right]=\int d \theta q_{\phi}(\theta) \nabla_{\phi} \log q_{\phi}(\theta)=\int d \theta \nabla_{\phi} q_{\phi}(\theta)=\nabla_{\phi} \int d \theta q_{\phi}(\theta)=\nabla_{\phi} 1=0$$

hence we can replace our surrogate objective

$$\log q_{\phi}(\theta) \overline{f_{\phi}(\theta)}$$

with

$$\log q_{\phi}(\theta)\left( \overline{f_{\phi}(\theta)} - b\right)$$

Therefore, the mean of our gradient estimator does not change, but the variance gets reduced (it’s a bit more involved to show this, but intuitively, if we track the mean of $\overline{f_{\phi}(\theta)}$ then the variance should decrease).

So our variance reduction task becomes predicting what $\overline{f_{\phi}(\theta)}$ is at each time-step. One simplest baseline is to simply perform a moving average of recent samples of $\overline{f_{\phi}(\theta)}$. In Pyro, this looks like:

z = pyro.sample("z", dist.My_Dist(...),
                infer=dict(baseline={'use_decaying_avg_baseline': True,
                                     'baseline_beta': 0.95}))

Of course, since the task to reduce ELBO variance essentially becomes prediction $\overline{f_{\phi}(\theta)}$, why not use an auxiliary neural-network to do this for us? The idea of neural-baseline is born. This in Pyro is not too much work. First, we construct some multilayer perceptron (or any DNN of your choice):

class BaselineNN(nn.Module):
    def __init__(self, dim_input, dim_hidden):
        super().__init__()
        self.fc1 = nn.Linear(dim_input, dim_hidden)
        self.fc2 = nn.Linear(dim_hidden,dim_hidden)
        # add more layers/activations

    def forward(self, x):
        out = self.fc2(F.relu(self.fc1(x)))
        return out

for our variational distribution (guide), we simply define:

baseline_module = BaselineNN(2, 50)

def guide(x):
    pyro.module("my_baseline", baseline_module)
    # ... other computations ...
    z = pyro.sample("z", dist.My_Dist(...),
                    infer=dict(baseline={'nn_baseline': baseline_module,
                                         'nn_baseline_input': x}))

Under the hood Pyro constructs a loss of the form: $$\text { baseline loss } \equiv(\overline{f_{\phi}(\mathbf{z})}-b)^{2}$$ which intuitively makes sense. However, there is no theorem that suggests this is the optimal loss function to use in this context, but in practice works well.

Also in practice it can be important to use a different set of learning hyperparameters for baseline parameters (normally, the lr for baseline is higher than the actual model). In Pyro this is done through:

def per_param_args(module_name, param_name):
    if 'baseline' in param_name or 'baseline' in module_name:
        return {"lr": 1e-2}
    else:
        return {"lr": 1e-3}

optimizer = optim.Adam(per_param_args)

5. Conclusion

The main take-away from this post is that - ELBO gradient is not straight forward to obtain. In a lot of cases, we turn this gradient of expectation into an expectation of some gradient (the REINFORCE algorithm). This makes ELBO gradient sampled/estimated, not analytically obtained, if unfortunately suffers from from high variance. To combat this, we mark all conditional independence when we can, and when that fails, we create a baseline that somehow estimates values of the function at each update-step.

Pyro makes all the aforementioned stuff easy. The SVI class implements ELBO gradient re-parameterisation when possible, and handles the REINFORCE algorithm automatically. Variational inference therefore is almost identical to normal gradient update.