I am building a image captioning model and for that I am using ResNet50
as a feature extraction model. I have written the code, and it is working properly:
rs50 = tf.keras.applications.ResNet50(include_top = False, weights = 'imagenet', input_shape = (224, 224, 3))
new_input = rs50.input
hidden_layer = rs50.layers[-1].output
feature_extract = tf.keras.Model(new_input, hidden_layer)
Below are the last few lines of model summary (feature_extract.summary()
):
conv5_block3_3_bn (BatchNormal (None, 7, 7, 2048) 8192 ['conv5_block3_3_conv[0][0]']
ization)
conv5_block3_add (Add) (None, 7, 7, 2048) 0 ['conv5_block2_out[0][0]',
'conv5_block3_3_bn[0][0]']
conv5_block3_out (Activation) (None, 7, 7, 2048) 0 ['conv5_block3_add[0][0]']
==================================================================================================
Total params: 23,587,712
Trainable params: 23,534,592
Non-trainable params: 53,120
But, the problem is that it is generating 2048 features. I don't have that much memory, so I wanted to change that (None, 7, 7, 2048)
to (None, 7, 7, 1024)
How do I do it?
CodePudding user response:
One way would be to find the last layer that has the output shape (None, 14, 14, 1024)
and extract the model's layers until that point. The conv4_block6_out
layer happens to be the last layer before the last block begins. This way, the last block is skipped altogether, which saves more memory. Then, apply one or more Conv2D
or MaxPooling
layers to get the shape (None, 7, 7, 1024)
:
import tensorflow as tf
rs50 = tf.keras.applications.ResNet50(include_top = False, weights = 'imagenet', input_shape = (224, 224, 3))
index = 0
for i, l in enumerate(rs50.layers):
if 'conv4_block6_out' in l.name:
index = i
new_input = rs50.input
hidden_layer = rs50.layers[index].output
output = tf.keras.layers.Conv2D(1024, kernel_size=8)(hidden_layer)
feature_extract = tf.keras.Model(new_input, output)
print(feature_extract.output)
KerasTensor(type_spec=TensorSpec(shape=(None, 7, 7, 1024), dtype=tf.float32, name=None), name='conv2d_4/BiasAdd:0', description="created by layer 'conv2d_4'")