pytorch MNIST neural network produces several non-zero outputs

I tried to do a neural network that operates on MNIST data set. I was mostly following the pytorch.nn tutorial. As a result, i got a model that learns, but there’s something wrong with the process or with the model itself. Instead of one active neuron at the output, i recieve multiple ones.

Here’s the model itself:

model = nn.Sequential(
    nn.Linear(784, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
    nn.ReLU(),
)

And here’s the training process:

loss_func = nn.CrossEntropyLoss()
opt = optim.SGD(model.parameters(), lr=lr)

for epoch in range(epochs):
    model.train()
    for xbt, ybt in train_dl:
        pred = model(xbt)
        loss = loss_func(pred, ybt)
        opt.zero_grad()
        loss.backward()
        opt.step()
        

    model.eval()
    # Validation
    if epoch % 10 == 0:
        with torch.no_grad():
            losses, nums = zip(
                *[(loss_func(model(xbv), ybv), len(xbv)) for xbv, ybv in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print(epoch, val_loss)

Here’s average loss each 10th epoch:

0 0.13384412774592638
10 0.0900113809091039
20 0.09795805384699234
30 0.10341344920364791
40 0.10804545368137551

And thats how result of applying the model to the validation set looks like:

[[ 0.         0.         0.        ... 28.436266   0.         5.001435 ]
 [ 7.3331523 12.666427  31.898096  ...  0.         0.         0.       ]
 [ 0.        18.116354   8.049953  ...  4.330721   0.         0.       ]
 ...
 [ 8.504517   0.         6.302228  ...  0.         0.         0.       ]
 [ 1.7339934  0.         0.        ...  0.         2.1565871  0.       ]
 [45.750134   0.         6.2685804 ...  2.247082   0.         0.       ]]
 Shape: (9984, 10)

I tried changing learning speed, model layers, amount of epochs, but nothing seems to work.

Answer

You have 10 neurons with ReLU in the last layers and yes all the neurons will fire/activated. In this case every neuron applies a ReLu function on the output of linear activation. ie ReLu(w.x+b). There are 10 such neurons and all of them will give out certain output based on its input, and yes all of them get fired/activated. The way you infer an output from this is by taking the class corresponding to the neuron which has the hugest activation (using np.argmax or torch.max).