Summing Torch tensor based on another tensor

I have two tensors, where the first contains floats and the second contains 0s and 1s. I want to sum over the first tensor based on the second tensor. More specifically, I want to sum between the occurrence of two 0s. For example, consider

a = tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
b = tensor([0., 1., 1., 1., 0., 1., 1., 1., 1., 0.])

I want some vectorised (preferably) operation that receives the two tensors and returns

c = tensor([4., 5., 1.] 

c is just the sum of the elements of tensor a, between the occurrence of two 0s in the tensor b.

Answer

you can use torch.tensor_split to split your tensor on indices of 0 in b and then sum them individually:

Eg:

group = torch.tensor_split(a, torch.where(b==0)[0])
# Output:
# (tensor([]),
# tensor([1., 1., 1., 1.]),
# tensor([1., 1., 1., 1., 1.]),
# tensor([1.]))

individual_sum = list(map(torch.sum, group))  # You can use loop/list comprehension etc
# Output
# [tensor(0.), tensor(4.), tensor(5.), tensor(1.)]

Note that the 1st 0 is also considered and results in an empty tensor after split. You can remove that while combining

torch.tensor(individual_sum[1:])
# Output
# tensor([4., 5., 1.])