Home > Enterprise >  Trying to use Dice Loss with UNET
Trying to use Dice Loss with UNET

Time:01-02

I'm trying to implement the UNET at the keras website:

Image segmentation with a U-Net-like architecture

With only one change. use Dice loss instead of "sparse_categorical_crossentropy". However, every time I try something, I get different error. I'm coding on google colab using Tensorflow 2.7.

For example, I tried using

def DiceLoss(targets, inputs, smooth=1e-6):
    
    #flatten label and prediction tensors
    
    inputs = K.flatten(inputs)
    targets = K.flatten(targets)
    
    intersection = K.sum(K.dot(targets, inputs))
    
    dice = (2*intersection   smooth) / (K.sum(targets)   K.sum(inputs)   smooth)
    
    return 1 - dice

The eror I got:

ValueError: Shape must be rank 2 but is rank 1 for '{{node DiceLoss99/MatMul}} = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=false](DiceLoss99/Reshape_1, DiceLoss99/Reshape)' with input shapes: [?], [?].

The problem is on this line:

intersection = K.sum(K.dot(targets, inputs))

I also tried this library:

 !pip install git https://github.com/qubvel/segmentation_models
 # define optomizer
 n_classes=3
 LR = 0.0001
 optim = keras.optimizers.Adam(LR)
 dice_loss_sm = sm.losses.DiceLoss(class_weights=K.ones_like(n_classes))  

However, I got the following error:

TypeError: Input 'y' of 'Mul' Op has type int32 that does not match type float32 of argument 'x'.

the remaining code is same as in keras.io. but I listed below for completeness :

from tensorflow.keras import layers


def get_model(img_size, num_classes):
    inputs = keras.Input(shape=img_size   (3,))

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model


# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

# Build model
model = get_model(img_size, num_classes)
model.summary()
# Configure the model for training.
# We use the "sparse" version of categorical_crossentropy
# because our target data is integers.

#   notice I changed the lose the dice loss instead of sparse_categorical_crossentropy
model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy")

callbacks = [
    keras.callbacks.ModelCheckpoint("oxford_segmentation.h5", save_best_only=True)
]

# Train the model, doing validation at the end of each epoch.
epochs = 15
model.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks)

CodePudding user response:

You are passing 1-dimensional vectors to K.dot, while the ValueError is saying that K.dot requires arrays with 2-dimensions.

You can replace it with element-wise multiplication, i.e. intersection = K.sum(targets *inputs)

  • Related