I am reading this part of the documentation of the Functional API for TensorFlow and I have trouble understanding a particular piece of code.
Input:
encoder_input = keras.Input(shape=(28, 28, 1), name="img")
x = layers.Conv2D(16, 3, activation="relu")(encoder_input)
x = layers.Conv2D(32, 3, activation="relu")(x)
x = layers.MaxPooling2D(3)(x)
x = layers.Conv2D(32, 3, activation="relu")(x)
x = layers.Conv2D(16, 3, activation="relu")(x)
encoder_output = layers.GlobalMaxPooling2D()(x)
encoder = keras.Model(encoder_input, encoder_output, name="encoder")
encoder.summary()
x = layers.Reshape((4, 4, 1))(encoder_output)
x = layers.Conv2DTranspose(16, 3, activation="relu")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu")(x)
x = layers.UpSampling2D(3)(x)
x = layers.Conv2DTranspose(16, 3, activation="relu")(x)
decoder_output = layers.Conv2DTranspose(1, 3, activation="relu")(x)
autoencoder = keras.Model(encoder_input, decoder_output, name="autoencoder")
autoencoder.summary()
Output:
Model: "encoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
img (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 26, 26, 16) 160
_________________________________________________________________
conv2d_1 (Conv2D) (None, 24, 24, 32) 4640
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 8, 8, 32) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 6, 6, 32) 9248
_________________________________________________________________
conv2d_3 (Conv2D) (None, 4, 4, 16) 4624
_________________________________________________________________
global_max_pooling2d (Global (None, 16) 0
=================================================================
Total params: 18,672
Trainable params: 18,672
Non-trainable params: 0
_________________________________________________________________
Model: "autoencoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
img (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 26, 26, 16) 160
_________________________________________________________________
conv2d_1 (Conv2D) (None, 24, 24, 32) 4640
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 8, 8, 32) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 6, 6, 32) 9248
_________________________________________________________________
conv2d_3 (Conv2D) (None, 4, 4, 16) 4624
_________________________________________________________________
global_max_pooling2d (Global (None, 16) 0
_________________________________________________________________
reshape (Reshape) (None, 4, 4, 1) 0
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 6, 6, 16) 160
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 8, 8, 32) 4640
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 24, 24, 32) 0
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 26, 26, 16) 4624
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 28, 28, 1) 145
=================================================================
Total params: 28,241
Trainable params: 28,241
Non-trainable params: 0
_________________________________________________________________
My question is rather simple. Why isn't the x
variable overwritten each time it calls a new layer? We can see that the information of the intermediate layers has to be stored somewhere so that, when encoder.summary()
is called, it traces back the schema. However, as I am overwriting the x
variable, I should lose all information related to these layers. Where is this information stored?
CodePudding user response:
The way that the Functional API works is based on the idea that a model is usually a DAG of layers. Each edge in the graph connects two layers. Therefore, a layer is instantiated with its parameters (eg x =layers.Dense(64, activation="relu")
), and then "called" (eg x(other_layer)
). This call is the equivalent of drawing an arrow from y
to x
. Thus, when you "stack" layers by repeatedly calling a layer on the previous layer, each layer contains a pointer to the previous layer. If you call x
in the REPL you will see something like:
<KerasTensor: shape=(None, 4, 4, 16) dtype=float32 (created by layer 'conv2d_3')>
That way, when you call keras.Model
on the final layer, it contains the edges connecting it to all of the previous layers.