The PyTorch training loop

Every model you train is the same four lines on repeat: guess, measure how wrong, find which way is downhill, take a step.

The idea

Training a model means nudging its numbers — its weights — until its predictions stop being so wrong. You measure "wrong" with a loss: one number that gets smaller as the model gets better. The whole job is to make that number go down.

PyTorch does this with the same little dance every iteration, called an epoch when it covers the data once. Forward pass: predict, then compute the loss. Backward pass: loss.backward() asks autograd "if I wiggle each weight, which way does the loss move?" Then optimizer.step() nudges every weight a small step downhill. One subtlety wrecks more first attempts than any other — gradients add up by default, so you must clear them with optimizer.zero_grad() before each backward pass.

data & fitted line y = w·x + b loss vs. epoch hi 0
Press play to run gradient descent: each epoch the line shifts toward the points and the loss drops.

How it works

The loop below is the canonical shape of almost every PyTorch training script. Read it top to bottom: clear last epoch's gradients, predict, score, back-propagate, step.

import torch
import torch.nn as nn

model     = nn.Linear(1, 1)                      # one weight w and bias b
criterion = nn.MSELoss()                         # mean squared error
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)

model.train()                                    # train mode (dropout, batchnorm)
for epoch in range(200):
    optimizer.zero_grad()                        # clear grads — they ACCUMULATE
    pred = model(x)                              # forward pass: prediction
    loss = criterion(pred, y)                    # forward pass: how wrong
    loss.backward()                              # backward pass: autograd fills .grad
    optimizer.step()                             # update: w -= lr * w.grad, etc.

Why optimizer.zero_grad()? When you call loss.backward(), PyTorch adds the new gradients into each parameter's .grad buffer instead of overwriting it. That accumulation is deliberate — it lets you split a big batch across several backward passes — but if you forget to clear it, epoch 5's update is computed from the summed gradients of epochs 1 through 5. The step direction is corrupted and training destabilises. So zero_grad() resets every .grad to zero at the top of each iteration, and step() then applies plain gradient descent: for SGD, param -= lr * param.grad.

Signals & trade-offs

SymptomLikely causeFix
Loss goes to NaN / infLearning rate too high, or missing zero_grad() so grads explodeLower lr (try 10×), add zero_grad(), clip gradients
Loss barely movesLearning rate too low, or vanishing gradientsRaise lr, normalise inputs, check network depth
Loss zig-zags then climbsStep size overshoots the minimumReduce lr or add a learning-rate schedule
Train loss down, val loss upOverfitting — memorising, not learningRegularise, add data, early-stop on val loss

Watch out for

Worked example

Say your data lies roughly on the line y = 0.8x + 0.5 and you start from w = 0, b = 0. The first forward pass predicts all zeros, so the MSE loss is large. loss.backward() computes dL/dw and dL/db — both negative here, because raising w and b would shrink the error. With lr = 0.03, optimizer.step() nudges w and b a little in the positive direction. Repeat: each epoch the fitted line rotates and rises toward the cloud of points, and the loss readout falls — fast at first, then slowing as it nears the minimum. Crank lr up too far and you'll watch the same line whip past the points and the loss shoot back up: that's divergence, and it's exactly the NaN trap from the table above.

Check yourself

You delete optimizer.zero_grad() from the loop and rerun. What happens to training?