I am trying to write a simple CNN using PyTorch, but am getting an error in my first layer. The following lines of code will produce the error RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #3 'mat1' in call to _th_addmm_
, despite me setting the input to type Long. The print statement in the 3rd line also gives torch.LongTensor
, yet the fourth line still throws the error.
import torch
import torch.nn as nn
data = torch.randint(low=0, high=255, size=[2, 1, 1024, 1024], dtype=torch.int64)
model = nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=False)
print(data.type())
out = model(data)
CodePudding user response:
The type of your data is Long
, but the type of the weights
of your model is Float
. You need to change the type of your data if you are planning to train a model:
import torch
import torch.nn as nn
data = torch.randint(low=0, high=255, size=[2, 1, 1024, 1024], dtype=torch.float32)
model = nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=False)
print(data.type())
out = model(data)