EmbeddingBag from PyTorch

EmbeddingBag in PyTorch is a useful feature to consume sparse ids and produce embeddings.

Here is a minimal example. There are 4 ids’ embeddings, each of 3 dimensions. We have two data points, the first point has three ids (0, 1, 2) and the second point has the id (3). This is reflected in input and offsets variables: the i-th data point has the id from input[offset[i]] (inclusive) to input[offset[i+1]] (exclusive). Since we are using the “sum” mode, the first data point’s output would be the sum of the embeddings of ids (0, 1, 2), and the second data point’s output would be the embedding of id 3.

>>> embedding_sum = nn.EmbeddingBag(4, 3, mode='sum')
>>> embedding_sum.weight
Parameter containing:
tensor([[-0.9674, -2.3095, -0.2560],
        [ 0.0061, -0.4309, -0.7920],
        [-1.3457,  0.8978,  0.1271],
        [-1.8232,  0.6509, -1.2162]], requires_grad=True)
>>> input = torch.LongTensor([0,1,2,3])
>>> offsets = torch.LongTensor([0,3])
>>> embedding_sum(input, offsets)
tensor([[-2.3070, -1.8426, -0.9209],
        [-1.8232,  0.6509, -1.2162]], grad_fn=<EmbeddingBagBackward>)
>>> torch.sum(embedding_sum.weight[:3], dim=0)
tensor([-2.3070, -1.8426, -0.9209], grad_fn=<SumBackward1>)
>>> torch.sum(embedding_sum.weight[3:], dim=0)
tensor([-1.8232,  0.6509, -1.2162], grad_fn=<SumBackward1>)

Test with torch.multiprocessing and DataLoader

As we know PyTorch’s DataLoader is a great tool for speeding up data loading. Through my experience with trying DataLoader, I consolidated my understanding in Python multiprocessing.

Here is a didactic code snippet:

from torch.utils.data import DataLoader, Dataset
import torch
import time
import datetime
import torch.multiprocessing as mp
num_batches = 110

print("File init")

class DataClass:
    def __init__(self, x):
        self.x = x


class SleepDataset(Dataset):
    def __len__(self):
        return num_batches

    def __getitem__(self, idx):
        print(f"sleep on {idx}")
        time.sleep(5)
        print(f"finish sleep on {idx} at {datetime.datetime.now()}")
        return DataClass(torch.randn(5))


def collate_fn(batch):
    assert len(batch) == 1
    return batch[0]


def _set_seed(worker_id):
    torch.manual_seed(worker_id)
    torch.cuda.manual_seed(worker_id)


if __name__ == "__main__":
    mp.set_start_method("spawn")
    num_workers = mp.cpu_count() - 1
    print(f"num of workers {num_workers}")
    dataset = SleepDataset()
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=num_workers,
        worker_init_fn=_set_seed,
        collate_fn=collate_fn,
    )

    dataloader = iter(dataloader)
    for i in range(1000):
        print(next(dataloader).x)

We have a Dataset called SleepDataset which is faked to be computationally expensive. We allow DataLoader to use all available processes (except the main process) to load the dataset. Python3 now has three ways to start processes: fork, spawn, and forkserver. I couldn’t find much online information regarding forkserver. But the difference between fork and spawn has been discussed a lot online: fork is only supported in Unix system. It creates a new process by copying the exact memory of the parent process into a new memory space and the child process can continue to execute from the forking point [3]. The system can still distinguish parent and child processes by process ids [1]. On the other hand, spawn creates new processes by initializing from executable images (files) rather than directly copying the memory from the parent process [2].

Based on these differences, if we let mp.set_start_method("spawn"), we find “File init” will be printed first at the main process then be printed every time a DataLoader process is created (110 times since num_batches = 110). If we let mp.set_start_method("fork"), we find “File init” will only be printed once. “forkserver” method behaves similarly to “spawn”, as we also see 110 times of “File init” being printed.

[1] https://qr.ae/TS6uaJ

[2] https://www.unix.com/unix-for-advanced-and-expert-users/178644-spawn-vs-fork.html

[3] https://www.python-course.eu/forking.php

Indexing data on GPU

This correspond a question I asked on Pytorch forum. When we want to use indexing to extract data which is already on GPU, should indexing arrays better be on GPU as well? The answer is yes. Here is the evidence:

