I'm trying to implement the UNET at the keras website:
Image segmentation with a U-Net-like architecture
With only one change. use Dice loss instead of "sparse_categorical_crossentropy". However, every time I try something, I get different error. I'm coding on google colab using Tensorflow 2.7.
For example, I tried using
def DiceLoss(targets, inputs, smooth=1e-6):
#flatten label and prediction tensors
inputs = K.flatten(inputs)
targets = K.flatten(targets)
intersection = K.sum(K.dot(targets, inputs))
dice = (2*intersection smooth) / (K.sum(targets) K.sum(inputs) smooth)
return 1 - dice
The eror I got:
ValueError: Shape must be rank 2 but is rank 1 for '{{node DiceLoss99/MatMul}} = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=false](DiceLoss99/Reshape_1, DiceLoss99/Reshape)' with input shapes: [?], [?].
The problem is on this line:
intersection = K.sum(K.dot(targets, inputs))
I also tried this library:
!pip install git https://github.com/qubvel/segmentation_models
# define optomizer
n_classes=3
LR = 0.0001
optim = keras.optimizers.Adam(LR)
dice_loss_sm = sm.losses.DiceLoss(class_weights=K.ones_like(n_classes))
However, I got the following error:
TypeError: Input 'y' of 'Mul' Op has type int32 that does not match type float32 of argument 'x'.
the remaining code is same as in keras.io. but I listed below for completeness :
from tensorflow.keras import layers
def get_model(img_size, num_classes):
inputs = keras.Input(shape=img_size (3,))
### [First half of the network: downsampling inputs] ###
# Entry block
x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
previous_block_activation = x # Set aside residual
# Blocks 1, 2, 3 are identical apart from the feature depth.
for filters in [64, 128, 256]:
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
# Project residual
residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
previous_block_activation
)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
### [Second half of the network: upsampling inputs] ###
for filters in [256, 128, 64, 32]:
x = layers.Activation("relu")(x)
x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling2D(2)(x)
# Project residual
residual = layers.UpSampling2D(2)(previous_block_activation)
residual = layers.Conv2D(filters, 1, padding="same")(residual)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
# Add a per-pixel classification layer
outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)
# Define the model
model = keras.Model(inputs, outputs)
return model
# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()
# Build model
model = get_model(img_size, num_classes)
model.summary()
# Configure the model for training.
# We use the "sparse" version of categorical_crossentropy
# because our target data is integers.
# notice I changed the lose the dice loss instead of sparse_categorical_crossentropy
model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy")
callbacks = [
keras.callbacks.ModelCheckpoint("oxford_segmentation.h5", save_best_only=True)
]
# Train the model, doing validation at the end of each epoch.
epochs = 15
model.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks)
CodePudding user response:
You are passing 1-dimensional vectors to K.dot
, while the ValueError is saying that K.dot
requires arrays with 2-dimensions.
You can replace it with element-wise multiplication, i.e. intersection = K.sum(targets *inputs)