Setting results of torch.gather(…) calls

I have a 2D pytorch tensor of shape n by m. I want to index the second dimension using a list of indices (which could be done with torch.gather) then then also set new values to the result of the indexing.

Example:

data = torch.tensor([[0,1,2], [3,4,5], [6,7,8]]) # shape (3,3)
indices = torch.tensor([1,2,1], dtype=torch.long).unsqueeze(-1) # shape (3,1)
# data tensor:
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [6, 7, 8]])

I want to select the specified indices per row (which would be [1,5,7] but then also set these values to another number – e.g. 42

I can select the desired columns row wise by doing:

data.gather(1, indices)
tensor([[1],
        [5],
        [7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather 
                                # does not use the same storage as the original tensor

which is fine, but I would like to change these values now, and have the change also affect the data tensor.

I can do what I want to achieve using this, but it seems to be very un-pythonic:

max_index = torch.max(indices)
for i in range(0, max_index + 1):
  mask = (indices == i).nonzero(as_tuple=True)[0]
  data[mask, i] = 42
print(data)
# tensor([[ 0, 42,  2],
#         [ 3,  4, 42],
#         [ 6, 42,  8]])

Any hints on how to do that more elegantly?

Answer

What you are looking for is torch.scatter_ with the value option.

Tensor.scatter_(dim, index, src, reduce=None) → Tensor
Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

With 2D tensors as input and dim=1, the operation is:
self[i][index[i][j]] = src[i][j]

No mention of the value parameter though…


With value=42, and dim=1, this will have the following effect on data:

data[i][index[i][j]] = 42

Here applied in-place:

>>> data.scatter_(index=indices, dim=1, value=42)
>>> data
tensor([[ 0, 42,  2],
        [ 3,  4, 42],
        [ 6, 42,  8]])