Every model you train is the same four lines on repeat: guess, measure how wrong, find which way is downhill, take a step.
Training a model means nudging its numbers — its weights — until its predictions stop being so wrong. You measure "wrong" with a loss: one number that gets smaller as the model gets better. The whole job is to make that number go down.
PyTorch does this with the same little dance every iteration, called an epoch when it covers the data once. Forward pass: predict, then compute the loss. Backward pass: loss.backward() asks autograd "if I wiggle each weight, which way does the loss move?" Then optimizer.step() nudges every weight a small step downhill. One subtlety wrecks more first attempts than any other — gradients add up by default, so you must clear them with optimizer.zero_grad() before each backward pass.
The loop below is the canonical shape of almost every PyTorch training script. Read it top to bottom: clear last epoch's gradients, predict, score, back-propagate, step.
import torch
import torch.nn as nn
model = nn.Linear(1, 1) # one weight w and bias b
criterion = nn.MSELoss() # mean squared error
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
model.train() # train mode (dropout, batchnorm)
for epoch in range(200):
optimizer.zero_grad() # clear grads — they ACCUMULATE
pred = model(x) # forward pass: prediction
loss = criterion(pred, y) # forward pass: how wrong
loss.backward() # backward pass: autograd fills .grad
optimizer.step() # update: w -= lr * w.grad, etc.
Why optimizer.zero_grad()? When you call loss.backward(), PyTorch adds the new gradients into each parameter's .grad buffer instead of overwriting it. That accumulation is deliberate — it lets you split a big batch across several backward passes — but if you forget to clear it, epoch 5's update is computed from the summed gradients of epochs 1 through 5. The step direction is corrupted and training destabilises. So zero_grad() resets every .grad to zero at the top of each iteration, and step() then applies plain gradient descent: for SGD, param -= lr * param.grad.
| Symptom | Likely cause | Fix |
|---|---|---|
| Loss goes to NaN / inf | Learning rate too high, or missing zero_grad() so grads explode | Lower lr (try 10×), add zero_grad(), clip gradients |
| Loss barely moves | Learning rate too low, or vanishing gradients | Raise lr, normalise inputs, check network depth |
| Loss zig-zags then climbs | Step size overshoots the minimum | Reduce lr or add a learning-rate schedule |
| Train loss down, val loss up | Overfitting — memorising, not learning | Regularise, add data, early-stop on val loss |
optimizer.zero_grad(). Gradients accumulate across iterations, so without clearing them each step is computed from a growing sum of past gradients — the update direction drifts and the loss often blows up.loss.backward() twice. The graph is freed after the first backward by default; a second call raises unless you pass retain_graph=True — and even then you're usually accumulating gradients you didn't mean to.model.train() / model.eval(). Dropout and batch-norm behave differently in the two modes. Evaluate in eval(), train in train(), and wrap validation in torch.no_grad()..to(device).Say your data lies roughly on the line y = 0.8x + 0.5 and you start from w = 0, b = 0. The first forward pass predicts all zeros, so the MSE loss is large. loss.backward() computes dL/dw and dL/db — both negative here, because raising w and b would shrink the error. With lr = 0.03, optimizer.step() nudges w and b a little in the positive direction. Repeat: each epoch the fitted line rotates and rises toward the cloud of points, and the loss readout falls — fast at first, then slowing as it nears the minimum. Crank lr up too far and you'll watch the same line whip past the points and the loss shoot back up: that's divergence, and it's exactly the NaN trap from the table above.
You delete optimizer.zero_grad() from the loop and rerun. What happens to training?