Home > Back-end >  PyTorch UNet semantic segmentation dice score more than 1
PyTorch UNet semantic segmentation dice score more than 1

Time:07-09

I'm trying to develop a program that finds road lanes using semantic segmentation with UNet backend. But while training the model, it's giving me dice score more than 1. Why is this happening?

Batch size: 16

Num_workers: 2

Epochs: 50

IMAGE_HEIGHT = 80

IMAGE_WIDTH = 120

PIN_MEMORY = True

Here's my accuracy function:

def check_accuracy(loader, model, device="cpu"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct  = (preds == y).sum()
            num_pixels  = torch.numel(preds)
            dice_score  = (2 * (preds * y).sum()) / ((preds   y).sum()   1e-8)

        print(f"Got {num_correct/num_pixels} with accuracy {num_correct/num_pixels*100:.2f}")
        print(f"Dice score: {dice_score/len(loader)}")
        model.train()

x's are images, and y's are ground truths.

Here's an example for images and ground truths:

Image: enter image description here

Ground truth: enter image description here

Here's my dataset's __getitem__ method:

    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.images[index])
        mask_path = os.path.join(self.val_dir, self.images[index])
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

Here's my transformations:

    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

And these are the accuracy metrics:

Epoch: 1/50

100%|██████████| 1/1 [00:00<00:00, 1.99it/s, loss=1.13]

=> Saving checkpoint 

Epoch: 2/50

100%|██████████| 1/1 [00:00<00:00, 1.81it/s, loss=1.13]

=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0 

Epoch: 3/50

100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=0.7]

=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0 

Epoch: 4/50

100%|██████████| 1/1 [00:00<00:00, 1.87it/s, loss=0.354]

=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0 

Epoch: 5/50

100%|██████████| 1/1 [00:00<00:00, 1.91it/s, loss=-.094]

=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0 

Epoch: 6/50

100%|██████████| 1/1 [00:00<00:00, 1.94it/s, loss=-.419]

=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0 

Epoch: 7/50

100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-.914]

=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0 

Epoch: 8/50

100%|██████████| 1/1 [00:00<00:00, 1.92it/s, loss=-.7]

=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0 

Epoch: 9/50

100%|██████████| 1/1 [00:00<00:00, 1.87it/s, loss=-1.26]

=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0 

Epoch: 10/50

100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-1.6]

=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0 

Epoch: 11/50

100%|██████████| 1/1 [00:00<00:00, 1.96it/s, loss=-2.04]

=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0 

Epoch: 12/50

100%|██████████| 1/1 [00:00<00:00, 1.89it/s, loss=-2.53]

=> Saving checkpoint Got 0.9816353917121887 with accuracy 98.16 Dice score: 0.0 

Epoch: 13/50

100%|██████████| 1/1 [00:00<00:00, 1.90it/s, loss=-2.77]

=> Saving checkpoint Got 0.9814687371253967 with accuracy 98.15 Dice score: 0.0 

Epoch: 14/50

100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-3.14]

=> Saving checkpoint Got 0.9801874756813049 with accuracy 98.02 Dice score: 0.0 

Epoch: 15/50

100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-3.77]

=> Saving checkpoint Got 0.9771562218666077 with accuracy 97.72 Dice score: 0.0 

Epoch: 16/50

100%|██████████| 1/1 [00:00<00:00, 1.85it/s, loss=-3.95]

=> Saving checkpoint Got 0.972906231880188 with accuracy 97.29 Dice score: 0.013821512460708618 

Epoch: 17/50

100%|██████████| 1/1 [00:00<00:00, 1.92it/s, loss=-4.83]

=> Saving checkpoint Got 0.9672812223434448 with accuracy 96.73 Dice score: 0.09691906720399857 

Epoch: 18/50

100%|██████████| 1/1 [00:00<00:00, 1.96it/s, loss=-4.96]

=> Saving checkpoint Got 0.9596353769302368 with accuracy 95.96 Dice score: 0.2153720110654831 

