Home > Blockchain >  Getting RESNet18 to work with float32 data [duplicate]
Getting RESNet18 to work with float32 data [duplicate]

Time:09-28

I have float32 data that I am trying to get RESNet18 to work with. I am using the RESNet model in torchvision (and using pytorch lightning) and modified it to use one layer (grayscale) data like so:

class ResNetMSTAR(pl.LightningModule):
def __init__(self):
  super().__init__()
  # define model and loss
  self.model = resnet18(num_classes=3)
  self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  self.loss = nn.CrossEntropyLoss()

@auto_move_data # this decorator automatically handles moving your tensors to GPU if required
def forward(self, x):
return self.model(x)

def training_step(self, batch, batch_no):
  # implement single training step
  x, y = batch
  logits = self(x)
  loss = self.loss(logits, y)
  return loss

def configure_optimizers(self):
  # choose your optimizer
  return torch.optim.RMSprop(self.parameters(), lr=0.005)

When I try to run this model I am getting the following error:

File "/usr/local/lib64/python3.6/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward

Is there anything that I can do differently to keep this error from happening?

CodePudding user response:

The problem is that the y your feeding your cross entropy loss, is not a LongTensor, but a FloatTensor. CrossEntropy expects getting fed a LongTensor for the target, and raises the error.

This is an ugly fix:

x, y = batch
y = y.long()

But what I recommend you to do is going to where the dataset is defined, and make sure you are generating long targets, this way you won't reproduce this error if you change how your training loop is working.

  • Related