Home > Mobile >  remove only last(dense) layer of an already trained model, keeping all the weights of the model inta
remove only last(dense) layer of an already trained model, keeping all the weights of the model inta

Time:03-06

I want to remove only the last dense layer from an already saved model in .h5 file and add a new dense layer.

Information about the saved model:

I used transfer learning on the EfficientNet B0 model and added a dropout with 2 dense layers. The last dense layer had 3 nodes equal to my number of classes, as shown below:

inputs = tf.keras.layers.Input(shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 3))
x = img_augmentation(inputs)
model = tf.keras.applications.EfficientNetB0(include_top=False, input_tensor=x, weights="imagenet")
# Freeze the pretrained weights
model.trainable = False
# Rebuild top
x = tf.keras.layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
x = tf.keras.layers.BatchNormalization()(x)

x = tf.keras.layers.Dropout(0.3)(x)
x = tf.keras.layers.Dense(5, activation=tf.nn.relu)(x)

outputs = tf.keras.layers.Dense(len(class_names), activation="softmax", name="pred")(x)

After training, I saved my model as my_h5_model.h5

Main Task: I want to use the saved model architecture with its weights and replace only the last dense layer with 4 nodes dense layer.

I tried many things as suggested by the StackOverflow community as:


Iterate over all the layers except the last layer and add them to a separate already defined sequential model

new_model = Sequential()
for layer in (model.layers[:-1]):
    new_model.add(layer)

But it gives an error which state:

ValueError: Exception encountered when calling layer "block1a_se_excite" (type Multiply).

A merge layer should be called on a list of inputs. Received: inputs=Tensor("Placeholder:0", shape=(None, 1, 1, 32), dtype=float32) (not a list of tensors)

Call arguments received:

• inputs=tf.Tensor(shape=(None, 1, 1, 32), dtype=float32)


I also tried the functional approach as:

input_layer = model.input
for layer in (model.layers[:-1]):
    x = layer(input_layer)

which throws an as mention below:

ValueError: Exception encountered when calling layer "stem_bn" (type BatchNormalization).

Dimensions must be equal, but are 3 and 32 for '{{node stem_bn/FusedBatchNormV3}} = FusedBatchNormV3[T=DT_FLOAT, U=DT_FLOAT, data_format="NHWC", epsilon=0.001, exponential_avg_factor=1, is_training=false](Placeholder, stem_bn/ReadVariableOp, stem_bn/ReadVariableOp_1, stem_bn/FusedBatchNormV3/ReadVariableOp, stem_bn/FusedBatchNormV3/ReadVariableOp_1)' with input shapes: [?,224,224,3], [32], [32], [32], [32].

Call arguments received:

• inputs=tf.Tensor(shape=(None, 224, 224, 3), dtype=float32)

• training=False


Lastly, I did something that came to my mind

inputs = tf.keras.layers.Input(shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 3))
x = img_augmentation(inputs)
x = model.layers[:-1](x)
x = keras.layers.Dense(5, name="compress_1")(x)

which simply gave an error as:

'list' object is not callable

CodePudding user response:

Have you tried switching between import keras and tensorflow.keras in your import? This has worked in other issues.

CodePudding user response:

I did some more experiments and was able to remove the last layer and add the new dense layer

# imported a pretained saved model  
from tensorflow import keras
import tensorflow as tf

model = keras.models.load_model('/content/my_h5_model.h5')

# selected all layers except last one
x= model.layers[-2].output 
outputs = tf.keras.layers.Dense(4, activation="softmax", name="predictions")(x)
model = tf.keras.Model(inputs = model.input, outputs = outputs)
model.summary()

In the saved model, I had 3 nodes on dense layers, but in the current model, I added 4 layers. The last layer summary is shown below:

 dropout_3 (Dropout)            (None, 1280)         0           ['batch_normalization_4[0][0]']  
                                                                                                  
 dense_3 (Dense)                (None, 5)            6405        ['dropout_3[0][0]']              
                                                                                                  
 predictions (Dense)            (None, 4)            24          ['dense_3[0][0]']                
                                                                                                  
==================================================================================================
  • Related