Tensor.scatter_() for Dummies
This little function Utilizes parameters: dim, index, src, reduce.
Tensor.scatter_()essentially uses the information fromindexto placesrcinto our belovedTensor.
Suppose we have the following code
src = torch.arange(1, 11).reshape((2, 5))
Our tensor looks like:
>> tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
Now define the index and target tensor:
idx = torch.tensor([2, 2, 2, 2, 2])
target = torch.zeros(3, 5, dtype=src.dtype)
Our target tensor will be the victim of our scatter_() bloodbath. It is simply a $3\times 5$ tensor of all zeroes. Meanwhile, idx will define “where” we place the elements of src into target. Here, it essentially acts as a middleman. Let’s see what happens if we call scatter.() with dim=0 to keep things simple:
target.scatter_(0, idx, src)
>> tensor([[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 2, 3, 4, 5]])
Mama mia. We have essentially moved the first row of src into the third row of target. Notice that the second row of src is nowhere to be found. This is because we specified idx with as a $1\times 4$ tensor, not as a $2\times 4$ tensor.