I’ve been recently implementing world model [1], which seems a promising algorithm to effectively learn controls after learning environments first. Here I share some implementation notes.
Loss of Gaussian Mixture Model
The memory model of world model is a Mixture-Density-Network Recurrent Neural Network (MDN-RNN). It takes current state and action as inputs, and outputs the prediction of next state, terminal, and reward. Since state transition might be stochastic, the authors choose to model the next state output as a gaussian mixture model: the next state can come from several multi-dimensional gaussian components with certain probability assignments. The probability of the next state coming from component $latex i$ is $latex \pi_i$, and that component itself is a gaussian distribution with mean $latex \mu_i \in \mathbb{R}^f$ and std. dev $latex \Sigma \in \mathbb{R}^{f*f}$, where $latex f$ is the feature size of a state.
To learn the parameters of MDN-RNN that predict $latex \pi_i$, $latex \mu_i$ and $latex \sigma_i$, we need to fit a loss function that defines how closely the observation of next state really fits the predicted gaussian mixture model. The loss function is negative log likelihood. Since we usually fit data in batches, and one batch is a sequence of consecutive states from the same episodes, the loss function is defined over batch_size
and seq_len
:
$latex – \sum\limits_{z=0}^{batch-size} \sum\limits_{s=0}^{seq-len} log (\sum\limits_i \pi_i \mathcal{N}(x^{z,s} | \mu^{z,s}_{i}, \Sigma^{z,s}_{i})) &s=2$
If we ignore the negative sign for now and only look at one state for a particular batch $latex z$ and sequence element $latex s$, we get:
$latex log(\sum\limits_i \pi_i \mathcal{N}(x | \mu_{i}, \Sigma_{i}))&s=2$
In practice, MDN-RNN usually predicts $latex log(\pi_{i})$, and we can get $latex log(\mathcal{N}(x | \mu_{i}, \Sigma_{i}))$ from pytorch endpoint [3], so the formula above can be re-written as:
$latex log \sum\limits_i e^{log(\pi_i \mathcal{N}(x | \mu_{i}, \Sigma_{i}))} \newline =log\sum\limits_i e^{log(\pi_i) + log(\mathcal{N}(x | \mu_{i}, \Sigma_{i}))} &s=2$
According to log-sum-exp trick [2], to get numerical stability,
$latex log(\sum\limits_{i=1}^n e^{x_i}) = log(\sum\limits_{i=1}^n e^{x_i – c}) + c &s=2$
where $latex c=max_i x_i$ is usually picked.
The take away is that we walk through how the loss function for GMM is defined, and adopts the log-sum-exp trick when calculating it.
References
[1] https://worldmodels.github.io/
[2] https://blog.feedly.com/tricks-of-the-trade-logsumexp/
[3] https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal.log_prob