I need to do something I believe should be straightforward: divide the output of a convolutional layer by the batch size (I can elaborate if you are interested about why).
Here is the minimal code to reproduce what I'm trying to do
from keras.models import *
from keras.layers import *
from keras import backend as K
input = Input((24,24,3))
conv = Conv2D(8,4,1,'SAME')(input)
norm = Lambda(lambda x:x[0]/x[1])((conv,input.shape[0]))
model = Model(inputs = input, outputs = norm)
However, I get the error:
----> norm = Lambda(lambda x:x[0]/x[1])((conv,input.shape[0]))
model = Model(inputs = input, outputs = norm)
ValueError: Exception encountered when calling layer "lambda_9" (type Lambda).
None values not supported.
Call arguments received:
• inputs=('tf.Tensor(shape=(None, 24, 24, 8), dtype=float32)', 'None')
• mask=None
• training=None
I feel like this "should be allowed". What am I missing or doing wrong? Thank you!
CodePudding user response:
There are actually two ways to get a symbolic tensor's shape: tensor.shape
(the one you are using) and tf.shape(tensor)
. In a nutshell, tensor.shape
returns the "static shape" of the tensor as it is known at compile time. This can be problematic in cases like this: You are trying to use a dimension which is unknown at compile time, but defined when the model actually runs. tensor.shape
fails here because it will just return the statically known None
(which is essentially a placeholder for the unknown batch size).
On the other hand, tf.shape(tensor)
returns the dynamic shape of the tensor, which will actually be a tensor itself, and you can use this to define operations based on unknown shapes. What this means is that you just need to replace one line:
norm = Lambda(lambda x:x[0]/x[1])((conv, tf.cast(tf.shape(input)[0], tf.float32)))
Note that we have to cast the shape to a float to avoid a dtype mismatch. This lets the cell run fine:
Model: "model"
Layer (type) Output Shape Param # Connected to
input_3 (InputLayer) [(None, 24, 24, 3)] 0 []
tf.compat.v1.shape_1 (TFOpLamb (4,) 0 ['input_3[0][0]']
tf.__operators__.getitem_1 (Sl () 0 ['tf.compat.v1.shape_1[0][0]']
conv2d_2 (Conv2D) (None, 24, 24, 8) 392 ['input_3[0][0]']
tf.cast (TFOpLambda) () 0 ['tf.__operators__.getitem_1[0][0
lambda_2 (Lambda) (None, 24, 24, 8) 0 ['conv2d_2[0][0]',
Total params: 392
Trainable params: 392
Non-trainable params: 0