I'm trying to develop a program that finds road lanes using semantic segmentation with UNet backend. But while training the model, it's giving me dice score more than 1. Why is this happening?
Batch size: 16
Num_workers: 2
Epochs: 50
IMAGE_HEIGHT = 80
IMAGE_WIDTH = 120
PIN_MEMORY = True
Here's my accuracy function:
def check_accuracy(loader, model, device="cpu"):
num_correct = 0
num_pixels = 0
dice_score = 0
model.eval()
with torch.no_grad():
for x, y in loader:
x = x.to(device)
y = y.to(device).unsqueeze(1)
preds = torch.sigmoid(model(x))
preds = (preds > 0.5).float()
num_correct = (preds == y).sum()
num_pixels = torch.numel(preds)
dice_score = (2 * (preds * y).sum()) / ((preds y).sum() 1e-8)
print(f"Got {num_correct/num_pixels} with accuracy {num_correct/num_pixels*100:.2f}")
print(f"Dice score: {dice_score/len(loader)}")
model.train()
x
's are images, and y
's are ground truths.
Here's an example for images and ground truths:
Here's my dataset's __getitem__
method:
def __getitem__(self, index):
img_path = os.path.join(self.root_dir, self.images[index])
mask_path = os.path.join(self.val_dir, self.images[index])
image = np.array(Image.open(img_path).convert("RGB"))
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
mask[mask == 255.0] = 1.0
if self.transform is not None:
augmentations = self.transform(image=image, mask=mask)
image = augmentations["image"]
mask = augmentations["mask"]
return image, mask
Here's my transformations:
train_transform = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
A.Rotate(limit=35, p=1.0),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.1),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)
val_transforms = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)
And these are the accuracy metrics:
Epoch: 1/50
100%|██████████| 1/1 [00:00<00:00, 1.99it/s, loss=1.13]
=> Saving checkpoint
Epoch: 2/50
100%|██████████| 1/1 [00:00<00:00, 1.81it/s, loss=1.13]
=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0
Epoch: 3/50
100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=0.7]
=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0
Epoch: 4/50
100%|██████████| 1/1 [00:00<00:00, 1.87it/s, loss=0.354]
=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0
Epoch: 5/50
100%|██████████| 1/1 [00:00<00:00, 1.91it/s, loss=-.094]
=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0
Epoch: 6/50
100%|██████████| 1/1 [00:00<00:00, 1.94it/s, loss=-.419]
=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0
Epoch: 7/50
100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-.914]
=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0
Epoch: 8/50
100%|██████████| 1/1 [00:00<00:00, 1.92it/s, loss=-.7]
=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0
Epoch: 9/50
100%|██████████| 1/1 [00:00<00:00, 1.87it/s, loss=-1.26]
=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0
Epoch: 10/50
100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-1.6]
=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0
Epoch: 11/50
100%|██████████| 1/1 [00:00<00:00, 1.96it/s, loss=-2.04]
=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0
Epoch: 12/50
100%|██████████| 1/1 [00:00<00:00, 1.89it/s, loss=-2.53]
=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0
Epoch: 13/50
100%|██████████| 1/1 [00:00<00:00, 1.90it/s, loss=-2.77]
=> Saving checkpoint Got 0.9814687371253967 with accuracy 98.15 Dice score: 0.0
Epoch: 14/50
100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-3.14]
=> Saving checkpoint Got 0.9801874756813049 with accuracy 98.02 Dice score: 0.0
Epoch: 15/50
100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-3.77]
=> Saving checkpoint Got 0.9771562218666077 with accuracy 97.72 Dice score: 0.0
Epoch: 16/50
100%|██████████| 1/1 [00:00<00:00, 1.85it/s, loss=-3.95]
=> Saving checkpoint Got 0.972906231880188 with accuracy 97.29 Dice score: 0.013821512460708618
Epoch: 17/50
100%|██████████| 1/1 [00:00<00:00, 1.92it/s, loss=-4.83]
=> Saving checkpoint Got 0.9672812223434448 with accuracy 96.73 Dice score: 0.09691906720399857
Epoch: 18/50
100%|██████████| 1/1 [00:00<00:00, 1.96it/s, loss=-4.96]
=> Saving checkpoint Got 0.9596353769302368 with accuracy 95.96 Dice score: 0.2153720110654831
Epoch: 19/50
100%|██████████| 1/1 [00:00<00:00, 1.86it/s, loss=-5.43]
=> Saving checkpoint Got 0.9514479041099548 with accuracy 95.14 Dice score: 0.34086111187934875
Epoch: 20/50
100%|██████████| 1/1 [00:00<00:00, 1.98it/s, loss=-5.81]
=> Saving checkpoint Got 0.9496874809265137 with accuracy 94.97 Dice score: 0.385390967130661
Epoch: 21/50
100%|██████████| 1/1 [00:00<00:00, 1.91it/s, loss=-5.86]
=> Saving checkpoint Got 0.9429270625114441 with accuracy 94.29 Dice score: 0.4814487397670746
Epoch: 22/50
100%|██████████| 1/1 [00:00<00:00, 1.98it/s, loss=-6.34]
=> Saving checkpoint Got 0.9388750195503235 with accuracy 93.89 Dice score: 0.5995486974716187
Epoch: 23/50
100%|██████████| 1/1 [00:00<00:00, 1.90it/s, loss=-6.92]
=> Saving checkpoint Got 0.9380520582199097 with accuracy 93.81 Dice score: 0.7058220505714417
Epoch: 24/50
100%|██████████| 1/1 [00:00<00:00, 1.90it/s, loss=-7.42]
=> Saving checkpoint Got 0.9415416717529297 with accuracy 94.15 Dice score: 0.8273581266403198
Epoch: 25/50
100%|██████████| 1/1 [00:00<00:00, 1.94it/s, loss=-7.84]
=> Saving checkpoint Got 0.9451770782470703 with accuracy 94.52 Dice score: 0.9627659916877747
Epoch: 26/50
100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-8.32]
=> Saving checkpoint Got 0.9467916488647461 with accuracy 94.68 Dice score: 1.1096603870391846
Epoch: 27/50
100%|██████████| 1/1 [00:00<00:00, 1.84it/s, loss=-8.75]
=> Saving checkpoint Got 0.9468228816986084 with accuracy 94.68 Dice score: 1.2025865316390991
Epoch: 28/50
100%|██████████| 1/1 [00:00<00:00, 1.92it/s, loss=-8.6]
=> Saving checkpoint Got 0.942562460899353 with accuracy 94.26 Dice score: 1.3215676546096802
Epoch: 29/50
100%|██████████| 1/1 [00:00<00:00, 1.95it/s, loss=-9.3]
=> Saving checkpoint Got 0.9366874694824219 with accuracy 93.67 Dice score: 1.4410816431045532
Epoch: 30/50
100%|██████████| 1/1 [00:00<00:00, 1.85it/s, loss=-9.37]
=> Saving checkpoint Got 0.9291979074478149 with accuracy 92.92 Dice score: 1.5965386629104614
Epoch: 31/50
100%|██████████| 1/1 [00:00<00:00, 1.89it/s, loss=-9.29]
=> Saving checkpoint Got 0.9251145720481873 with accuracy 92.51 Dice score: 1.7157979011535645
Epoch: 32/50
100%|██████████| 1/1 [00:00<00:00, 1.89it/s, loss=-9]
=> Saving checkpoint Got 0.9198125004768372 with accuracy 91.98 Dice score: 1.7378650903701782
Epoch: 33/50
100%|██████████| 1/1 [00:00<00:00, 1.88it/s, loss=-9.82]
=> Saving checkpoint Got 0.9136666655540466 with accuracy 91.37 Dice score: 1.758676290512085
Epoch: 34/50
100%|██████████| 1/1 [00:00<00:00, 1.85it/s, loss=-9.84]
=> Saving checkpoint Got 0.9044270515441895 with accuracy 90.44 Dice score: 1.7790669202804565
Epoch: 35/50
100%|██████████| 1/1 [00:00<00:00, 1.92it/s, loss=-10.4]
=> Saving checkpoint Got 0.8995833396911621 with accuracy 89.96 Dice score: 1.7764993906021118
Epoch: 36/50
100%|██████████| 1/1 [00:00<00:00, 1.95it/s, loss=-10.7]
=> Saving checkpoint Got 0.8992916345596313 with accuracy 89.93 Dice score: 1.775970697402954
Epoch: 37/50
100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-10.8]
=> Saving checkpoint Got 0.8976249694824219 with accuracy 89.76 Dice score: 1.77357017993927
Epoch: 38/50
100%|██████████| 1/1 [00:00<00:00, 1.95it/s, loss=-10.5]
=> Saving checkpoint Got 0.8952499628067017 with accuracy 89.52 Dice score: 1.782752513885498
Epoch: 39/50
100%|██████████| 1/1 [00:00<00:00, 1.90it/s, loss=-10.8]
=> Saving checkpoint Got 0.8901249766349792 with accuracy 89.01 Dice score: 1.7879419326782227
Epoch: 40/50
100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-11]
=> Saving checkpoint Got 0.8890937566757202 with accuracy 88.91 Dice score: 1.7865288257598877
Epoch: 41/50
100%|██████████| 1/1 [00:00<00:00, 1.90it/s, loss=-11.2]
=> Saving checkpoint Got 0.8915520906448364 with accuracy 89.16 Dice score: 1.7899266481399536
Epoch: 42/50
100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-11.1]
=> Saving checkpoint Got 0.8923645615577698 with accuracy 89.24 Dice score: 1.7993781566619873
Epoch: 43/50
100%|██████████| 1/1 [00:00<00:00, 1.96it/s, loss=-10.9]
=> Saving checkpoint Got 0.887333333492279 with accuracy 88.73 Dice score: 1.804529070854187
Epoch: 44/50
100%|██████████| 1/1 [00:00<00:00, 1.85it/s, loss=-11.9]
=> Saving checkpoint Got 0.8823854327201843 with accuracy 88.24 Dice score: 1.7975029945373535
Epoch: 45/50
100%|██████████| 1/1 [00:00<00:00, 1.91it/s, loss=-11.2]
=> Saving checkpoint Got 0.8859270811080933 with accuracy 88.59 Dice score: 1.802530288696289
Epoch: 46/50
100%|██████████| 1/1 [00:00<00:00, 1.92it/s, loss=-11.6]
=> Saving checkpoint Got 0.890500009059906 with accuracy 89.05 Dice score: 1.8090462684631348
Epoch: 47/50
100%|██████████| 1/1 [00:00<00:00, 1.90it/s, loss=-11.7]
=> Saving checkpoint Got 0.9020833373069763 with accuracy 90.21 Dice score: 1.8257639408111572
Epoch: 48/50
100%|██████████| 1/1 [00:00<00:00, 1.81it/s, loss=-12.1]
=> Saving checkpoint Got 0.9160521030426025 with accuracy 91.61 Dice score: 1.8463400602340698
Epoch: 49/50
100%|██████████| 1/1 [00:00<00:00, 1.91it/s, loss=-12]
=> Saving checkpoint Got 0.925125002861023 with accuracy 92.51 Dice score: 1.859954833984375
Epoch: 50/50
100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-11.4]
=> Saving checkpoint Got 0.933968722820282 with accuracy 93.40 Dice score: 1.873431921005249
As you can see 50th dice score is 1.873431921005249
. Why dice score tis more than 1?
CodePudding user response:
The equation you are using to calculate the dice coefficient is wrong. This will work.
dice_score = (2 * (preds * y).sum()) / (2 * (preds * y).sum() ((preds*y)<1).sum())
You can interpret it as 2 x correct_classified/(2 x correct_classified wrong_classified). Note that this only works in the binary case.