def train_fn(data_loader, model, optimizer):
model.train()
total_loss = 0.0
for images, masks in tqdm(data_loader):
images = images.to(DEVICE)
masks = masks.to(DEVICE)
optimizer.zero_grad()
logits, loss = model(images,masks)
loss.backward()
optimizer.step()
total_loss = loss.item()
return total_loss/ len(data_loader)
def eval_fn(data_loader, model):
model.eval()
total_loss = 0.0
with torch.no_grad():
for images, masks in tqdm(data_loader):
images = images.to(DEVICE)
masks = masks.to(DEVICE)
logits, loss = model(images,masks)
total_loss = loss.item()
return total_loss/ len(data_loader)
optimizer = torch.optim.Adam(model.parameters(), lr = LR)
best_valid_loss = np.Inf
for i in range(EPOCHS):
train_loss = train_fn(trainloader, model, optimizer)
valid_loss = eval_fn(validloader, model)
if valid_loss < best_valid_loss:
torch.save(model.state_dict(), 'best_model.pt')
print("SAVED_MODEL")
best_valid_loss = valid_loss
print(f"Epoch : {i 1} Train_loss: {train_loss} Valid_loss: {valid_loss}")
I get the following error when I try to train the model:
0%| | 0/15 [00:00<?, ?it/s]
NotImplementedError Traceback (most recent call last) in () 4 5 ----> 6 train_loss = train_fn(trainloader, model, optimizer) 7 valid_loss = eval_fn(validloader, model) 8
2 frames /usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _forward_unimplemented(self, *input) 199 registered hooks while the latter silently ignores them. 200 """ --> 201 # raise NotImplementedError 202 203
NotImplementedError:
How do I deal with this?
CodePudding user response:
Looking at the link you provided in the comment, your model definition looks like this:
class SegmentationModel(nn.Module):
def __init__(self):
super(SegmentationModel,self).__init__()
self.arc = smp.Unet(
encoder_name = ENCODER,
encoder_weights = WEIGHTS,
in_channels = 3,
classes = 1,
activation = None
)
def forward(self, images, masks = None):
logits = self.arc(images)
if masks != None:
loss1 = DiceLoss(mode = 'binary')(logits, masks)
loss2 = nn.BCEWithLogitsLoss()(logits,masks)
return logits, loss1 loss2
return logits
If you look close, you'll see forward()
has an erratic extra indentation, making it an internal function inside __init__()
rather than a method of a SegmentationModel
. Shift it a bit to left, and it should work fine:
class SegmentationModel(nn.Module):
def __init__(self):
super(SegmentationModel,self).__init__()
self.arc = smp.Unet(
encoder_name = ENCODER,
encoder_weights = WEIGHTS,
in_channels = 3,
classes = 1,
activation = None
)
def forward(self, images, masks = None):
logits = self.arc(images)
if masks != None:
loss1 = DiceLoss(mode = 'binary')(logits, masks)
loss2 = nn.BCEWithLogitsLoss()(logits,masks)
return logits, loss1 loss2
return logits