Model evaluation in PyTorch

Switch the model out of training mode, freeze the gradients, then count how often it was right.

The idea

A trained model still has two faces. During training it deliberately adds noise — dropout randomly zeroes activations, and batch-norm uses the statistics of the current mini-batch. That is great for learning, but it makes predictions jittery and non-deterministic.

Evaluation means asking the model a fair, final question: given inputs you did not train on, how often are you right? To ask it fairly you flip the model into eval() mode, turn gradients off with torch.no_grad(), run the validation set through once, and tally the results into a confusion matrix you can read metrics off.

model.train() grad: on x val sample predicted spam     not-spam actual 0 0 0 0 TP FN FP TN
Press play to run the validation set through the model.

How it works

The loop is small but every line earns its place. eval() changes layer behaviour; no_grad() stops PyTorch building the autograd graph (faster, less memory); argmax turns logits into a class; and the running counts become accuracy, precision and recall at the end.

model.eval()                      # dropout off, batch-norm uses running stats
correct, total = 0, 0
tp = fp = fn = tn = 0

with torch.no_grad():             # no autograd graph — faster, no memory leak
    for x, y in val_loader:
        logits = model(x)         # raw scores, shape [batch, n_classes]
        pred = logits.argmax(dim=1)

        correct += (pred == y).sum().item()
        total   += y.size(0)

        # for the "spam" class (label 1):
        tp += ((pred == 1) & (y == 1)).sum().item()
        fp += ((pred == 1) & (y == 0)).sum().item()
        fn += ((pred == 0) & (y == 1)).sum().item()
        tn += ((pred == 0) & (y == 0)).sum().item()

acc       = correct / total
precision = tp / (tp + fp) if tp + fp else 0.0
recall    = tp / (tp + fn) if tp + fn else 0.0

Signals

MetricReadsWhen it matters
Accuracy(TP+TN) / allBalanced classes; misleading when one class is rare
PrecisionTP / (TP+FP)False alarms are costly (flagging real mail as spam)
RecallTP / (TP+FN)Misses are costly (letting spam through, missing a tumor)
F1harmonic meanYou need one number balancing precision and recall

Watch out for

Worked example

You evaluate a spam classifier on 200 validation emails: 40 are truly spam. The model flags 36 as spam — 30 of those are real spam (TP=30), 6 are good mail it got wrong (FP=6), and it misses 10 spam emails (FN=10). The rest are correct (TN=154). Accuracy is (30+154)/200 = 92%, which sounds great. But recall is 30/40 = 75% — a quarter of spam slips through. Precision is 30/36 = 83%. The single accuracy number hid the real story; the confusion matrix told it.

Check yourself

Your metrics change slightly every time you run evaluation on the same data. What is the most likely cause?

A fraud model scores 99% accuracy but the team is unhappy. Which metric most likely explains it?