Home > Net >  How to deal with NotImplementedError in training Unet model?
How to deal with NotImplementedError in training Unet model?

Time:07-07

def train_fn(data_loader, model, optimizer):

model.train()
total_loss = 0.0

for images, masks in tqdm(data_loader):

  images = images.to(DEVICE)
  masks = masks.to(DEVICE)

  optimizer.zero_grad()
  logits, loss = model(images,masks)
  loss.backward()
  optimizer.step()

  total_loss  = loss.item()



return total_loss/ len(data_loader)


def eval_fn(data_loader, model):

model.eval()
total_loss = 0.0

with torch.no_grad():

  for images, masks in tqdm(data_loader):

    images = images.to(DEVICE)
    masks = masks.to(DEVICE)

    logits, loss = model(images,masks)


    total_loss  = loss.item()


return total_loss/ len(data_loader)

optimizer = torch.optim.Adam(model.parameters(), lr = LR)

best_valid_loss = np.Inf

for i in range(EPOCHS):


train_loss = train_fn(trainloader, model, optimizer)
valid_loss = eval_fn(validloader, model)

if valid_loss < best_valid_loss:
  torch.save(model.state_dict(), 'best_model.pt')
  print("SAVED_MODEL")
  best_valid_loss = valid_loss

print(f"Epoch : {i 1} Train_loss: {train_loss} Valid_loss: {valid_loss}")

I get the following error when I try to train the model:

0%| | 0/15 [00:00<?, ?it/s]

NotImplementedError Traceback (most recent call last) in () 4 5 ----> 6 train_loss = train_fn(trainloader, model, optimizer) 7 valid_loss = eval_fn(validloader, model) 8

2 frames /usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _forward_unimplemented(self, *input) 199 registered hooks while the latter silently ignores them. 200 """ --> 201 # raise NotImplementedError 202 203

NotImplementedError:

How do I deal with this?

CodePudding user response:

Looking at the link you provided in the comment, your model definition looks like this:

class SegmentationModel(nn.Module):

  def __init__(self):
    super(SegmentationModel,self).__init__()

    self.arc = smp.Unet(
        encoder_name = ENCODER,
        encoder_weights = WEIGHTS,
        in_channels = 3,
        classes = 1,
        activation = None
    )

    def forward(self, images, masks = None):
      logits = self.arc(images)

      if masks != None:
        loss1 = DiceLoss(mode = 'binary')(logits, masks)
        loss2 = nn.BCEWithLogitsLoss()(logits,masks)
        return logits, loss1   loss2

      return logits

If you look close, you'll see forward() has an erratic extra indentation, making it an internal function inside __init__() rather than a method of a SegmentationModel. Shift it a bit to left, and it should work fine:

class SegmentationModel(nn.Module):

  def __init__(self):
    super(SegmentationModel,self).__init__()

    self.arc = smp.Unet(
        encoder_name = ENCODER,
        encoder_weights = WEIGHTS,
        in_channels = 3,
        classes = 1,
        activation = None
    )

  def forward(self, images, masks = None):
    logits = self.arc(images)

    if masks != None:
      loss1 = DiceLoss(mode = 'binary')(logits, masks)
      loss2 = nn.BCEWithLogitsLoss()(logits,masks)
      return logits, loss1   loss2

    return logits
  • Related