Introduction
In this post, we introduce one machine learning technique called stochastic variational inference that is widely used to estimate posterior distribution of Bayesian models. Suppose in a Bayesian model, the model parameters is denoted as a vector and the observation is denoted as . According to Bayesian theorem, the posterior distribution of can be computed as:
is the probability of observation marginal over all possible model parameters:
isn’t easy to compute, most of time intractable, because of its integral form. If we are not able to compute , then we are not able to compute , which is what we want to know. Therefore, we need to come up with a way to approximate . We denote the approximated posterior as . is also called the variational distribution hence the name of variational inference.
Stochastic variational inference (SVI) is such one method to approximate . From the ICML 2018 tutorial [2], we can see the niche where SVI lies: among all possible ways to approximate , there is a group of algorithms using optimization to minimize the difference between and . Those representing the difference between the two distributions as Kullback-Leibler divergence is called variational inference. If we further categorize based on the family of , there is one particular family called mean-field variational family which is easy to apply variational inference. After all levels of categorization, we arrive at some form of objective function which we sort to minimize. SVI is one optimization method to optimize a defined objective function that pushes to reflect our interest in minimizing the KL-divergence with .
Objective function
By definition, KL divergence between two continuous distributions and is defined as [4]:
If we are trying to find the best approximated distribution using variational Bayes, we define the following objective function:
,
where . (all expectations are taken with respect to .) Note that if we are gonna optimize w.r.t , then can be treated as a constant. Thus, minimizing the KL-divergence is equivalent to maximizing:
is the lower bound of because of the non-negativity of KL-divergence:
Update 2020.4:
The derivation above is also illustrated in [14]:
There are several other ways to understand ELBO.
- Based on Jensen’s inequality [13]: for a convex function and a random variable , ; for a concave function , . Therefore, we have:
Therefore, is the lower bound of
2. By rearranging , we have:
Therefore, the first part of can be thought as the so-called “reconstruction error”, which encourages to put more probability mass on the area with high . The second part encourages to be close to the parameter prior . is the common objective used in Variational Autoencoder models.
How to optimize?
Recall that our objective function is . In practice, minimizing with regard to translates to parameterize and then optimize the objective function with regard to the parameters. One big assumption we could make to facilitate computation is to assume all latent variables are independent such that can be factorized into the product of distributions of individual latent variables. We call such a the mean-field variational family:
From the factorization, you can see that each individual latent variable’s distribution is governed by its own parameter . Hence, the objective function to approximate changes from to:
One simple algorithm to optimize this is called coordinate ascent mean-field variational inference (CAVI). Each time, the algorithm optimizes one variational distribution parameter while holding all the others fixed. The algorithm works as follows:
“Set ” may seems hard to understand. It means that setting the variational distribution parameter such that follows the distribution that is equivalent to up to a constant. means that the expectation is taken with regard to a distribution .
What to do after knowing ?
After the optimization (using CAVI for example), we get the variational distribution . We can use the estimated to analytically derive the mean of or sample from . One thing to note is that there is no restriction on the parametric form of the individual variational distribution. For example, you may define to be an exponential distribution: . Then, the mean of is . If is a normal distribution, then actually contains two parameters, the normal distribution’s mean and variance. Thus the mean of is simply the mean parameter.
Stochastic Variational Inference
One big disadvantage of CAVI is its scalability. Each update of requires full sweep of data to compute the update. Stochastic variational inference (SVI) kicks in because updates of using SVI only requires sub-samples of data. The simple idea is to take the gradient of and use it to update . But there is some more detail:
- formulas of updates would be very succinct if we assume complete conditionals are in the exponential family: , where is its own sufficient statistics, , , and are defined according to the definition of the exponential family [10].
- We also categorize latent variables into local variables, and global variables.
- the gradient is not simply taken in the Euclidean space of parameters but in the distribution space [11]. In other words, the gradient is transformed in a sensible way such that it is in the steepest descent direction of KL-divergence. Also see [12].
All in all, SVI works as follows [8]:
Reference
[1] https://www.quora.com/Why-and-when-does-mean-field-variational-Bayes-underestimate-variance
[2] ICML 2018 Tutorial Session: Variational Bayes and Beyond: Bayesian Inference for Big Data: https://www.youtube.com/watch?time_continue=2081&v=DYRK0-_K2UU
[3] Variational Inference: A Review for Statisticians
[4] https://towardsdatascience.com/demystifying-kl-divergence-7ebe4317ee68
[5] https://www.zhihu.com/question/31032863
[6] https://blog.csdn.net/aws3217150/article/details/57072827
[7] http://krasserm.github.io/2018/04/03/variational-inference/
[8] NIPS 2016 Tutorial Variational Inference: Foundations and Modern Methods: https://www.youtube.com/watch?v=ogdv_6dbvVQ
[10] https://jwmi.github.io/BMS/chapter3-expfams-and-conjugacy.pdf
[11] https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/
[12] https://czxttkl.com/2019/05/09/gradient-and-natural-gradient-fisher-information-matrix-and-hessian/