After Hours Academic

How to make LLM checkpoint writes fast

Unless you have been living under a rock for the last couple of years or so, you already know what a large language models (LLM) is. For the rock-dwellers, here is a quick note: a LLM is a machine learning model (more specifically, a neural network) trained on a large corpus of data (think: all the text on the internet) to predict a sequence of words. The outcome is a seemingly "smart" model that you can converse with because it can form coherent sentences. ChatGPT, Gemini, etc are examples of LLMs. There is a whole lot of debate about whether LLMs are actually "smart" or even useful. But that's not what I am going to talk about.

I am interested in the systems that enable training LLMs. In particular, for this post, I am interested in checkpointing during training.

Why do we need checkpointing?

You see, these models are really large, with millions or even billions of parameters (a model is essentially a collection of parameters that control its computation). And these large models need to be trained on large compute clusters. Training on hundreds or thousands of GPUs is commonplace and some of the largest models are trained on hundreds of thousands of GPUs. And these training runs can span days, weeks, or even months because of the amount of data they have to process.

That's right, hundreds of thousands of GPUs burn for months so that we can ask an oracle about the meaning of life.

However, with large scale comes frequent failures. Public reports (like this one) talk about failures increase with increasing scale. With 131,072 nodes, the training job fails every 13 minutes! So the question becomes, how do the week long training jobs complete with such frequent failures?

Checkpointing is what enables long running training jobs to deal with failures. Every so often, the training job checkpoints its entire state. This allows a failed job to restart from the last checkpoint instead of starting from scratch. The checkpoints include the model state (i.e., the model parameters) as well the meta-state about the training job (e.g., the learning rate which controls how quickly the model's parameters change).

Why are fast checkpoint writes important?

Checkpointing is additional work for the training job and it is non-productive work in the sense that it does not move forward the training itself. This leads us to a metric known as effective training time ratio (ETTR) which is the ratio of time spent on productive training to the total time. The more time a training job spends on checkpointing, the lower the ETTR.

Intuitively, fast checkpoint writes would lead to a higher ETTR because the job would waste less time on the non-productive part of training. However, checkpoints need to be stored durably (on a hard disk or SSD) and writing to durable storage is typically slow (compared to compute used for updating model parameters in the productive part of training).

How to make checkpoint writes fast?

The most common technique used to speed up checkpoint writes was proposed in the CheckFreq paper. The key idea is that the training job running on the GPU can write the checkpoint to the attached CPU memory and continue with the training. The CPU can then write out the checkpoint to durable storage. This is referred to as asynchronous checkpointing. The training job can continue with the productive part of training without being blocked on the unproductive part of persisting the checkpoint. Thus from the training job's perspective, checkpoint writes just got a whole lot faster!

Asynchronous checkpointing is already prevalent today. Some recent papers try to go beyond just that.

Gemini proposes writing the checkpoint to multiple CPUs' memory instead of writing it out to persistent storage. With enough replicas of the checkpoint in different CPUs' memory, it can be considered persistent. I like to think of this as the RAMCloud equivalent of checkpoint storage.

PCheck aims to increase the checkpointing frequency to reduce the amount of computation that is lost in case of a failure. However, with a high even frequency, the CPU memory to persistent storage write starts becoming the bottleneck again. That is, the training job can't write a new checkpoint to CPU memory if the previous one hasn't been written out to storage yet. To address this, PCheck proposes that the training job can write copy multiple checkpoints into the CPU memory to unblock itself. The CPU, upon completing the write of a checkpoint to the persistent storage, picks the most recent checkpoint in its memory and starts writing that out.

FastPersist has four clever tricks to speed up the checkpoint writes. The one that I found the most interesting is using the fact that each part of the checkpoint state is replicated between a different subset of the GPUs involved in the training. The GPUs with the replicated checkpoint state can write it out in parallel effectively increasing the write throughput.

#checkpoints #computer-science #llm #machine-learning