Variance Control for Black Box Variational Inference Using The James-Stein Estimator

I have just uploaded a preprint to the arXiv titled “Variance Control for Black Box Variational Inference Using The James-Stein Estimator“. It went live Friday and is the first step in the culmination of an idea I’ve been working on for some time since last year.

In brief, our paper builds off this amazing paper from David Blei’s then-laboratory at Princeton (he has apparently since taken up a professorship at Columbia). The paper comes in a recent line of work attempting to make a method called Variational Inference into a general purpose algorithm applicable to any model.

A primary challenge with performing Bayesian data analysis is that the resulting posterior distribution p(\theta | x) for parameter \theta given data x tends to be intractable. A standard approach to this that will be familiar to many is the use of Markov Chain Monte Carlo techniques, specifically some form of Gibbs Sampling, to draw sequential draws of parameter \theta, which eventually converge to the posterior. There are certain drawbacks to this method, however: for very large models, parameter draws tend to mix very slowly, requiring hours of compute. Also, because of the dependent nature of the sampling process, parallelization can be tricky. For this reason, an alternative line of attack can be found in the form of Variational Inference.

Under Variational Inference, instead of attempting to draw samples from the posterior p(\theta | x) directly, it is proposed to instead create a suitable approximation q(\theta | \lambda) which can be a simpler, tractable distribution freely parametrized by \lambda. If they are close enough (usually in terms of the Kullback-Leibler divergence),

\text{KL}(q(\theta | \lambda) || p(\theta | x)) = \int_{\lambda} q(\theta | \lambda) \log(\frac{q(\theta | \lambda)}{p(\theta | x)})

then we can easily perform all posterior analysis (MAP estimation, Credible Intervals, Bayes Factors, etc) using q(\theta | \lambda) instead. Variational Inference then recasts the problem from sampling (Gibbs) to one of stochastic optimization. From here we can easily take advantage of boosting and parallelization techniques that have already been demonstrated for stochastic optimization problems.

Here’s the rub: finding q(\theta | \lambda) is actually a very complex undertaking, usually more complicated than finding the full conditional distributions needed to perform Gibbs. Each time the model is modified (by adding heirarchical priors, or changing the distribution), the stochastic update steps needed to perform optimization have to be re-derived from scratch. An open problem is therefore how to reduce that overhead. If no prior derivations are needed from the analyst, then it should be simple enough to package the algorithm for any application the same way Gibbs has been packaged in the form of Stan.

Figure 1. Variance Reduction In VI Update Steps Using the James-Stein Approach

The original paper by Ranganath et. al. then shows that this can be done via stochastic gradient ascent problem, where the parameter \lambda can be sequentially updated using

\lambda^t = \lambda^{t-1} + \rho^{t} \hat\nabla_\lambda \mathcal{L}

for some learning rate \rho^{t} and where

\hat\nabla_\lambda \mathcal{L} = \frac{1}{S} \sum^{S}_{s=1} \nabla_{\lambda} \log q(\theta[s] | \lambda) (\log p(y, \theta[s]) - \log q(\theta[s] | \lambda))

Our paper argues that this, and other stochastic optimization problems, can be recast as a multivariate estimation problem where the goal is estimate some true parameter $\mu = \nabla_\lambda \mathcal{L}$ using iid Monte Carlo draws of \nabla{\lambda} \log q(\theta[s] | \lambda) (\log p(y, \theta[s]) - \log q(\theta[s] | \lambda)). And because this is generally a multivariate sample (one component for each parameter $\theta and corresponding parameter $\lambda$), then we can easily borrow the James-Stein estimator, specifically the positive part version, to control the sampling error via a bias-variance tradeoff:

\hat\mu_{JS+} = \bigg(1 - \frac{(p-3) \sigma^2}{|| \bar{z} ||^2} \bigg)^+ \bar{z}

for

\bar{z} = \frac{1}{S} \sum^{S}_{s=1} \nabla_{\lambda} \log q(\theta[s] | \lambda) (\log p(y, \theta[s]) - \log q(\theta[s] | \lambda)) and (g)^+ = gI_{[0, +\infty)}(g). Simulation results (Figure 1) demonstrate the correctness of this theorem, although the paper provides a more complete proof that shows its applicability in the variational inference context.

The use of the norm as a weighting factor in the stochastic update step should be familiar to anyone working in the field of deep learning, as it is related to gradient clipping. When training deep neural networks with long memory, there is often the issue of exploding gradients that cause the training to get pushed to very far areas of the objective function away from a local optimum. Gradient clipping controls this by keeping the gradient only up to a maximum radius. In the paper, we show that the method is performing an opposite version of that control: instead of stopping the gradient when it reaches a set radius, it penalizes it for coming close to the radius. Effectively, the James-Stein estimator applied to VI is forcing the gradients to stay very small. Instead of stopping the algorithm from making very big jumps, it forces it to make very small ones consistently instead.

There’s certainly more work required to polish the paper into a complete scientific finding, but I think the idea is sound enough to stand on its own already. If anything I think this is something of step (if small) towards a truly black box algorithm for Variational Inference.

One response

  1. My James-Stein BBVI paper is now online on the AROB journal – Dominic Dayta Avatar

    […] paper on improving black box variational inference (BBVI) using the James-Stein estimator was recently accepted in the SpringerNature journal, Artificial Life and Robotics, and is now […]

    Like

Leave a comment