Calculate accuracy function in ShuffleNet

I’m using some code of ShuffleNet, but I have a problem with understanding the calculation of correct in this function.(this function calculates precision 1 and 5).
As I understand in the third line pred is the indices, but I can’t understand why two lines later with equivalence function it has been compared with the target, because pred is indices of the most probabilities of output.

def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].contiguous().view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0/batch_size))
    return res

Answer

Looking at the code, I can speculate output to be shaped (batch_size, n_logits) while the target is a dense representation: shaped (batch_size, 1). This means the ground truth class is designated by an integer value: the corresponding class label.

If we look into this implementation of the top-k accuracy, we first need to understand this: top-k accuracy is about counting how many ground truth labels are among the k highest predictions of our output. It’s essentially a generalized form of the standard top-1 accuracy where we would only look at the single highest prediction and find out if it matches the target.

If we take a simple example with batch_size=2, n_logits=10, and k=3 i.e. we’re interested in the top-3 accuracy. Here we sample a random prediction:

>>> output
tensor([[0.2110, 0.9992, 0.0597, 0.9557, 0.8316, 0.8407, 0.8398, 0.3631, 0.2889, 0.3226],
        [0.6811, 0.2932, 0.2117, 0.6522, 0.2734, 0.8841, 0.0336, 0.7357, 0.9232, 0.2633]])

We first look at the k highest logits and retrieve their indices:

>>> _, pred = output.topk(k=3, dim=1, largest=True, sorted=True)

>>> pred
tensor([[3, 6, 4],
        [7, 3, 5]])

This is nothing more than a sliced torch.argsort: output.argsort(1, descending=True)[:, :3] will return the same result.

We can then transpose to get batches last (3, 2):

>>> pred = pred.T
tensor([[3, 7],
        [6, 3],
        [4, 5]])

Now that we have the top-k predictions for each batch element, we need to compare those with the ground truths. Let us imagine now a target tensor (remember is shaped as (batch_size=2, 1)):

>>> target
tensor([[1],
        [5]]) 

We first need to expand it to the shape of pred:

>>> target.view(1, -1).expand_as(pred)
tensor([[1, 0],
        [1, 0],
        [1, 0]])

We then compare eachother with torch.eq, the element-wise equality operator:

>>> correct = torch.eq(pred, target.view(1, -1).expand_as(pred))
tensor([[False, False],
        [False, False],
        [False,  True]])

As you can tell on the 2nd batch element, one of the highest three matches the ground-truth class label 5. On the first batch element, neither of the three highest predictions matches the ground-truth label, it’s not correct. The second batch element counts as one ‘correct’.

Of course, based on this equality mask tensor correct, you can slice it even more, to compute other top-k' accuracies where k' <= k. For instance k' = 1:

>>> correct[:1]
tensor([[False, False]])

Here for the top-1 accuracy, we have zero correct instances out of the two batch elements.