Assume there are two models to be used: X and Y. Data is sequentially passed through X and Y. Only the parameters of model Y need to be optimized with respect to a loss computed over output of model Y. Is the following snippet a correct implementation of this requirement. Few specific queries I need answers to:
- What does the
with torch.no_grad()
exactly do? - As only the parameters of model Y are registered with the optimizer do we still need to freeze model X to be correct or is it only required to reduce the computational load?
- More generally I want an explanation on how the computation graph and backpropagation behaves in the presence of
with torch.no_grad()
or when some layers are freezed by setting the correspondingrequires_grad
parameter to False. - Also comment on whether we can have non-consecutive layers in the network frozen at once.
optimizer = AdamW(model_Y.parameters(), lr= , eps= , ...)
optimizer.zero_grad()
with torch.no_grad():
A = model_X(data)
B = model_Y(A)
loss = some_function(B)
loss.backward()
optimizer.step()
CodePudding user response:
torch.no_grad
serves as a context manager that disables gradient computation. This is very useful for, e.g., inference where no use of .backward()
call will be done. It saves both memory and computations.
In your example, you can treat A
, the output of model_X
as your "inputs": you will not modify anything related to model_X
. In this case, you do not care about gradients w.r.t model_X
: not w.r.t its parameters nor w.r.t its inputs. It is safe to wrap this call A = model_X(data)
with the context of torch.no_grad()
.
However, in other cases, you might not want to modify the weights of model_X
("freeze" them), but you might still need to propagate gradients through them - if you want to modify elements feeding model_X
.
It's all in the chain rule.