I’m using some code of ShuffleNet, but I have a problem with understanding the calculation of
correct in this function.(this function calculates precision 1 and 5).
As I understand in the third line
pred is the indices, but I can’t understand why two lines later with equivalence function it has been compared with the
pred is indices of the most probabilities of output.
def accuracy(output, target, topk=(1,)): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res =  for k in topk: correct_k = correct[:k].contiguous().view(-1).float().sum(0) res.append(correct_k.mul_(100.0/batch_size)) return res
Looking at the code, I can speculate
output to be shaped
(batch_size, n_logits) while the target is a dense representation: shaped
(batch_size, 1). This means the ground truth class is designated by an integer value: the corresponding class label.
If we look into this implementation of the top-
k accuracy, we first need to understand this: top-
k accuracy is about counting how many ground truth labels are among the k highest predictions of our output. It’s essentially a generalized form of the standard top-
1 accuracy where we would only look at the single highest prediction and find out if it matches the target.
If we take a simple example with
k=3 i.e. we’re interested in the top-
3 accuracy. Here we sample a random prediction:
>>> output tensor([[0.2110, 0.9992, 0.0597, 0.9557, 0.8316, 0.8407, 0.8398, 0.3631, 0.2889, 0.3226], [0.6811, 0.2932, 0.2117, 0.6522, 0.2734, 0.8841, 0.0336, 0.7357, 0.9232, 0.2633]])
We first look at the
k highest logits and retrieve their indices:
>>> _, pred = output.topk(k=3, dim=1, largest=True, sorted=True) >>> pred tensor([[3, 6, 4], [7, 3, 5]])
This is nothing more than a sliced
output.argsort(1, descending=True)[:, :3] will return the same result.
We can then transpose to get batches last
>>> pred = pred.T tensor([[3, 7], [6, 3], [4, 5]])
Now that we have the top-
k predictions for each batch element, we need to compare those with the ground truths. Let us imagine now a target tensor (remember is shaped as
>>> target tensor([, ])
We first need to expand it to the shape of
>>> target.view(1, -1).expand_as(pred) tensor([[1, 0], [1, 0], [1, 0]])
We then compare eachother with
torch.eq, the element-wise equality operator:
>>> correct = torch.eq(pred, target.view(1, -1).expand_as(pred)) tensor([[False, False], [False, False], [False, True]])
As you can tell on the 2nd batch element, one of the highest three matches the ground-truth class label
5. On the first batch element, neither of the three highest predictions matches the ground-truth label, it’s not correct. The second batch element counts as one ‘correct’.
Of course, based on this equality mask tensor
correct, you can slice it even more, to compute other top-
k' accuracies where
k' <= k. For instance
k' = 1:
>>> correct[:1] tensor([[False, False]])
Here for the top-
1 accuracy, we have zero correct instances out of the two batch elements.