EmbeddingBag in PyTorch is a useful feature to consume sparse ids and produce embeddings.
Here is a minimal example. There are 4 ids’ embeddings, each of 3 dimensions. We have two data points, the first point has three ids (0, 1, 2) and the second point has the id (3). This is reflected in input
and offsets
variables: the i-
th data point has the id from input[offset[i]]
(inclusive) to input[offset[i+1]]
(exclusive). Since we are using the “sum” mode, the first data point’s output would be the sum of the embeddings of ids (0, 1, 2), and the second data point’s output would be the embedding of id 3.
>>> embedding_sum = nn.EmbeddingBag(4, 3, mode='sum')
>>> embedding_sum.weight
Parameter containing:
tensor([[-0.9674, -2.3095, -0.2560],
[ 0.0061, -0.4309, -0.7920],
[-1.3457, 0.8978, 0.1271],
[-1.8232, 0.6509, -1.2162]], requires_grad=True)
>>> input = torch.LongTensor([0,1,2,3])
>>> offsets = torch.LongTensor([0,3])
>>> embedding_sum(input, offsets)
tensor([[-2.3070, -1.8426, -0.9209],
[-1.8232, 0.6509, -1.2162]], grad_fn=<EmbeddingBagBackward>)
>>> torch.sum(embedding_sum.weight[:3], dim=0)
tensor([-2.3070, -1.8426, -0.9209], grad_fn=<SumBackward1>)
>>> torch.sum(embedding_sum.weight[3:], dim=0)
tensor([-1.8232, 0.6509, -1.2162], grad_fn=<SumBackward1>)
this is really helpful
thanks