EmbeddingBag from PyTorch

EmbeddingBag in PyTorch is a useful feature to consume sparse ids and produce embeddings.

Here is a minimal example. There are 4 ids’ embeddings, each of 3 dimensions. We have two data points, the first point has three ids (0, 1, 2) and the second point has the id (3). This is reflected in input and offsets variables: the i-th data point has the id from input[offset[i]] (inclusive) to input[offset[i+1]] (exclusive). Since we are using the “sum” mode, the first data point’s output would be the sum of the embeddings of ids (0, 1, 2), and the second data point’s output would be the embedding of id 3.

>>> embedding_sum = nn.EmbeddingBag(4, 3, mode='sum')
>>> embedding_sum.weight
Parameter containing:
tensor([[-0.9674, -2.3095, -0.2560],
        [ 0.0061, -0.4309, -0.7920],
        [-1.3457,  0.8978,  0.1271],
        [-1.8232,  0.6509, -1.2162]], requires_grad=True)
>>> input = torch.LongTensor([0,1,2,3])
>>> offsets = torch.LongTensor([0,3])
>>> embedding_sum(input, offsets)
tensor([[-2.3070, -1.8426, -0.9209],
        [-1.8232,  0.6509, -1.2162]], grad_fn=<EmbeddingBagBackward>)
>>> torch.sum(embedding_sum.weight[:3], dim=0)
tensor([-2.3070, -1.8426, -0.9209], grad_fn=<SumBackward1>)
>>> torch.sum(embedding_sum.weight[3:], dim=0)
tensor([-1.8232,  0.6509, -1.2162], grad_fn=<SumBackward1>)

Join the Conversation

1 Comment

Leave a comment

Your email address will not be published. Required fields are marked *