Home > front end >  PyTorch: why does running output = model(images) use so much GPU memory?
PyTorch: why does running output = model(images) use so much GPU memory?

Time:05-28

In trying to understand why my maximum batch size is limited for my PyTorch model, I noticed that it's not the model itself nor loading the tensors onto the GPU that uses the most memory. Most memory is used up when generating a prediction for the first time, e.g. with the following line in the training loop:

output = model(images)

where images is some input tensor, and model is my PyTorch model. Before running the line, I have something like 9GB of GPU memory available, and afterwards I'm down to 2.5GB (it then further drops to 1GB available after running loss = criterion(outputs, labels).

Two questions:

  1. Is this normal?
  2. Why is it happening? What is all that memory being used for? From what I understand the model is already loaded in, and the actual input tensors are already on the GPU before making that call. The output tensors themselves can't be that big. Does it have something to do with storing the computational graph?

CodePudding user response:

This is normal: The key here is that all intermediate tensors (the whole computation graph) have to be stored if you want to compute the gradient via backward-mode differentiation. You can aviod that by using the .no_grad context manager:

with torch.no_grad():
    output = model(images)

You will observe that a lot less memory is used, because no computation graph will be stored. But this also means that you can't compute the derivatives anymore. It is however the standard way if you just want to evaluate the model without the need of any optimization.

There is one way to reduce the memory cosumption if you still want to optimize, and it is called checkpointing. Whenever you need an intermediate tensor in the backward pass, it will be computed again from the input (or actually from the last "checkpoint"), without storing an intermediate tensor up to that tensor. But this is just computationally more expensive. You're trading memory against computational time.

  • Related