Home > Software design >  Cannot divide values in convolution output by batch size in Keras
Cannot divide values in convolution output by batch size in Keras

Time:05-12

I need to do something I believe should be straightforward: divide the output of a convolutional layer by the batch size (I can elaborate if you are interested about why).

Here is the minimal code to reproduce what I'm trying to do

from keras.models import *
from keras.layers import *
from keras import backend as K

input = Input((24,24,3))
conv = Conv2D(8,4,1,'SAME')(input)
norm = Lambda(lambda x:x[0]/x[1])((conv,input.shape[0]))
model = Model(inputs = input, outputs = norm)
model.summary()

However, I get the error:

----> norm = Lambda(lambda x:x[0]/x[1])((conv,input.shape[0]))
      model = Model(inputs = input, outputs = norm)

ValueError: Exception encountered when calling layer "lambda_9" (type Lambda).

None values not supported.

Call arguments received:
  • inputs=('tf.Tensor(shape=(None, 24, 24, 8), dtype=float32)', 'None')
  • mask=None
  • training=None

I feel like this "should be allowed". What am I missing or doing wrong? Thank you!

CodePudding user response:

There are actually two ways to get a symbolic tensor's shape: tensor.shape (the one you are using) and tf.shape(tensor). In a nutshell, tensor.shape returns the "static shape" of the tensor as it is known at compile time. This can be problematic in cases like this: You are trying to use a dimension which is unknown at compile time, but defined when the model actually runs. tensor.shape fails here because it will just return the statically known None (which is essentially a placeholder for the unknown batch size).

On the other hand, tf.shape(tensor) returns the dynamic shape of the tensor, which will actually be a tensor itself, and you can use this to define operations based on unknown shapes. What this means is that you just need to replace one line:

norm = Lambda(lambda x:x[0]/x[1])((conv, tf.cast(tf.shape(input)[0], tf.float32)))

Note that we have to cast the shape to a float to avoid a dtype mismatch. This lets the cell run fine:

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_3 (InputLayer)           [(None, 24, 24, 3)]  0           []                               
                                                                                                  
 tf.compat.v1.shape_1 (TFOpLamb  (4,)                0           ['input_3[0][0]']                
 da)                                                                                              
                                                                                                  
 tf.__operators__.getitem_1 (Sl  ()                  0           ['tf.compat.v1.shape_1[0][0]']   
 icingOpLambda)                                                                                   
                                                                                                  
 conv2d_2 (Conv2D)              (None, 24, 24, 8)    392         ['input_3[0][0]']                
                                                                                                  
 tf.cast (TFOpLambda)           ()                   0           ['tf.__operators__.getitem_1[0][0
                                                                 ]']                              
                                                                                                  
 lambda_2 (Lambda)              (None, 24, 24, 8)    0           ['conv2d_2[0][0]',               
                                                                  'tf.cast[0][0]']                
                                                                                                  
==================================================================================================
Total params: 392
Trainable params: 392
Non-trainable params: 0
__________________________________________________________________________________________________
  • Related