PyTorch Tensors – vectorized slicing with given list of end indices

Suppose I have a 1D PyTorch tensor end_index of length L.

I want to construct a 2D PyTorch tensor T with L lines where T[i,j] = 2 when j < end_index[i] and T[i,j] = 1 otherwise.

The following works:

T = torch.ones([4,3], dtype=torch.long)
for element in end_index:
    T[:, :element] = 2

Is there a way to do this without a for loop? (To vectorize this).

Answer

You can construct such a tensor using broadcast semantics

# sample inputs
L, C = 4, 3
end_index = torch.tensor([0, 2, 2, 1])

# Construct tensor of shape [L, C] such that for all (i, j)
#     T[i, j] = 2 if j < end_index[i] else 1
j_range = torch.arange(C, device=end_index.device)
T = (j_range[None, :] < end_index[:, None]).long() + 1

which results in

T = 
tensor([[1, 1, 1],
        [2, 2, 1],
        [2, 2, 1],
        [2, 1, 1]])

Leave a Reply

Your email address will not be published. Required fields are marked *