I have subclassed the tf.keras.Model
class and did what I consider to be transfer learning. Here's the code:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D
from tensorflow.python.keras.applications.xception import Xception
class MyModel(tf.keras.Model):
def __init__(self, input_shape, num_classes=5, dropout_rate=0.5):
super(MyModel, self).__init__()
self.weight_dict = {}
self.weight_dict['backbone'] = Xception(input_shape=input_shape, weights='imagenet', include_top=False)
self.weight_dict['backbone'].trainable = False
self.weight_dict['outputs'] = Conv2D(num_classes, (1, 1), padding="same", activation="softmax")
def call(self, inputs, training=False):
x = self.weight_dict['backbone'](inputs)
x = self.weight_dict['outputs'](x)
return x
model = MyModel(input_shape=(256, 256, 3))
model.compute_output_shape(input_shape=(None, 256, 256, 3))
model.summary()
Here's the output of model.summary()
:
Layer (type) Output Shape Param #
=================================================================
conv2d_5 (Conv2D) multiple 10245
=================================================================
Total params: 20,871,725
Trainable params: 10,245
Non-trainable params: 20,861,480
I suppose that the backbone has successfully been initialized. However, why does it not show up in the output of .summary()
? Can I be assured that the weights will get saved properly (if I use model.save()
)?
Also, would you consider this a good/normal way to subclass tf.keras
models and do transfer learning?
CodePudding user response:
Restructuring this way would give you the correct model.summary()
IMG_SHAPE = (256, 256, 3)
base_model = Xception(input_shape=IMG_SHAPE, weights='imagenet',
include_top=False)
base_model.trainable = False
inputs = tf.keras.Input(shape=IMG_SHAPE)
x = base_model(inputs, training=False)
outputs = Conv2D(num_classes, (1,1), padding="same", activation="softmax")(x)
model = tf.keras.Model(inputs, outputs)
model.summary()
Also, did you mean to use a Dense
layer as the final output? Because usually, ReLU activation is used for outputs of Conv2D
layers.
To see the proper way to use models for transfer learning, you can have a look at examples from the documentation: https://www.tensorflow.org/tutorials/images/transfer_learning