I also created some other examples to show that if you are generating indexing arrays on the fly, they should be best created using torch.xxx(..., device=torch.device("cuda") rather than torch.xxx(...).

TRPO, PPO, Graph NN + RL

Writing this post to share my notes on Trust Region Policy Optimization [2], Proximal Policy Optimization [3], and some recent works leveraging graph neural networks on RL problems. 

We start from the objective of TRPO. The expected return of a policy is \eta(\pi)=\mathbb{E}_{s_0, a_0, \cdots}[\sum\limits_{t=0}^{\infty}\gamma^t r(s_t)]. The return of another policy \hat{\pi} can be expressed as \eta(\pi) and a relative difference term: \eta(\hat{\pi})=\eta(\pi) + \sum\limits_s \rho_{\hat{\pi}}(s)\sum\limits_a \hat{\pi}(a|s) A_{\pi}(s,a), where for any given \pi, \rho_{\pit}(s)=P(s_0) + \gamma P(s_1=s) + \gamma^2 P(s_2=s) + \cdots is called discounted state visitation frequency. 

To facilitate the optimization process, the authors propose to replace \rho_{\hat{\pi}}(s) with \rho_{\pi}(s). If the new policy \hat{\pi} is close to \pi, this approximation is not that bad since the discounted state visitation frequency shouldn’t change too much. Thus \eta(\hat{\pi}) \approx L_{\pi}(\hat{\pi})=\eta(\pi) + \sum\limits_s \rho_{\pi}(s)\sum\limits_a \hat{\pi}(a|s)A_{\pi}(s,a).

Using some theorem (See section 3 [2]) we can prove that \eta(\hat{\pi}) \geq L_{\pi}(\hat{\pi})-C \cdot D^{max}_{KL}(\pi, \hat{\pi}), where C=\frac{4\epsilon\gamma}{(1-\gamma)^2}, \epsilon=max_{s,a}|A_{\pi}(s,a)|, and D^{max}_{KL}(\pi, \hat{\pi})=\max_s D_{KL}(\pi(\cdot|s) \| \hat{\pi}(\cdot|s)). The inequality-equality becomes just equality if \hat{\pi}=\pi. Therefore, \arg\max_{\hat{\pi}}[L_{\pi}(\hat{\pi})-C\cdot D^{max}_{KL}(\pi, \hat{\pi})] \geq L_{\pi}(\pi)-C\cdot D^{max}_{KL}(\pi, \pi)=L_{\pi}(\pi). Therefore, we can use algorithm 1 to monotonically improve \pi.

The algorithm 1 can also be understood through the diagram below (from [10]), where we set M_i(\pi)=L_{\pi_i}(\pi)-CD^{max}_{KL}(\pi_i, \pi).

In practice, we parameterize a policy with \theta, and we move the coefficient C from the objective to a constraint (the paper argues this could improve the policy in larger steps), and finally we use the average KL divergence between two policies rather than D^{max}_{KL}. Putting these practical treatments together, we get:

\arg\max\limits_{\theta} L_{\pi_{\theta_{old}}}(\pi_{\theta})=\sum\limits_s \rho_{\pi_{\theta_{old}}}(s)\sum\limits_a \pi_\theta(a|s)A_{\pi_{\theta_{old}}}(s,a) \newline \text{subject to } \bar{D}^{\rho_{\pi_{\theta_{old}}}}_{KL}(\pi_{\theta_{old}}, \pi_{\theta})\leq \delta

If changing the sum to expectation \mathbb{E}_{s\sim \rho_{\theta_{old}}, a\sim \pi_{\theta_{old}}}, we will need some importance-sampling re-weighting. And in practice, we can estimate Q-values from trajectories more easily than estimating advantages because we would need exhaust all actions at each step to compute advantages. Thus the objective function using empirical samples eventually becomes:

\arg\max\limits_{\theta} \mathbb{E}_{s\sim \rho_{\pi_{\theta_{old}}}, a\sim \pi_{\theta_{old}}} \big[\frac{\pi_{\theta}(a|s)}{\pi_{\theta_{old}}(a|s)} Q_{\pi_{\theta_{old}}}(s,a)\big] \newline \text{subject to } \mathbb{E}_{s\sim \rho_{\pi_{\theta_{old}}}} \big[D_{KL}(\pi_{\theta_{old}}(\cdot|s) \| \pi_{\theta}(\cdot|s))\big]\leq \delta

Suppose the objective is f(\theta)=\frac{\pi_{\theta}(s|a)}{\pi_{\theta_{old}}(a|s)} Q_{\pi_{\theta_{old}}}(s,a). Using Taylor series expansion, we have f(\theta) \approx f(\theta_{old}) + \nabla_\theta f(\theta)|_{\theta=\theta_{old}}(\theta-\theta_{old})=f(\theta_{old}) + g^T (\theta-\theta_{old}). f(\theta_{old}) can be seen as a constant and thus can be ignored during optimization.

And for the constraint, we can also use Taylor series expansion (this is a very common trick to convert KL divergence between two distributions into Taylor series expansion). Suppose h(\theta)=D_{KL}(\pi_{\theta_{old}}(\cdot|s) \| \pi_{\theta}(\cdot|s)), then h(\theta)\approx h(\theta_{old}) + \nabla_\theta h(\theta) |_{\theta=\theta_{old}} (\theta-\theta_{old})+\frac{1}{2}(\theta-\theta_{old})^T \nabla^2_\theta h(\theta)|_{\theta=\theta_{old}}(\theta-\theta_{old}). We know that h(\theta_{old})=0 because it is the KL divergence between the same distribution \pi_{\theta_{old}}. We can also know that \nabla_\theta h(\theta) |_{\theta=\theta_{old}}=0 because the minimum of KL divergence is 0 and is reached when \theta=\theta_{old} hence the derivative at \theta_{old} must be 0.

Removing all constant and zero terms, and with two more notations s=\theta-\theta_{old} and H=\nabla^2_\theta h(\theta)|_{\theta=\theta_{old}}, we rewrite the objective function as well as the constraint:

\arg\min\limits_{s} -g^T s \newline\text{ subject to } \frac{1}{2}s^T H s - \delta \leq 0

Now the Lagrange Multiplier optimization kicks in. Intuitively, the direction to get to the next \theta is on the direction that minimizes -g^T s while stretching the constraint by the most extent (at the moment when the equality of the constraint holds). Denote the Lagrange Multiplier as \lambda (a single scalar because we only have one constraint) and the auxiliary objective function as \mathcal{L}(s, \lambda)=-g^T s + \lambda (\frac{1}{2}s^T H s - \delta). From Karush-Kuhn-Tucker (KTT) conditions, we want to find a unique \lambda^* and the local minimum solution s^* such that:

  • \nabla_s \mathcal{L}(s^*,\lambda^*)=0
  • \lambda^* \geq 0
  • \lambda^*\cdot (\frac{1}{2}s^T H s - \delta) = 0
  • \frac{1}{2}s^{*T} H s^* - \delta \leq 0 (actually, the equality should hold because s^* should be obtained at the border of the constraint, although I am not sure how to prove it. This means, as long as we find some non-negative \lambda^* such that \nabla_s \mathcal{L}(s^*,\lambda^*)=0, the first four conditions are satisfied.) 
  • \nabla^2_s \mathcal{L}(x^*, \lambda^*) is positive semi-definite. \nabla^2_s \mathcal{L}(x^*, \lambda^*) is just H times some constant term. And we know from the fact that: (1) H is the Fisher information matrix because it is the Hessian of KL divergence [12]; (2) a Fisher information matrix is positive semi-definite [13].  

Since we don’t have analytical solution for \frac{1}{2}s^{*T} H s^* - \delta = 0, we can’t know s^* easily. We can start from looking at \nabla_s \mathcal{L}(s^*,\lambda^*)=-g + \lambda^* Hs^* = 0. Moving -g to the right side, we have \lambda^* H s^*=g. Here, both s^* and \lambda^* are unknown to us. But we know that s^*=\lambda^{*-1}H^{-1}g. Therefore, s^* must be in the direction of  H^{-1}g. How to compute H^{-1}g? If we just compute H^{-1} as pure matrix inversion, we would need O(n^3) time complexity where n is width/height of H. Instead, we can first obtain the solution x of the equation H x=g using Conjugate Gradient Method; x will be exactly H^{-1}g. We will spend one section below to introduce Conjugate Gradient method but for now just remember Conjugate Gradient method can compute x much faster than O(n^3) as in matrix inversion. Back to our focus: once we get H^{-1}g, we will try to take the step size \beta that makes the constraint equality hold: \frac{1}{2} \beta x^T H \beta x - \delta=0. Thus the \beta would be obtained at \beta=\sqrt{(\frac{2\delta}{x^T H x})}. Therefore, \theta = \theta_{old}+\beta x. This is exactly one iteration of TRPO update!

 

We now introduce what the conjugate gradient method is, which is used to find the update direction in TRPO. 

Conjugate Gradient Method

My summary is primarily based on [6]. CG is an iterative, efficient method to solve Ax=b, which is the solution of the quadratic form argmin_x f(x)=argmin_x \frac{1}{2}x^TAx-b^Tx+c. The most straightforward way to solve Ax=b is to compute A^{-1} and let x=A^{-1}b. However, computing A^{-1} needs O(n^3) time complexity by Gaussian elimination method, where n is the width (or height) of A.

If we can’t afford to compute A^{-1}, the best bet is to rely on iterative methods. We list some notations used across iterative methods before we proceed. The error e_i=x_i - x is the distance between the i-th solution x_i and the real solution x. The residual r_i=-f'(x_i)=b-Ax_i indicates how far we are from b.

Steepest gradient descent with line search works as follows. At solution x_i, take a step on the direction of r_i, and the step size is determined such that the arrival point after the step, which is x_{i+1}, should have the smallest f(x_{i+1}). This can also be understood as r_{i+1} should be orthogonal to r_i. The two pictures below help illustrate the idea:

 

As you can see, the steepest gradient method often takes a zig-zag path to reach the sufficient proximity of the optimal solution. From theories (Section 6.2 in [6]), the number of steps it needs depends on the condition number of A and the starting point x_0. And from the intuition, the number of steps is no smaller than the number of bases of the solution space. The luckiest situation is that you pick x_0 which results to all orthogonal steepest gradients in the following steps. Such x_0 can also be thought to have e_o parallel with any of the eigenvectors of A:

The idea of conjugate direction methods (conjugate gradient is one of conjugate direction methods) is to enforce the iterations happen exactly n times to reach the optimal solution, where, by our definition, n is the width (or height) of A and also the number of bases of the solution space.

The procedure of conjugate direction methods starts from finding d_0, d_1, \cdots, d_{n-1} search directions. If d_0, d_1, \cdots, d_{n-1} can be orthogonal to each other, we can easily approach x in n steps, as in Figure 21. But theoretically we are not able to find such n search directions. Instead, we can use Gram-Schmidt Conjugation and u_0, u_1, \cdots, u_{n-1} linearly independent vectors to construct n A-orthogonal search directions. And it is provable that using n A-orthogonal search directions we can also approach x in n steps. However, for arbitrary u_0, u_1, \cdots, u_{n-1} Gram-Schmidt Conjugation requires O(n^2) space and O(n^3) time complexity, which is a disadvantage in practice.

The method of conjugate gradients refers to the conjugate direction method when u_0, u_1, \cdots, u_{n-1} are actually residuals r_0, r_1, \cdots, r_{n-1}. Fortunately, it has a nice property that space complexity and time complexity can be reduced to O(mn), where m is the number of nonzero entries of A. Conjugate gradient can be summarized as follows:

 

Once you understand TRPO, PPO is much easier to be understood. PPO just simplifies the constrained optimization in TRPO to an unconstrained optimization problem. The main objective function of PPO is:

L^{CLIP}\newline=CLIP\left(\mathbb{E}_{s\sim \rho_{\pi_{\theta_{old}}}, a\sim \pi_{\theta_{old}}} \left[\frac{\pi_{\theta}(a|s)}{\pi_{\theta_{old}}(a|s)} A_{\pi_{\theta_{old}}}(s,a)\right]\right)\newline=CLIP\left(\mathbb{E}_{s\sim \rho_{\pi_{\theta_{old}}}, a\sim \pi_{\theta_{old}}} \left[r(\theta)A_{\pi_{\theta_{old}}}(s,a)\right]\right)\newline=\mathbb{E}_{s\sim \rho_{\pi_{\theta_{old}}}, a\sim \pi_{\theta_{old}}} \left[min(r(\theta), clip(r(\theta), 1-\epsilon, 1+\epsilon)) \cdot A_{\pi_{\theta_{old}}}(s,a)\right] 

Let’s understand this clipped objective by some example. If A_{\pi_{\theta_{old}}}(s,a) is positive, the model tries to increase r(\theta) but no more than 1+\epsilon; if A_{\pi_{\theta_{old}}}(s,a) is negative, the model tries to decrease r(\theta) but no less than 1-\epsilon. This way, the change of the policy is limited.

Now, let me examine the paper [5], because of which I started investigating TRPO and PPO. [5] models each joint in a robot in a locomotion task as a node in a graph. In graph learning, each node is represented as an embedding. At any given time step, the state feature can be summarized by all nodes’ embeddings. In [5], the policy network is learned through PPO. 

[4] uses graph learning to model job scheduling graphs. But they use REINFORCE to learn their policy because their actions are defined as selecting one node out of all available nodes for scheduling. 

Besides using graph learning to encode a graph into some feature representation, I’ve also seen people using tree-LSTM to encode tree-like graphs [14]. 

 

 

Reference

[1] Spinning Up doc about TRPO: https://spinningup.openai.com/en/latest/algorithms/trpo.html

[2] Trust Region Policy Optimization: https://arxiv.org/abs/1502.05477

[3] Proximal Policy Optimization Algorithms: https://arxiv.org/abs/1707.06347

[4] Learning Scheduling Algorithms for Data Processing Clusters: https://arxiv.org/abs/1810.01963

[5] NerveNet: Learning Structured Policy with Graph Neural Networks: https://openreview.net/pdf?id=S1sqHMZCb

[6] An Introduction to the Conjugate Gradient Method Without the Agonizing Pain: https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf

[7] CS294-112 10/11/17: https://www.youtube.com/watch?v=ycCtmp4hcUs&feature=youtu.be&list=PLkFD6_40KJIznC9CDbVTjAF2oyt8_VAe3

[8] Towards Data Science PPO vs. TRPO: https://towardsdatascience.com/introduction-to-various-reinforcement-learning-algorithms-part-ii-trpo-ppo-87f2c5919bb9

[9] Efficiently Computing the Fisher Vector Product in TRPO: http://www.telesens.co/2018/06/09/efficiently-computing-the-fisher-vector-product-in-trpo/

[10] TRPO (Trust Region Policy Optimization) : In depth Research Paper Review: https://www.youtube.com/watch?v=CKaN5PgkSBc

[11] Optimization Overview: https://czxttkl.com/2016/02/22/optimization-overview/

[12] https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/

[13] https://stats.stackexchange.com/questions/49942/why-is-the-fisher-information-matrix-positive-semidefinite

[14] Learning to Perform Local Rewriting for Combinatorial Optimization https://arxiv.org/abs/1810.00337

 

Notes on “Recommending What Video to Watch Next: A Multitask Ranking System”

Share some thoughts on this paper: Recommending What Video to Watch Next: A Multitask Ranking System [1]

The main contribution of this work is two parts: (1) a network architecture that learns on multiple objectives; (2) handles position bias in the same model

To first contribution is achieved by “soft” shared layers. So each objective does not necessarily use one set of expert parameters; instead it can use multiple sets of expert parameters controlled by a gating network:

 

The second contribution is to have a shallow network directly accounting for positions.

One thing that the paper is not clear about is how to finally use the multi-objective prediction. At the end of Section 4.2, it says:

we take the input of these multiple predictions, and output a combined score using a combination function in the form of weighted multiplication. The weights are manually tuned to achieve best performance on both user engagements and user satisfactions.

But I guess there might be some sweaty work needed to come to the optimal combination.

The final result looks pretty good, with improvement in all the metrics:

 

Since we are talking about position bias, I am interested to see some previous works in battling position bias, which is usually linked to contextual bandit works. The works in position bias and contextual bandit both battle with disproportionate data distribution. While position bias means certain logged samples are either over- or under-represented due to different levels of attention paid to different positions, contextual bandit models try to derive the optimal policy from the data distribution which is generated by sub-optimal policies thus is disproportionate to the data distribution by the optimal policy.

[2] introduces a fundamental framework for learning contextual bandit problems, coined as counterfactual risk minimization (CRM). A good mental model of contextual bandit models is that these models’ outputs are always stochastic and the output distribution keeps updated such that more probable actions should eventually correlate to the ones leading to higher returns.

Note the objective used for batch w/bandit in Table 1. If you look at the paper [2], you’ll know that \hat{R}^M(h) is some clipped version of importance sampling-based error:

A typical loss function for contextual bandit models would be \hat{R}^M(h) and [2]’s main contribution is to augment the loss function with C\cdot Reg(\mathcal{H}) + \lambda \cdot \sqrt(Var(h)/n) (I need to verify if this sentence is correct). But as [2] suggested, the part \lambda \cdot \sqrt(Var(h)/n) is said to “penalize the hypotheses with large variance during training using a data-dependant regularizer”.

Another sound previous work is [3]. The main idea is that a randomized dataset is used to estimate position bias. The intuition is simple: if you just randomize the order of a model’s ranking output, the click rate at each position would not be affected by any content placed on that content because content effects would be cancelled out by randomization. So the click rate at each position would only be related to the position itself, and that’s exactly position bias.

[4, 5] claims that you can account for position bias through Inverse Propensity Scores (IPS) when you evaluate a new ranking function based on implicit feedback (observable signals like clicks are implicit feedback, while the real user intent is explicit feedback but not always easily obtained). Intuitively, when we collect implicit feedback datasets like click logs, we need to account for observation bias because top positions certainly draw more feedbacks than bottom positions.

The learning-to-rank loss augmented by IPS is:

\triangle_{IPS}(\mathbf{y} | \mathbf{x}_i, \bar{\mathbf{y}_i}, o_i)\newline=\sum\limits_{y:o_i(y)=1} \frac{rank(y|\mathbf{y})\cdot r_i(y)}{Q(o_i(y)=1|\mathbf{x}_i, \bar{\mathbf{y}_i}, r_i(y))}\newline=\sum\limits_{y:o_i(y)=1 \land r_i(y)=1}\frac{rank(y|\mathbf{y})}{Q(o_i(y)=1|\mathbf{x}_i, \bar{\mathbf{y}_i}, r_i(y))},

where rank(y|\mathbf{y}) is the rank of an item y in the ranked list \mathbf{y} by the new model, \mathbf{\bar{y}_i} is the logged list, Q(\cdot) denotes what is the probability of observing an item y being clicked in the logged data. In the paper, the authors propose to express Q(\cdot) as the probability of browsing the position of y times the probability of clicking on y

Finally, this loss function can be optimized through SVM-Rank [5].

 

Reference

[1] Recommending What Video to Watch Next: A Multitask Ranking System

[2] Batch Learning from Logged Bandit Feedback through Counterfactual Risk Minimization

[3] Learning to Rank with Selection Bias in Personal Search

[4] Unbiased Learning-to-Rank with Biased Feedback ARXIV

[5] Unbiased Learning-to-Rank with Biased Feedback ICJAI

 

Convergence of Q-learning and SARSA

Here, I am listing some classic proofs regarding the convergence of Q-learning and SARSA in finite MDPs (by definition, in finite Markov Decision Process the sets of statesactions and rewards are finite [1]).

The very first Q-learning convergence proof comes from [4]. The proof is based on a very useful theorem:


Note that this theorem is general to be applied on multi-dimensional space (x can be multi-dimensional).

We can apply Theorem 1 to Q-learning whose update rule is as follows:
Q_{t+1}(s_t, a_t) = (1-\alpha(s_t, a_t))Q_t(s_t, a_t) + \alpha_t(s_t, a_t)[r_t +\gamma \max_{b \in \mathcal{A}}Q_t(s_{t+1}, b)]

Subtracting Q^*(s_t, a_t) from both sides and letting \triangle_t(s,a)=Q_t(s,a)-Q^*(s,a), we have:
\triangle_t(s_t,a_t)\newline=(1-\alpha_t(s_t, a_t))\triangle_t(s_t, a_t) + \alpha_t(s,a)[r_t + \gamma\max_{b\in \mathcal{A}}Q_t(s_{t+1}, b) - Q^*(s_t, a_t)]\newline=(1-\alpha_t(s_t, a_t))\triangle_t(s_t, a_t) + \alpha_t(s_t, a_t)F_t(s_t,a_t)

We can also let \alpha_t(s,a), the learning rate, to be zero for \forall s,a \neq s_t, a_t. Till this point, we have modeled the Q-learning process exactly as the random iterative process stated in Theorem 1. Hence, as long as we can prove the 4 assumptions are held, we can prove Q-learning converges to the optimal Q-values (i.e., \triangle_t(s,a) converges to 0).

Assumption 1: trivial to hold since we are focusing on finite MDPs

Assumption 2: \beta_n(x)=\alpha_n(x)=\text{Q-learning learning rate} based on our formulation. If we apply a GLIE learning policy (“greedy in the limit with infinite exploration”) [2], then we make assumption 2 hold:

  

Assumption 3: \|E\{F_t(s,a) | P_n\}\|_w \leq \gamma \|\triangle_t\|_w. There are some notations that may not look familiar to ordinary audience. First, \| x \|_w := \|w \cdot x\|_\infty=max |w_i x_i| means the weighted maximum norm [5]. Therefore, \|\triangle_t\|_w=\max\limits_{s,a} Q_t(s,a)-Q^*(s,a). Second, F_t(s,a) | P_n just means F_t(s,a) can be estimated using all the past interactions, i.e., Q_t in F_t(s,a) can be estimated conditional on all the past information rather than just (s_t, a_t)

To prove assumption 3 holds, we first look at how we define the optimal Q-function:
Q^*(s,a)=\sum\limits_{s'\in\mathcal{S}}P_a(s,s')[r(s,a,s')+\gamma \max\limits_{a'\in\mathcal{A}}Q^*(s',a')]

We can show that Q^*(s,a) is a fix point of a contraction operator \textbf{H}, defined over a generic function q: \mathcal{S} \times \mathcal{A} \rightarrow \mathbb{R}:
(\textbf{H}q)(s,a)=\sum\limits_{s'\in\mathcal{S}}P_a(s, s')[r(s,a,s')+\gamma \max\limits_{a' \in \mathcal{A}}q(s',a')]

Actually, \|\textbf{H}q_1 - \textbf{H}q_2\|_\infty \leq \gamma \|q_1-q_2\|_\infty, a \gamma-contraction, because:
\|\textbf{H}q_1 - \textbf{H}q_2\|_\infty \newline=\max\limits_{s,a} \big|\sum\limits_{s'\in \mathcal{S}} P_a(s,s')[r(s,a,s')+\gamma\max\limits_{a' \in \mathcal{A}}q_1(s',a') - r(s,a, s')-\gamma \max\limits_{a'\in\mathcal{A}}q_2(s',a')]\big|\newline=\max\limits_{s,a}\gamma \big|\sum\limits_{s'\in\mathcal{S}}P_a(s,s') [\max\limits_{a' \in \mathcal{A}}q_1(s',a')-\max\limits_{a' \in \mathcal{A}}q_2(s',a')]\big|\newline\text{Think of }\max |\textbf{p}*\textbf{x}| \leq \max|p_1\cdot\max|\textbf{x}|, p_2\cdot \max|\textbf{x}|, \cdots| \text{ if } \textbf{p}>0\newline\leq\max\limits_{s,a}\gamma \sum\limits_{s'\in \mathcal{S}}P_a(s,s')\big|\max\limits_{a' \in \mathcal{A}}q_1(s',a')-\max\limits_{a' \in \mathcal{A}}q_2(s',a')\big|\newline\text{norm property: }|\textbf{u}-\textbf{v}|_\infty \geq |\textbf{u}|_\infty-|\textbf{v}|_\infty\newline\leq \max\limits_{s,a}\gamma \sum\limits_{s'\in \mathcal{S}}P_a(s,s')\max\limits_{a' \in \mathcal{A}}\big|q_1(s',a')-q_2(s',a')\big| \newline \text{Enlarge the domain of max} \newline\leq \max\limits_{s,a}\gamma \sum\limits_{s'\in \mathcal{S}}P_a(s,s')\max\limits_{s'' \in \mathcal{S},a'' \in \mathcal{A}}\big|q_1(s'',a'')-q_2(s'',a'')\big|\newline=\max\limits_{s,a}\gamma \sum\limits_{s'\in \mathcal{S}}P_a(s,s')\|q_1-q_2\|_\infty\newline=\gamma \|q_1-q_2\|_\infty

Note in one of our previous note [6], we talked about the intuitive meaning of \gamma-contraction: consecutive application of \textbf{H} makes the function q closer and closer to the fixed point q^* at the rate of \gamma. Another intuitive understanding of \|\textbf{H}q_1 - \textbf{H}q_2\| \leq \gamma \|q_1-q_2\| (in terms of L2 norm) is that the distance between q_1 and q_2 is closer after applying \textbf{H}:

After proving \textbf{H} is a \gamma-contraction, we now prove assumption 3:
\|E\{F_t(s,a) | P_n\}\|_w \newline=\sum\limits_{s'\in\mathcal{S}}P_a(s,s')[r(s,a,s')+\gamma\max\limits_{a'\in\mathcal{A}}Q_t(s', a')-Q^*(s,a)]\newline=\sum\limits_{s'\in\mathcal{S}}P_a(s,s')[r(s,a,s')+\gamma\max\limits_{a'\in\mathcal{A}}Q_t(s', a')]-Q^*(s,a)\newline=(\textbf{H}Q_t)(s,a)-Q^*(s,a)\newline\text{Using the fact that }Q^*=\textbf{H}Q^*\newline=(\textbf{H}Q_t)(s,a)-(\textbf{H}Q^*)(s,a)\newline \leq \|\textbf{H}Q_t-\textbf{H}Q^*\|_\infty\newline\leq \gamma \|Q_t-Q^*\|_\infty\newline=\gamma \|\triangle_t\|_\infty

It is less obvious to me how to prove assumption 4, even after reading [3] and proposition 5.5 in [8]. But the take away of assumption 4 is that the variance of F_t(s,a) should be bounded.

At last, y0u can prove the convergence of SARSA [2] in a similar fashion. Hope I’ll have time to cover in the future.  

I think the most important note to take away from this post is the pattern to prove convergence of a learning algorithm. Researchers often need to propose variants of Q-learning (such as soft Q-values in maximum entropy environment [6], or SlateQ for dealing with combinatorial actions [9]). They usually need to prove convergence at least in the case of finite MDPs. One can start from the optimal Q-function of their learning algorithm and prove it is a fixed point of a contraction operator. Then, look at the update rule and construct a random iterative process as stated in Theorem 1. Finally, prove the 4 assumptions hold.

 

Reference

[1] https://medium.com/harder-choices/finite-markov-decision-process-8a34f5e571f9

[2] Convergence Results for Single-Step On-Policy
Reinforcement-Learning Algorithms https://link.springer.com/content/pdf/10.1023/A:1007678930559.pdf

[3] Convergence of Q-learning: a simple proof http://users.isr.ist.utl.pt/~mtjspaan/readingGroup/ProofQlearning.pdf

[4] Convergence of Stochastic Iterative Dynamic Programming Algorithms https://papers.nips.cc/paper/764-convergence-of-stochastic-iterative-dynamic-programming-algorithms.pdf

[5] https://math.stackexchange.com/questions/182054/is-this-a-norm-triangle-inequality-for-weighted-maximum-norm

[6] https://czxttkl.com/2018/10/30/notes-on-soft-actor-critic-off-policy-maximum-entropy-deep-reinforcement-learning-with-a-stochastic-actor/

[7] norm properties: https://en.wikipedia.org/wiki/Norm_(mathematics)#Properties

[8] Neuro Dynamic Programming https://www.dropbox.com/s/naenlfpy0f9vmvt/neuro-dynamic-programming-optimization-and-neural-.pdf?dl=0

[9] SlateQ: https://arxiv.org/abs/1905.12767

Cross entropy with logits

I keep forgetting the exact formulation of `binary_cross_entropy_with_logits` in pytorch. So write this down for future reference.

The function binary_cross_entropy_with_logits takes as two kinds of inputs: (1) the value right before the probability transformation (softmax) layer, whose range is (-infinity, +infinity); (2) the target, whose values are binary

binary_cross_entropy_with_logits calculates the following loss (i.e., negative log likelihood), ignoring sample weights:

    \[loss = -[target * log(\sigma(input)) + (1-target) * log(1 - \sigma(input))]\]

>>> import torch
>>> import torch.nn.functional as F
>>> input = torch.tensor([3.0])
>>> target = torch.tensor([1.0])
>>> F.binary_cross_entropy_with_logits(input, target)
tensor(0.0486)
>>> - (target * torch.log(torch.sigmoid(input)) + (1-target)*torch.log(1-torch.sigmoid(input)))
tensor([0.0486])

2019-12-03 Update

Now let’s look at the difference between N-classes cross entropy and KL-divergence loss. They refer the same thing (https://adventuresinmachinelearning.com/cross-entropy-kl-divergence/) but differ only in I/O format.

import torch.nn as nn
import torch
import torch.nn.functional as F

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
loss(input, target)
>>> tensor(1.6677, grad_fn=<NllLossBackward>)

loss1 = nn.KLDivLoss(reduction="batchmean")
loss1(nn.LogSoftmax()(input), F.one_hot(target, 5).float())
>>> tensor(1.6677, grad_fn=<NllLossBackward>)

Reference:

[1] https://pytorch.org/docs/stable/nn.html#binary-cross-entropy-with-logits

[2] https://pytorch.org/docs/stable/nn.html#bcewithlogitsloss

[3] https://stackoverflow.com/questions/34240703/what-is-logits-softmax-and-softmax-cross-entropy-with-logits

mujoco only works with gcc8

pip install mujoco-py would only build with gcc8. On Mac, use ll /usr/local/Cellar/gcc* to find all gcc versions you have installed. Uninstall them and only install gcc@8.

brew uninstall gcc
brew uninstall gcc@7
brew uninstall gcc@9
brew install gcc@8

Another time I saw the following error when using pip install mujoco-py:

  Building wheel for mujoco-py (PEP 517) ... error
  ERROR: Command errored out with exit status 1:
   command: /Users/czxttkl/anaconda2/envs/softlearning/bin/python /Users/czxttkl/anaconda2/envs/softlearning/lib/python3.7/site-packages/pip/_vendor/pep517/_in_process.py build_wheel /var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/tmp8nhhe2ui
       cwd: /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py
  Complete output (34 lines):
  running bdist_wheel
  running build
  Removing old mujoco_py cext /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/cymj_2.0.2.9_37_macextensionbuilder_37.so
  Compiling /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/cymj.pyx because it depends on /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-build-env-phfzl25m/overlay/lib/python3.7/site-packages/Cython/Includes/numpy/__init__.pxd.
  [1/1] Cythonizing /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/cymj.pyx
  running build_ext
  building 'mujoco_py.cymj' extension
  creating /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder
  creating /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder/temp.macosx-10.9-x86_64-3.7
  creating /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder/temp.macosx-10.9-x86_64-3.7/private
  creating /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder/temp.macosx-10.9-x86_64-3.7/private/var
  creating /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder/temp.macosx-10.9-x86_64-3.7/private/var/folders
  creating /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder/temp.macosx-10.9-x86_64-3.7/private/var/folders/kt
  creating /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder/temp.macosx-10.9-x86_64-3.7/private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0
  creating /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder/temp.macosx-10.9-x86_64-3.7/private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T
  creating /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder/temp.macosx-10.9-x86_64-3.7/private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff
  creating /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder/temp.macosx-10.9-x86_64-3.7/private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py
  creating /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder/temp.macosx-10.9-x86_64-3.7/private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py
  creating /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder/temp.macosx-10.9-x86_64-3.7/private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/gl
  /usr/local/bin/gcc-8 -Wno-unused-result -Wsign-compare -Wunreachable-code -DNDEBUG -g -fwrapv -O3 -Wall -I/Users/czxttkl/anaconda2/envs/softlearning/include -arch x86_64 -I/Users/czxttkl/anaconda2/envs/softlearning/include -arch x86_64 -DONMAC -Imujoco_py -I/private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py -I/Users/czxttkl/.mujoco/mujoco200/include -I/private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-build-env-phfzl25m/overlay/lib/python3.7/site-packages/numpy/core/include -I/Users/czxttkl/anaconda2/envs/softlearning/include/python3.7m -c /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/cymj.c -o /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/generated/_pyxbld_2.0.2.9_37_macextensionbuilder/temp.macosx-10.9-x86_64-3.7/private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/cymj.o -fopenmp -w
  In file included from /usr/local/Cellar/gcc@8/8.3.0/lib/gcc/8/gcc/x86_64-apple-darwin18/8.3.0/include-fixed/syslimits.h:7,
                   from /usr/local/Cellar/gcc@8/8.3.0/lib/gcc/8/gcc/x86_64-apple-darwin18/8.3.0/include-fixed/limits.h:34,
                   from /Users/czxttkl/anaconda2/envs/softlearning/include/python3.7m/Python.h:11,
                   from /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/cymj.c:58:
  /usr/local/Cellar/gcc@8/8.3.0/lib/gcc/8/gcc/x86_64-apple-darwin18/8.3.0/include-fixed/limits.h:194:61: error: no include path in which to search for limits.h
   #include_next <limits.h>  /* recurse down to the real one */
                                                               ^
  In file included from /Users/czxttkl/anaconda2/envs/softlearning/include/python3.7m/Python.h:25,
                   from /private/var/folders/kt/bt_h1rbx0658mqgh49cmt1xwb9g9_0/T/pip-install-g5afauff/mujoco-py/mujoco_py/cymj.c:58:
  /usr/local/Cellar/gcc@8/8.3.0/lib/gcc/8/gcc/x86_64-apple-darwin18/8.3.0/include-fixed/stdio.h:78:10: fatal error: _stdio.h: No such file or directory
   #include <_stdio.h>
            ^~~~~~~~~~
  compilation terminated.
  error: command '/usr/local/bin/gcc-8' failed with exit status 1
  ----------------------------------------
  ERROR: Failed building wheel for mujoco-py
Failed to build mujoco-py
ERROR: Could not build wheels for mujoco-py which use PEP 517 and cannot be installed directly

This error is suspected to be due to a corrupted gcc@8. I solved this by using brew reinstall gcc@8.

However during the reinstallation, I encountered another error:

Error: Your Xcode (11.0) is outdated.
Please update to Xcode 11.3 (or delete it).
Xcode can be updated from the App Store.

, which I solved by using xcode-select –install.

Notes for “Defensive Quantization: When Efficiency Meets Robustness”

I have been reading “Defensive Quantization: When Efficiency Meets Robustness” recently. Neural network quantization is a brand-new topic to me so I am writing some notes down for learning. 

The first introduction I read is [1], from which I learn that the term “quantization” generally refers to reducing the memory usage of model weights by lowering representation precision. It could mean several similar concepts: (1) low precision: convert FP32 (floating point of 32 bits) to FP16, INT8, etc; (2) mixed precision, use FP16 for some weights while still keeping using FP32 for other weights (so that accuracy can be maintained); (3) exact quantization, basically using INT8 for all weights.

The next thing to understand is the basic of fixed-point and floating-point representation:

 

The fixed-point format saves the integer part and fractional part as separate numbers. Therefore, the precision between two consecutive fixed-point numbers is fixed. (For example, fixed-point numbers can be 123.456, 123.457, 123.458, …, with 0.001 between each two numbers.) The floating-point format saves a number as significand \times base^{exponent} so significand and exponent are saved as two separate numbers. Note that the precision between two consecutive floating-point numbers is not fixed but actually depends on exponents. For every exponent, the number of representable floating-point numbers is the same. But for a smaller exponent, the density of representable floating-point numbers is high; while for a larger exponent, the density is low. As a result, the closer a real value number is to 0, the more accurate it can be represented as a floating-point number. In practice, floating-point numbers are quite accurate and precise even with the largest exponent.  

The procedure to quantize FP32 to INT8 is shown in Eqn.6~9 in [1]. In short, you need to find a scale and a reference point (X_{zero_point} in the equations) to make sure every FP32 can fall into INT8’s value range [0, 255].

Eqn. 6 ~ 9 only describes how to quantize weights but there is no guarantee that arithmetic operated on quantized weights would still fall into the quantized range. For example, there is an operator that adds two FP32 weights. While we can quantize the two weights into INT8, it is possible that the operator would result to overflow after adding the two quantized weights. Therefore, people have designed dedicated procedures to perform quantized arithmetic, such as multiplication, addition, subtraction, so on. See Eqn. 10~16 and 17~26 for the example of multiplication and addition, respectively.

While quantization is one way to reduce memory usage of models, compression is another alternative. [3] is one very famous work on compressing deep learning models. 

 

Since I am an RL guy, I’d also like to introduce another work [4]. This work uses RL for model compression. They designed an MDP where the state is each layer’s features, action is compression ratio per layer (so the dimension of action is one and range is (0,1)); reward is model performance on the validation dataset, only evaluated at the last step of the MDP.  The reward can also be tweaked to add resource constraint, such as latency. The take-away result is accuracy can be almost preserved while model size is compressed to a good ratio. 

 

References

[1] Neural Network Quantization Introduction

[2] Defensive Quantization: When Efficiency Meets Robustness

[3] Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffman Coding

[4] AMC: AutoML for Model Compression and Acceleration on Mobile Devices

[5] https://www.geeksforgeeks.org/floating-point-representation-digital-logic/

Gradient and Natural Gradient, Fisher Information Matrix and Hessian

Here I am writing down some notes summarizing my understanding in natural gradient. There are many online materials covering similar topics. I am not adding anything new but just doing personal summary.

Assume we have a model with model parameter \theta. We have training data x. Then, the Hessian of log likelihood, log p(x|\theta), is:

    \[       H_{log p(x|\theta)} = \begin{bmatrix} \frac{\partial^2 logp(x|\theta)}{\partial \theta_1^2} & \dots & \frac{\partial^2 logp(x|\theta)}{\partial \theta_1 \partial \theta_n} \\ \dots & \dots & \dots \\ \frac{\partial^2 logp(x|\theta)}{\partial \theta_n \partial \theta_1} & \dots & \frac{\partial^2 logp(x|\theta)}{\partial \theta_n^2} \end{bmatrix} \]

Fisher information matrix F is defined as the covariance matrix of \nabla_\theta log p(x|\theta) (note that \nabla_\theta log p(x|\theta) itself is a vector). If we define s(\theta) = \nabla_\theta log p(x|\theta), then F=\mathbb{E}_{p(x|\theta)}\left[\left(s(\theta) - mean(s(\theta))\right)\left(s(\theta) - mean(s(\theta))\right)^T\right]=\mathbb{E}_{p(x|\theta)}\left[s(\theta)s(\theta)^T\right]. The last equation holds because mean(s(\theta)) is 0 [1]. It can be shown that the negative of \mathbb{E}_{p(x|\theta)}[H_{log p(x|\theta}] is Fisher information matrix [1]. Therefore, Fisher information matrix can be thought as the (negative) curvature of log likelihood. The curvature (Hessian) of a function is the second-order derivative of the function, which depicts how quickly the gradient of the function changes. Computing Hessian (curvature) takes longer time than computing just gradient but knowing Hessian can accelerate learning convergence. If the curvature is high, the gradient changes quickly, then the gradient update of the parameters should be more cautious (i.e., smaller step); if the curvature is low, the gradient doesn’t change much, then the gradient update can be more aggressive (i.e., large step). Moreover, the eigenvalues of a Hessian determines convergence speed [3]. Therefore, knowing Fisher information matrix is quite important in optimizing a function.

Suppose we are going to optimize a loss function: \min \mathcal{L}(x, \theta). The normal gradient update is: \theta \leftarrow \theta - \alpha \nabla_\theta \mathcal{L}(x,\theta). The natural gradient has a little different form: \theta \leftarrow \theta - \alpha F^{-1} \mathcal{L}(x,\theta). The natural gradient formula is actually derived from [5]:

    \[ \nabla_\theta^{NAT} \mathcal{L}(x,\theta) = \arg\min_{d} \mathcal{L}(x, \theta+d), \quad s.t. \quad KL\left( p(x|\theta) || p(x|\theta+d)\right) = c \]

This formula shows the intuition behind natural gradient: the natural gradient should minimize the loss as much as possible while doesn’t radically change p(x|\theta). Another way to think about natural gradient is that since the Fisher information matrix F encodes the curvature of the log likelihood, then the natural gradient is the normal gradient scaled by the reverse of the curvature: if the log likelihood’s curvature is large, that means some change in \theta could radically the likelihood, then we should be conservative in the gradient update.   

 

Reference

[1] https://wiseodd.github.io/techblog/2018/03/11/fisher-information/

[2] https://towardsdatascience.com/its-only-natural-an-excessively-deep-dive-into-natural-gradient-optimization-75d464b89dbb

[3] http://mlexplained.com/2018/02/02/an-introduction-to-second-order-optimization-for-deep-learning-practitioners-basic-math-for-deep-learning-part-1/

[4] http://kvfrans.com/a-intuitive-explanation-of-natural-gradient-descent/

[5] https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/