Home > Net >  Pytorch input tensor is correct data type, but throws RuntimeError: Expected object of scalar type L
Pytorch input tensor is correct data type, but throws RuntimeError: Expected object of scalar type L

Time:09-17

I am trying to write a simple CNN using PyTorch, but am getting an error in my first layer. The following lines of code will produce the error RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #3 'mat1' in call to _th_addmm_, despite me setting the input to type Long. The print statement in the 3rd line also gives torch.LongTensor, yet the fourth line still throws the error.

import torch
import torch.nn as nn

data = torch.randint(low=0, high=255, size=[2, 1, 1024, 1024], dtype=torch.int64)
model = nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=False)
print(data.type())
out = model(data)

CodePudding user response:

The type of your data is Long, but the type of the weights of your model is Float. You need to change the type of your data if you are planning to train a model:

import torch
import torch.nn as nn

data = torch.randint(low=0, high=255, size=[2, 1, 1024, 1024], dtype=torch.float32)
model = nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=False)
print(data.type())
out = model(data)
  • Related