Epoch: 19/50

100%|██████████| 1/1 [00:00<00:00, 1.86it/s, loss=-5.43]

=> Saving checkpoint Got 0.9514479041099548 with accuracy 95.14 Dice score: 0.34086111187934875 

Epoch: 20/50

100%|██████████| 1/1 [00:00<00:00, 1.98it/s, loss=-5.81]

=> Saving checkpoint Got 0.9496874809265137 with accuracy 94.97 Dice score: 0.385390967130661 

Epoch: 21/50

100%|██████████| 1/1 [00:00<00:00, 1.91it/s, loss=-5.86]

=> Saving checkpoint Got 0.9429270625114441 with accuracy 94.29 Dice score: 0.4814487397670746 

Epoch: 22/50

100%|██████████| 1/1 [00:00<00:00, 1.98it/s, loss=-6.34]

=> Saving checkpoint Got 0.9388750195503235 with accuracy 93.89 Dice score: 0.5995486974716187 

Epoch: 23/50

100%|██████████| 1/1 [00:00<00:00, 1.90it/s, loss=-6.92]

=> Saving checkpoint Got 0.9380520582199097 with accuracy 93.81 Dice score: 0.7058220505714417 

Epoch: 24/50

100%|██████████| 1/1 [00:00<00:00, 1.90it/s, loss=-7.42]

=> Saving checkpoint Got 0.9415416717529297 with accuracy 94.15 Dice score: 0.8273581266403198 

Epoch: 25/50

100%|██████████| 1/1 [00:00<00:00, 1.94it/s, loss=-7.84]

=> Saving checkpoint Got 0.9451770782470703 with accuracy 94.52 Dice score: 0.9627659916877747 

Epoch: 26/50

100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-8.32]

=> Saving checkpoint Got 0.9467916488647461 with accuracy 94.68 Dice score: 1.1096603870391846 

Epoch: 27/50

100%|██████████| 1/1 [00:00<00:00, 1.84it/s, loss=-8.75]

=> Saving checkpoint Got 0.9468228816986084 with accuracy 94.68 Dice score: 1.2025865316390991 

Epoch: 28/50

100%|██████████| 1/1 [00:00<00:00, 1.92it/s, loss=-8.6]

=> Saving checkpoint Got 0.942562460899353 with accuracy 94.26 Dice score: 1.3215676546096802 

Epoch: 29/50

100%|██████████| 1/1 [00:00<00:00, 1.95it/s, loss=-9.3]

=> Saving checkpoint Got 0.9366874694824219 with accuracy 93.67 Dice score: 1.4410816431045532 

Epoch: 30/50

100%|██████████| 1/1 [00:00<00:00, 1.85it/s, loss=-9.37]

=> Saving checkpoint Got 0.9291979074478149 with accuracy 92.92 Dice score: 1.5965386629104614 

Epoch: 31/50

100%|██████████| 1/1 [00:00<00:00, 1.89it/s, loss=-9.29]

=> Saving checkpoint Got 0.9251145720481873 with accuracy 92.51 Dice score: 1.7157979011535645 

Epoch: 32/50

100%|██████████| 1/1 [00:00<00:00, 1.89it/s, loss=-9]

=> Saving checkpoint Got 0.9198125004768372 with accuracy 91.98 Dice score: 1.7378650903701782 

Epoch: 33/50

100%|██████████| 1/1 [00:00<00:00, 1.88it/s, loss=-9.82]

=> Saving checkpoint Got 0.9136666655540466 with accuracy 91.37 Dice score: 1.758676290512085 

Epoch: 34/50

100%|██████████| 1/1 [00:00<00:00, 1.85it/s, loss=-9.84]

=> Saving checkpoint Got 0.9044270515441895 with accuracy 90.44 Dice score: 1.7790669202804565 

Epoch: 35/50

100%|██████████| 1/1 [00:00<00:00, 1.92it/s, loss=-10.4]

