Home > database >  Fix CNN overfitting
Fix CNN overfitting

Time:11-23

I'm using the CNN and MobileNet models to build a model to classify sign language to alphabet letters based on an images data set. So, it is a multi-class classification model. However, after compiling and fitting the model. I got a high accuracy (98%). But when I want to visualize the confusion matrix I got really missed matrix. Does this mean the model is overfitting? and how can I fix it to get a better matrix?


train_path = 'train'
test_path = 'test'

train_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input).flow_from_directory(
    directory=train_path, target_size=(64,64), batch_size=10)


test_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input).flow_from_directory(
    directory=test_path, target_size=(64,64), batch_size=10)


mobile = tf.keras.applications.mobilenet.MobileNet()

x = mobile.layers[-6].output
output = Dense(units=32, activation='softmax')(x)
model = Model(inputs=mobile.input, outputs=output)
for layer in model.layers[:-23]:
    layer.trainable = False
model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self,epoch,logs={}):
        if(logs.get('val_accuracy')>=0.98):
            print('\n Reached to good accuracy')
            self.model.stop_training=True
callbacks=myCallback()


model.fit(train_batches,
            steps_per_epoch=len(train_batches), 
            validation_data=test_batches,
            validation_steps=len(test_batches),
            epochs=10,callbacks=[callbacks])




Epoch 1/10
4498/4498 [==============================] - 979s 217ms/step - loss: 1.3062 - accuracy: 0.6530 - val_loss: 0.1528 - val_accuracy: 0.9594
Epoch 2/10
4498/4498 [==============================] - 992s 221ms/step - loss: 0.1777 - accuracy: 0.9491 - val_loss: 0.1164 - val_accuracy: 0.9691
Epoch 3/10
4498/4498 [==============================] - 998s 222ms/step - loss: 0.1117 - accuracy: 0.9654 - val_loss: 0.0925 - val_accuracy: 0.9734
Epoch 4/10
4498/4498 [==============================] - 1000s 222ms/step - loss: 0.0789 - accuracy: 0.9758 - val_loss: 0.0992 - val_accuracy: 0.9750
Epoch 5/10
4498/4498 [==============================] - 1001s 223ms/step - loss: 0.0626 - accuracy: 0.9805 - val_loss: 0.0818 - val_accuracy: 0.9783
Epoch 6/10
4498/4498 [==============================] - 1007s 224ms/step - loss: 0.0521 - accuracy: 0.9834 - val_loss: 0.0944 - val_accuracy: 0.9789
Epoch 7/10
4498/4498 [==============================] - 1004s 223ms/step - loss: 0.0475 - accuracy: 0.9863 - val_loss: 0.0935 - val_accuracy: 0.9795
Epoch 8/10
4498/4498 [==============================] - 1013s 225ms/step - loss: 0.0371 - accuracy: 0.9880 - val_loss: 0.0854 - val_accuracy: 0.9781
Epoch 9/10
4498/4498 [==============================] - 896s 199ms/step - loss: 0.0365 - accuracy: 0.9879 - val_loss: 0.0766 - val_accuracy: 0.9806

 Reached to good accuracy


test_labels = test_batches.classes

predictions = model.predict(x=test_batches, steps=len(test_batches),verbose=0)

cm = confusion_matrix(y_true=test_labels, y_pred=predictions.argmax(axis=1))


cm_plot_labels = ['0','1','2','3','4','5','6','7','8','9','10','11','12','13','14','15','16',
                  '17','18','19','20','21','22','23','24','25','26','27','28','29','30','31'
                 ]
plot_confusion_matrix(cm=cm, classes=cm_plot_labels, title='Confusion Matrix')

the resulted confusion matrix

CodePudding user response:

there is some tricks to help with orver fitting problem:

  1. Adding data augmentation, this method will slightly transform each time the input with rotation, random croping, etc. and the model will see more example of the same image it will help the model to better generalize.
  2. Adding dropout layer, this layer will randomly sets input units to 0 with in the training process, so in that the model will make more epoch before over fitting.
  3. L1 and L2 regularization , this method will penalize the absolute value of the weights by adding them to the total loss.(enter link description here
  4. It's better to change your callback withcallback = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=3), I think your model stopped when there is still room for emprovement.
  • Related