Keep Warm with Portable torch.compile Caches

April 06, 2025
python torch

PyTorch's compile function improves performance of your code by compiling and caching the computational graph for later use. torch.compile has a notion of a "cold compile" and a "warm compile", where "cold" is the first run and "warm" is faster by using a cache. The latest torch nightly wheel introduces a portable caching solution loadable on separate machines. In this post, we learn about this new caching solution and how it keeps your cache warm!

Cold vs Warm Compile Runtimes

In this example, we run torch.compile on a ResNet152 nn.Module without a cache and measure the cold start run:

from torchvision.models import resnet152

model = resnet152().to(torch.float32).cuda()
model_opt = torch.compile(model, mode="reduce-overhead", fullgraph=True)
X = torch.randn(32, 3, 128, 128).to(torch.float32).cuda()

timed(model_opt, X)
# Duration: 21.1s

The first run takes 21.1s, because torch.compile is compiling from scratch. When we run again, the cache is warm and the duration goes down to 8.9s which is 2x faster than a cold start!

Portable Cache

What if we wanted to populate the cache once on one machine and use that to warm up the cache for many machines? I have found success by adjusting the cache directory using TORCHINDUCTOR_CACHE_DIR, saving the entire directory, and loading the directory on the new machine. With the latest torch nightly, there is a new option: torch.compiler.save_cache_artifacts and torch.compiler.load_cache_artifacts.

With torch.compiler.save_cache_artifacts, we have a Python API for saving the cache as a binary file. In this example, we save the artifact as "artifact.bin":

artifact_path = Path("artifact.bin")

artifact_bytes, _ = torch.compiler.save_cache_artifacts()
save_artifact_path.write_bytes(artifact_bytes)

On a separate machine, we load the "artifact.bin" which warms up the cache!

artifact_bytes = artifact_path.read_bytes()
torch.compiler.load_cache_artifacts(artifact_bytes)

To learn more, refer to the documentation for save_cache_artifacts and load_cache_artifacts.

Conclusion

Now that frameworks, such as SGLang and vLLM, use torch.compile to speed up workloads, it is also important to optimize for startup times so we can scale up faster to meet demand. Keeping the torch cache warm is one dimension where we can improve startup times. The new {save,load}_cache_artifact API gives us a more portable way to keep the cache warm on multiple machines.

Similar Posts

03/16/25
PyTorch Graphs Three Ways: Data-Dependent Control Flow
12/27/23
Python Extensions in Rust with Jupyter Notebooks
08/15/23
Quick NumPy UFuncs with Cython 3.0
05/14/23
Accessing Data from Python's DataFrame Interchange Protocol
09/12/18
Survival Regression Analysis on Customer Churn