Home > Software engineering >  Pytorch is throwing an error RuntimeError: result type Float can't be cast to the desired outpu
Pytorch is throwing an error RuntimeError: result type Float can't be cast to the desired outpu

Time:12-05

How should I get rid of the following error?

>>> t = torch.tensor([[1, 0, 1, 1]]).T
>>> p = torch.rand(4,1)
>>> torch.nn.BCEWithLogitsLoss()(p, t)

The above code is throwing the following error:

RuntimeError: result type Float can't be cast to the desired output type Long

CodePudding user response:

BCEWithLogitsLoss requires its target to be a float tensor, not long. So you should specify the type of t tensor by dtype=torch.float32:

import torch

t = torch.tensor([[1, 0, 1, 1]], dtype=torch.float32).T
p = torch.rand(4,1)
loss_fn = torch.nn.BCEWithLogitsLoss()

print(loss_fn(p, t))

Output:

tensor(0.5207)
  • Related