=> Saving checkpoint Got 0.8995833396911621 with accuracy 89.96 Dice score: 1.7764993906021118 

Epoch: 36/50

100%|██████████| 1/1 [00:00<00:00, 1.95it/s, loss=-10.7]

=> Saving checkpoint Got 0.8992916345596313 with accuracy 89.93 Dice score: 1.775970697402954 

Epoch: 37/50

100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-10.8]

=> Saving checkpoint Got 0.8976249694824219 with accuracy 89.76 Dice score: 1.77357017993927 

Epoch: 38/50

100%|██████████| 1/1 [00:00<00:00, 1.95it/s, loss=-10.5]

=> Saving checkpoint Got 0.8952499628067017 with accuracy 89.52 Dice score: 1.782752513885498 

Epoch: 39/50

100%|██████████| 1/1 [00:00<00:00, 1.90it/s, loss=-10.8]

=> Saving checkpoint Got 0.8901249766349792 with accuracy 89.01 Dice score: 1.7879419326782227 

Epoch: 40/50

100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-11]

=> Saving checkpoint Got 0.8890937566757202 with accuracy 88.91 Dice score: 1.7865288257598877 

Epoch: 41/50

100%|██████████| 1/1 [00:00<00:00, 1.90it/s, loss=-11.2]

=> Saving checkpoint Got 0.8915520906448364 with accuracy 89.16 Dice score: 1.7899266481399536 

Epoch: 42/50

100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-11.1]

=> Saving checkpoint Got 0.8923645615577698 with accuracy 89.24 Dice score: 1.7993781566619873 

Epoch: 43/50

100%|██████████| 1/1 [00:00<00:00, 1.96it/s, loss=-10.9]

=> Saving checkpoint Got 0.887333333492279 with accuracy 88.73 Dice score: 1.804529070854187 

Epoch: 44/50

100%|██████████| 1/1 [00:00<00:00, 1.85it/s, loss=-11.9]

=> Saving checkpoint Got 0.8823854327201843 with accuracy 88.24 Dice score: 1.7975029945373535 

Epoch: 45/50

100%|██████████| 1/1 [00:00<00:00, 1.91it/s, loss=-11.2]

=> Saving checkpoint Got 0.8859270811080933 with accuracy 88.59 Dice score: 1.802530288696289 

Epoch: 46/50

100%|██████████| 1/1 [00:00<00:00, 1.92it/s, loss=-11.6]

=> Saving checkpoint Got 0.890500009059906 with accuracy 89.05 Dice score: 1.8090462684631348 

Epoch: 47/50

100%|██████████| 1/1 [00:00<00:00, 1.90it/s, loss=-11.7]

=> Saving checkpoint Got 0.9020833373069763 with accuracy 90.21 Dice score: 1.8257639408111572 

Epoch: 48/50

100%|██████████| 1/1 [00:00<00:00, 1.81it/s, loss=-12.1]

=> Saving checkpoint Got 0.9160521030426025 with accuracy 91.61 Dice score: 1.8463400602340698 

Epoch: 49/50

100%|██████████| 1/1 [00:00<00:00, 1.91it/s, loss=-12]

=> Saving checkpoint Got 0.925125002861023 with accuracy 92.51 Dice score: 1.859954833984375 

Epoch: 50/50

100%|██████████| 1/1 [00:00<00:00, 1.93it/s, loss=-11.4]

=> Saving checkpoint Got 0.933968722820282 with accuracy 93.40 Dice score: 1.873431921005249

As you can see 50th dice score is 1.873431921005249. Why dice score tis more than 1?

CodePudding user response:

The equation you are using to calculate the dice coefficient is wrong. This will work.

dice_score  = (2 * (preds * y).sum()) / (2 * (preds * y).sum()  ((preds*y)<1).sum())

You can interpret it as 2 x correct_classified/(2 x correct_classified wrong_classified). Note that this only works in the binary case.

  • Related