Gradient Checkpointing Is Neither About Gradients Nor About Checkpointing

The name gradient checkpointing is misleading. Checkpoint suggests saving something to disk. Gradient carries its usual meaning. Together, they suggest we persist gradients to storage. The technique does something else entirely.

Gradient checkpointing saves activations (the intermediate outputs of forward pass layers) and recomputes them during backpropagation instead of storing them all in memory. No gradients are checkpointed. Nothing is saved to disk. The technique trades compute for memory, letting you train larger models on limited hardware. So how did we end up with such a misleading name?

The Automatic Differentiation Origins

In 2000, Griewank and Walther published Algorithm 799: Revolve, describing optimal checkpointing schedules for reverse-mode automatic differentiation.[1] In this context, checkpoint refers to saving the computational state at specific nodes in a computation graph so you can recover intermediate values later without storing the entire trajectory.

The metaphor makes sense within AD theory: if you save the state at steps 5 and 10, those become your checkpoints, and you reconstruct step 7 by rolling forward from step 5. The problem is that checkpoint carries entirely different connotations outside this narrow domain.

When Tianqi Chen et al. brought the technique to deep learning in 2016 with Training Deep Nets with Sublinear Memory Cost, they inherited the AD terminology.[2] Their paper explicitly cites Griewank's work and even thanks David Warde-Farley "for pointing out the relation to gradient checkpointing." The name was never invented for neural networks.

Right after PyTorch implemented the technique as torch.utils.checkpoint, TensorFlow followed. The academic terminology became API names, and the confusion became institutionalized. In AD literature, checkpointing describes a memory management strategy for computation graphs. To everyone else, it sounds like saving model weights or gradient values to persistent storage.

The renaming gained momentum around 2021–2022, when models like GPT-3 pushed memory limits to the breaking point. NVIDIA's Megatron team made it explicit in their 2022 paper Reducing Activation Recomputation in Large Transformer Models.[3] The title says it all: the activations are what get recomputed. It described exactly what the GPU does with no ambiguity about gradients or disk checkpoints.

Meanwhile, JAX approached from a compiler perspective. When you delete a computed value and compute it again later, you're rematerializing it. JAX adopted remat as its API, and the term stuck in that ecosystem.

If you remember nothing else: when someone says gradient checkpointing, they mean recomputing activations. The gradients were never the point. And nothing is being checkpointed in any sense a normal person would recognize.

References

  1. Algorithm 799: Revolve ↗︎
    Griewank & Walther · ACM Trans. Math. Software 26(1) · 2000
  2. Training Deep Nets with Sublinear Memory Cost ↗︎
    Chen, Xu, Zhang & Guestrin · arXiv:1604.06174 · 2016