Home > Software engineering >  PyTorch memory leak reference cycle in for loop
PyTorch memory leak reference cycle in for loop

Time:12-27

I am facing a memory leak when iteratively updating tensors in PyTorch on my Mac M1 GPU using the PyTorch mps interface. The following is a minimal reproducible example that replicates the behavior:

import torch 

def leak_example(p1, device):
    
    t1 = torch.rand_like(p1, device = device) # torch.cat((torch.diff(ubar.detach(), dim=0).detach().clone(), torch.zeros_like(ubar.detach()[:1,:,:,:], dtype = torch.float32)), dim = 0)
    u1 = p1.detach()   2 * (t1.detach())
    
    B = torch.rand_like(u1, device = device)
    mask = u1 < B
    
    a1 = u1.detach().clone()
    a1[~mask] = torch.rand_like(a1)[~mask]
    return a1

if torch.cuda.is_available(): # cuda gpus
    device = torch.device("cuda")
elif torch.backends.mps.is_available(): # mac gpus
    device = torch.device("mps")
torch.set_grad_enabled(False)
        
p1 = torch.rand(5, 5, 224, 224, device = device)
for i in range(10000):
    p1 = leak_example(p1, device)    

My Mac's GPU memory steadily grows when I execute this loop. I have tried running it on a CUDA GPU in Google Colab and it seems to be behaving similarly, with the GPU's Active memory, Non-releasable memory, and Allocated memory increasing as the loop progresses.

I have tried detaching and cloning the tensors and using weakrefs, to no avail. Interestingly, if I don't reassign the output of leak_example to p1, the behavior disappears, so it really seems related to the recursive assignment. Does anyone have any idea how I could resolve this?

CodePudding user response:

I think I found the cause of the leak, it was the masked assignment. Replacing it with an equivalent torch.where() statement makes the leak disappear. I imagine this is related to masked_scatter not being implemented for MPS support in PyTorch (yet)?

  • Related