Home > database >  How to train a model with loss calculated by another model in pytorch?
How to train a model with loss calculated by another model in pytorch?

Time:09-22

There are two models A and B. Model A outputs a bike deployment plan for all stations in a city, and model B takes this plan as input and gives the evaluation of each station.

Now, the model B is pretrained, and i want to use the evaluation given by model B as loss to optimize parameters of model A.

Here is the sample code.

A = modelA()
B = modelB()

optimizer = torch.optim.Adam(A.parameters())

def my_loss(deploy):
  shape = deploy.size()
  state = torch.zeros((shape[0], shape[1], 2   shape[1]), dtype=torch.long)

  # Notice: this step will copy deploy
  state[:, :, 2:] = torch.reshape(deploy, (shape[0], 1, shape[1]))
  state[:, :, 0] = torch.arange(0, shape[1])

  state = torch.reshape(state, (-1, 2   shape[1]))
  eval = B(state)
  eval = torch.reshape(eval, (shape[0], shape[1]))

  return torch.mean(eval)

# Train model A
for epoch in range(EPOCHS):
  for batch_idx, (x, useless_y) in enumerate(dataloader):
    optimizer.zero_gard()
    pred = A(x)
    loss = my_loss(pred)
    loss.backward()
    optimizer.step()

But in fact, during training, nothing happens, parameters of model A is not updated. I also tried

optimizer = torch.optim.Adam([{'params': A.parameters()}, {'params': B.parameters(), 'lr':0}])

and nothing happens too.

Any ideas?

CodePudding user response:

The reason why you have no update on the parameters of A is because loss, the result of my my_loss is not attached to the computation graph, i.e. to the output of model A. If you have a look at your implementation, this is clearly the case:

    state = torch.reshape(state_tensors, (-1, 2   shape[1]))
    eval = B(state)
    eval = torch.reshape(eval, (shape[0], shape[1]))

Variable state is defined from state_tensors which is not defined in your example, so it is either a typo or it is meant to be defined from state. Anyway, the resulting tensor my_loss ouputs should have a grad_fn attached to it. Do make sure it is the case. The last operation here is an average, so you must have something like:

>>> my_loss(loss)
tensor(0.7742, grad_fn=<MeanBackward0>)

After fixing this your gradient should be able to backpropagate to the parameters of model A.

CodePudding user response:

The computational graph is cut off at state, so the loss does not back propagate to A.

Try;

state = torch.zeros((shape[0], shape[1], 2   shape[1]), dtype=torch.long)
->
state = torch.zeros((shape[0], shape[1], 2   shape[1]), requires_grad=True) # add requires_grad=True, dtype=torch.long may throw error

However, I still don't think it will work with your code. Optional suggestion;

  • I think state_tensors is not defined.
  • In-place operation state[:, :, 2:] = may not good. (In my case, this throw error) To copy tensor, .expand() or .repeat(), and to expand dim, .unsqueeze() may useful to avoid this.
  • Related