← all posts

Diagnosing a PyTorch Memory Leak on Apple MPS

2 min read

Diagnosing a PyTorch Memory Leak on Apple MPS

When we started testing Stormlog on Apple Silicon, everything worked as expected on CUDA. The memory timeline was clean, allocations were predictable, and the leak detector stayed quiet. Then we switched to the MPS backend.

The symptom

During a standard training loop with mixed-precision autocast enabled, GPU memory usage climbed steadily across epochs. Not dramatically; maybe 50MB per epoch on a 16GB M2 Max. Easy to miss if you're not watching closely. But over a long training run, it would eventually exhaust available memory and crash.

What made it hard to find

The leak only appeared under a specific combination of conditions: MPS backend, autocast enabled, and tensors moved between MPS and CPU within the same training step. Remove any one of those three conditions and the leak disappeared.

This is the kind of bug that survives manual testing because no single condition is unusual. Mixed-precision training is standard practice. Moving tensors between devices happens during data loading and metric computation. It's only the intersection that triggers the problem.

How Stormlog helped

Stormlog's allocation timeline made the pattern visible. Each training step showed a small residual allocation that was never freed: a cached autocast tensor that the MPS allocator held a reference to after the tensor was copied to CPU. The tensor-level breakdown tab showed the exact allocation site.

The fix

The fix was straightforward once we understood the cause: explicitly clearing the autocast cache after cross-device tensor operations. But finding the cause without a profiler that tracks individual allocations would have been significantly harder.

What I took away

Memory leaks in ML training are rarely dramatic. They're slow, conditional, and only surface under realistic training conditions. Building profiling tools that make these patterns visible is the kind of infrastructure work that compounds over time.