Gradient and Natural Gradient, Fisher Information Matrix and Hessian

Here I am writing down some notes summarizing my understanding in natural gradient. There are many online materials covering similar topics. I am not adding anything new but just doing personal summary.

Assume we have a model with model parameter \theta. We have training data x. Then, the Hessian of log likelihood, log p(x|\theta), is:

    \[       H_{log p(x|\theta)} = \begin{bmatrix} \frac{\partial^2 logp(x|\theta)}{\partial \theta_1^2} & \dots & \frac{\partial^2 logp(x|\theta)}{\partial \theta_1 \partial \theta_n} \\ \dots & \dots & \dots \\ \frac{\partial^2 logp(x|\theta)}{\partial \theta_n \partial \theta_1} & \dots & \frac{\partial^2 logp(x|\theta)}{\partial \theta_n^2} \end{bmatrix} \]

Fisher information matrix F is defined as the covariance matrix of \nabla_\theta log p(x|\theta) (note that \nabla_\theta log p(x|\theta) itself is a vector). If we define s(\theta) = \nabla_\theta log p(x|\theta), then F=\mathbb{E}_{p(x|\theta)}\left[\left(s(\theta) - mean(s(\theta))\right)\left(s(\theta) - mean(s(\theta))\right)^T\right]=\mathbb{E}_{p(x|\theta)}\left[s(\theta)s(\theta)^T\right]. The last equation holds because mean(s(\theta)) is 0 [1]. It can be shown that the negative of \mathbb{E}_{p(x|\theta)}[H_{log p(x|\theta}] is Fisher information matrix [1]. Therefore, Fisher information matrix can be thought as the (negative) curvature of log likelihood. The curvature (Hessian) of a function is the second-order derivative of the function, which depicts how quickly the gradient of the function changes. Computing Hessian (curvature) takes longer time than computing just gradient but knowing Hessian can accelerate learning convergence. If the curvature is high, the gradient changes quickly, then the gradient update of the parameters should be more cautious (i.e., smaller step); if the curvature is low, the gradient doesn’t change much, then the gradient update can be more aggressive (i.e., large step). Moreover, the eigenvalues of a Hessian determines convergence speed [3]. Therefore, knowing Fisher information matrix is quite important in optimizing a function.

Suppose we are going to optimize a loss function: \min \mathcal{L}(x, \theta). The normal gradient update is: \theta \leftarrow \theta - \alpha \nabla_\theta \mathcal{L}(x,\theta). The natural gradient has a little different form: \theta \leftarrow \theta - \alpha F^{-1} \mathcal{L}(x,\theta). The natural gradient formula is actually derived from [5]:

    \[ \nabla_\theta^{NAT} \mathcal{L}(x,\theta) = \arg\min_{d} \mathcal{L}(x, \theta+d), \quad s.t. \quad KL\left( p(x|\theta) || p(x|\theta+d)\right) = c \]

This formula shows the intuition behind natural gradient: the natural gradient should minimize the loss as much as possible while doesn’t radically change p(x|\theta). Another way to think about natural gradient is that since the Fisher information matrix F encodes the curvature of the log likelihood, then the natural gradient is the normal gradient scaled by the reverse of the curvature: if the log likelihood’s curvature is large, that means some change in \theta could radically the likelihood, then we should be conservative in the gradient update.   

 

Reference

[1] https://wiseodd.github.io/techblog/2018/03/11/fisher-information/

[2] https://towardsdatascience.com/its-only-natural-an-excessively-deep-dive-into-natural-gradient-optimization-75d464b89dbb

[3] http://mlexplained.com/2018/02/02/an-introduction-to-second-order-optimization-for-deep-learning-practitioners-basic-math-for-deep-learning-part-1/

[4] http://kvfrans.com/a-intuitive-explanation-of-natural-gradient-descent/

[5] https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/

 

Leave a comment

Your email address will not be published. Required fields are marked *