I'm working on a simple model as follows but kinda having difficulty with a concat error.
def build_classifier_model():
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='input1')
preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
encoder_inputs = preprocessing_layer(text_input)
encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
outputs = encoder(encoder_inputs)
net = outputs['pooled_output']
net = tf.keras.layers.Dropout(0.1)(net)
side_input = tf.keras.layers.Input(shape=(2), dtype=tf.float32, name='input2')
print(net.shape)
print(side_input.shape)
net = tf.concat(values=[net, side_input], axis=1)
# net = tf.keras.layers.concatenate([net, side_input], axis=1)
net = tf.keras.layers.Dense(1, activation=None, name='classifier')(net)
return tf.keras.Model(inputs=[text_input, side_input], outputs=net)
I printed the shape of 'net', and 'side_input' and check their shapes are (None, 512), and (None, 2).
However I'm having concat rank error and it indicates that the shapes are (1, 512) and (2).
(None, 512)
(None, 2)
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-241-a75a6db54831> in <module>
1 classifier_model = build_classifier_model()
2
----> 3 bert_raw_result = classifier_model([tf.constant(text_test), tf.reshape([0.3, 0.3], (2))])
4
5
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
7184 def raise_from_not_ok_status(e, name):
7185 e.message = (" name: " name if name is not None else "")
-> 7186 raise core._status_to_exception(e) from None # pylint: disable=protected-access
7187
7188
InvalidArgumentError: Exception encountered when calling layer "tf.concat_7" (type TFOpLambda).
ConcatOp : Ranks of all input tensors should match: shape[0] = [1,512] vs. shape[1] = [2] [Op:ConcatV2] name: concat
Call arguments received:
• values=['tf.Tensor(shape=(1, 512), dtype=float32)', 'tf.Tensor(shape=(2,), dtype=float32)']
• axis=1
• name=concat
The sample input i used was
bert_raw_result = classifier_model([tf.constant(text_test), tf.reshape([0.3, 0.3], (2))])
CodePudding user response:
The problem is in your 2nd input, i.e. tf.reshape([0.3, 0.3], (2))
which is a [2]
sized input.
Your model input (i.e. side_input
from what I see) expects a [None, 2]
sized input.
So the solution is (given your first input is [1, 512]
size),
tf.reshape([0.3, 0.3], (1, 2))