There are two models A and B. Model A outputs a bike deployment plan for all stations in a city, and model B takes this plan as input and gives the evaluation of each station.
Now, the model B is pretrained, and i want to use the evaluation given by model B as loss to optimize parameters of model A.
Here is the sample code.
A = modelA()
B = modelB()
optimizer = torch.optim.Adam(A.parameters())
def my_loss(deploy):
shape = deploy.size()
state = torch.zeros((shape[0], shape[1], 2 shape[1]), dtype=torch.long)
# Notice: this step will copy deploy
state[:, :, 2:] = torch.reshape(deploy, (shape[0], 1, shape[1]))
state[:, :, 0] = torch.arange(0, shape[1])
state = torch.reshape(state, (-1, 2 shape[1]))
eval = B(state)
eval = torch.reshape(eval, (shape[0], shape[1]))
return torch.mean(eval)
# Train model A
for epoch in range(EPOCHS):
for batch_idx, (x, useless_y) in enumerate(dataloader):
optimizer.zero_gard()
pred = A(x)
loss = my_loss(pred)
loss.backward()
optimizer.step()
But in fact, during training, nothing happens, parameters of model A is not updated. I also tried
optimizer = torch.optim.Adam([{'params': A.parameters()}, {'params': B.parameters(), 'lr':0}])
and nothing happens too.
Any ideas?
CodePudding user response:
The reason why you have no update on the parameters of A
is because loss
, the result of my my_loss
is not attached to the computation graph, i.e. to the output of model A
. If you have a look at your implementation, this is clearly the case:
state = torch.reshape(state_tensors, (-1, 2 shape[1]))
eval = B(state)
eval = torch.reshape(eval, (shape[0], shape[1]))
Variable state
is defined from state_tensors
which is not defined in your example, so it is either a typo or it is meant to be defined from state
. Anyway, the resulting tensor my_loss
ouputs should have a grad_fn
attached to it. Do make sure it is the case. The last operation here is an average, so you must have something like:
>>> my_loss(loss)
tensor(0.7742, grad_fn=<MeanBackward0>)
After fixing this your gradient should be able to backpropagate to the parameters of model A
.
CodePudding user response:
The computational graph is cut off at state
, so the loss does not back propagate to A.
Try;
state = torch.zeros((shape[0], shape[1], 2 shape[1]), dtype=torch.long)
->
state = torch.zeros((shape[0], shape[1], 2 shape[1]), requires_grad=True) # add requires_grad=True, dtype=torch.long may throw error
However, I still don't think it will work with your code. Optional suggestion;
- I think
state_tensors
is not defined. - In-place operation
state[:, :, 2:] =
may not good. (In my case, this throw error) To copy tensor,.expand()
or.repeat()
, and to expand dim,.unsqueeze()
may useful to avoid this.