Home > database >  How do I handle tensorflow concatenate shape error?
How do I handle tensorflow concatenate shape error?

Time:10-20

I'm working on a simple model as follows but kinda having difficulty with a concat error.

def build_classifier_model():
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='input1')
    preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
    encoder_inputs = preprocessing_layer(text_input)
    encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
    outputs = encoder(encoder_inputs)
    net = outputs['pooled_output']
    net = tf.keras.layers.Dropout(0.1)(net)
    side_input = tf.keras.layers.Input(shape=(2), dtype=tf.float32, name='input2')
    print(net.shape)
    print(side_input.shape)
    net = tf.concat(values=[net, side_input], axis=1)
    # net = tf.keras.layers.concatenate([net, side_input], axis=1)
    net = tf.keras.layers.Dense(1, activation=None, name='classifier')(net)
    return tf.keras.Model(inputs=[text_input, side_input], outputs=net)

I printed the shape of 'net', and 'side_input' and check their shapes are (None, 512), and (None, 2).

However I'm having concat rank error and it indicates that the shapes are (1, 512) and (2).

(None, 512)
(None, 2)
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-241-a75a6db54831> in <module>
      1 classifier_model = build_classifier_model()
      2 
----> 3 bert_raw_result = classifier_model([tf.constant(text_test), tf.reshape([0.3, 0.3], (2))])
      4 
      5 

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   7184 def raise_from_not_ok_status(e, name):
   7185   e.message  = (" name: "   name if name is not None else "")
-> 7186   raise core._status_to_exception(e) from None  # pylint: disable=protected-access
   7187 
   7188 

InvalidArgumentError: Exception encountered when calling layer "tf.concat_7" (type TFOpLambda).

ConcatOp : Ranks of all input tensors should match: shape[0] = [1,512] vs. shape[1] = [2] [Op:ConcatV2] name: concat

Call arguments received:
  • values=['tf.Tensor(shape=(1, 512), dtype=float32)', 'tf.Tensor(shape=(2,), dtype=float32)']
  • axis=1
  • name=concat

The sample input i used was

bert_raw_result = classifier_model([tf.constant(text_test), tf.reshape([0.3, 0.3], (2))])

CodePudding user response:

The problem is in your 2nd input, i.e. tf.reshape([0.3, 0.3], (2)) which is a [2] sized input.

Your model input (i.e. side_input from what I see) expects a [None, 2] sized input.

So the solution is (given your first input is [1, 512] size),

tf.reshape([0.3, 0.3], (1, 2))
  • Related