Cross entropy with logits

I keep forgetting the exact formulation of `binary_cross_entropy_with_logits` in pytorch. So write this down for future reference.

The function binary_cross_entropy_with_logits takes as two kinds of inputs: (1) the value right before the probability transformation (softmax) layer, whose range is (-infinity, +infinity); (2) the target, whose values are binary

binary_cross_entropy_with_logits calculates the following loss (i.e., negative log likelihood), ignoring sample weights:

    \[loss = -[target * log(\sigma(input)) + (1-target) * log(1 - \sigma(input))]\]

>>> import torch
>>> import torch.nn.functional as F
>>> input = torch.tensor([3.0])
>>> target = torch.tensor([1.0])
>>> F.binary_cross_entropy_with_logits(input, target)
tensor(0.0486)
>>> - (target * torch.log(torch.sigmoid(input)) + (1-target)*torch.log(1-torch.sigmoid(input)))
tensor([0.0486])

2019-12-03 Update

Now let’s look at the difference between N-classes cross entropy and KL-divergence loss. They refer the same thing (https://adventuresinmachinelearning.com/cross-entropy-kl-divergence/) but differ only in I/O format.

import torch.nn as nn
import torch
import torch.nn.functional as F

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
loss(input, target)
>>> tensor(1.6677, grad_fn=<NllLossBackward>)

loss1 = nn.KLDivLoss(reduction="batchmean")
loss1(nn.LogSoftmax()(input), F.one_hot(target, 5).float())
>>> tensor(1.6677, grad_fn=<NllLossBackward>)

Reference:

[1] https://pytorch.org/docs/stable/nn.html#binary-cross-entropy-with-logits

[2] https://pytorch.org/docs/stable/nn.html#bcewithlogitsloss

[3] https://stackoverflow.com/questions/34240703/what-is-logits-softmax-and-softmax-cross-entropy-with-logits

Leave a comment

Your email address will not be published. Required fields are marked *