Home > Software design >  `model.summary()` with TensorFlow model subclassing print output shape as "multiple"
`model.summary()` with TensorFlow model subclassing print output shape as "multiple"

Time:12-02

I tried to implement Vgg network with following VggBlock.

class VggBlock(tf.keras.Model):
  def __init__(self, filters, repetitions):
    super(VggBlock, self).__init__()
    self.repetitions = repetitions

    self.conv_layers = [Conv2D(filters=filters, kernel_size=(3, 3), padding='same', activation='relu') for _ in range(repetitions)]
    self.max_pool = MaxPool2D(pool_size=(2, 2))

  def call(self, inputs):
    x = inputs
    for layer in self.conv_layers:
      x = layer(x)
    return self.max_pool(x)

test_block = VggBlock(filters=64, repetitions=2)
temp_inputs = Input(shape=(224, 224, 3))
test_block(temp_inputs)
test_block.summary()

Then the above code prints:

Model: "vgg_block"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             multiple                  1792      
                                                                 
 conv2d_1 (Conv2D)           multiple                  36928     
                                                                 
 max_pooling2d (MaxPooling2D  multiple                 0         
 )                                                               
                                                                 
=================================================================
Total params: 38,720
Trainable params: 38,720
Non-trainable params: 0
_________________________________________________________________

And if I build Vgg with these blocks, its summary() also prints "multiple".

There are some questions similar to my problem, ex: https://github.com/keras-team/keras/issues/13782 , model.summary() can't print output shape while using subclass model

However, I can not extend the answers in the second link: in terms of varying input_shape.

How do I treat summary() in order to make "multiple" to be an appropriate shape.

CodePudding user response:

You already linked some workarounds. You seem to be landing here, because the output shape of each layer cannot be determined. As stated here:

You can do all these things (printing input / output shapes) in a Functional or Sequential model because these models are static graphs of layers.

In contrast, a subclassed model is a piece of Python code (a call method). There is no graph of layers here. We cannot know how layers are connected to each other (because that's defined in the body of call, not as an explicit data structure), so we cannot infer input / output shapes.

You could also try something like this:

import tensorflow as tf

class VggBlock(tf.keras.Model):

  def __init__(self, filters, repetitions, image_shape):
    super(VggBlock, self).__init__()
    self.repetitions = repetitions

    self.conv_layers = [tf.keras.layers.Conv2D(filters=filters, kernel_size=(3, 3), padding='same', activation='relu') for _ in range(repetitions)]
    self.max_pool = tf.keras.layers.MaxPool2D(pool_size=(2, 2))

    inputs = tf.keras.layers.Input(shape=image_shape)
    x = inputs
    for layer in self.conv_layers:
      x = layer(x)
    outputs = self.max_pool(x)
    self.model = tf.keras.Model(inputs, outputs)

  def call(self, inputs):
    return self.model(inputs)
  
  def summary(self):
    self.model.summary()

test_block = VggBlock(filters=64, repetitions=2, image_shape=(224, 224, 3))
test_block.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 conv2d (Conv2D)             (None, 224, 224, 64)      1792      
                                                                 
 conv2d_1 (Conv2D)           (None, 224, 224, 64)      36928     
                                                                 
 max_pooling2d (MaxPooling2D  (None, 112, 112, 64)     0         
 )                                                               
                                                                 
=================================================================
Total params: 38,720
Trainable params: 38,720
Non-trainable params: 0
_________________________________________________________________
  • Related