I’m using some code of ShuffleNet, but I have a problem with understanding the calculation of `correct`

in this function.(this function calculates precision 1 and 5).

As I understand in the third line `pred`

is the indices, but I can’t understand why two lines later with equivalence function it has been compared with the `target`

, because `pred`

is indices of the most probabilities of output.

def accuracy(output, target, topk=(1,)): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].contiguous().view(-1).float().sum(0) res.append(correct_k.mul_(100.0/batch_size)) return res

## Answer

Looking at the code, I can speculate `output`

to be shaped `(batch_size, n_logits)`

while the target is a dense representation: shaped `(batch_size, 1)`

. This means the ground truth class is designated by an integer value: the corresponding class label.

If we look into this implementation of the top-`k`

accuracy, we first need to understand this: top-`k`

accuracy is about counting how many ground truth labels are among the *k* highest predictions of our output. It’s essentially a generalized form of the standard top-`1`

accuracy where we would only look at the single highest prediction and find out if it matches the target.

If we take a simple example with `batch_size=2`

, `n_logits=10`

, and `k=3`

*i.e.* we’re interested in the top-`3`

accuracy. Here we sample a random prediction:

>>> output tensor([[0.2110, 0.9992, 0.0597, 0.9557, 0.8316, 0.8407, 0.8398, 0.3631, 0.2889, 0.3226], [0.6811, 0.2932, 0.2117, 0.6522, 0.2734, 0.8841, 0.0336, 0.7357, 0.9232, 0.2633]])

We first look at the * k* highest logits and retrieve their indices:

>>> _, pred = output.topk(k=3, dim=1, largest=True, sorted=True) >>> pred tensor([[3, 6, 4], [7, 3, 5]])

_{This is nothing more than a sliced torch.argsort: output.argsort(1, descending=True)[:, :3] will return the same result.}

We can then transpose to get batches last `(3, 2)`

:

>>> pred = pred.T tensor([[3, 7], [6, 3], [4, 5]])

Now that we have the top-`k`

predictions for each batch element, we need to compare those with the ground truths. Let us imagine now a target tensor (remember is shaped as `(batch_size=2, 1)`

):

>>> target tensor([[1], [5]])

We first need to expand it to the shape of `pred`

:

>>> target.view(1, -1).expand_as(pred) tensor([[1, 0], [1, 0], [1, 0]])

We then compare eachother with `torch.eq`

, the element-wise equality operator:

>>> correct = torch.eq(pred, target.view(1, -1).expand_as(pred)) tensor([[False, False], [False, False], [False, True]])

As you can tell on the 2nd batch element, one of the highest three matches the ground-truth class label `5`

. On the first batch element, neither of the three highest predictions matches the ground-truth label, it’s not correct. The second batch element counts as one *‘correct’*.

Of course, based on this equality mask tensor `correct`

, you can slice it even more, to compute other top-`k'`

accuracies where `k' <= k`

. For instance `k' = 1`

:

>>> correct[:1] tensor([[False, False]])

Here for the top-`1`

accuracy, we have zero correct instances out of the two batch elements.