I try to extend one of my alread build models with another layer to sum the substraction of the previous output of the model and an additional input. All my tries can be summarized using the Lambda-layer:
new_input = Input(shape=(20, 3))
sub_layer = Lambda(lambda x: Reshape((20,1))(backend.sum(x[0] - x[1], axis=-1)))((model.output,new_input))
model = Model(inputs=[model.input, new_input], outputs=sub_layer.output)
But no matter what I do I always receive this error:
Could not build a TypeSpec for KerasTensor(type_spec=TensorSpec(shape=(None, 20, 1), dtype=tf.float32, name=None), name='tf.reshape/Reshape:0', description="created by layer 'tf.reshape'") of unsupported type <class 'keras.engine.keras_tensor.KerasTensor'>.
CodePudding user response:
Not sure what your pretrained model looks like, but you can try something like this:
import tensorflow as tf
inputs = tf.keras.layers.Input((10,))
outputs = tf.keras.layers.Dense(20, activation='relu')(inputs)
outputs = tf.keras.layers.RepeatVector(3)(outputs)
outputs = tf.keras.layers.Reshape((20,3))(outputs)
model1 = tf.keras.Model(inputs, outputs)
new_input = tf.keras.layers.Input(shape=(20, 3))
sub_layer = tf.keras.layers.Lambda(lambda x: tf.keras.layers.Reshape((20,1))(tf.keras.backend.sum(x[0] - x[1], axis=-1)))((model1.output,new_input))
model2 = tf.keras.Model(inputs=[model1.input, new_input], outputs=sub_layer)
print(model2.summary())
Model: "model_4"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_7 (InputLayer) [(None, 10)] 0 []
dense_3 (Dense) (None, 20) 220 ['input_7[0][0]']
repeat_vector_1 (RepeatVector) (None, 3, 20) 0 ['dense_3[0][0]']
reshape (Reshape) (None, 20, 3) 0 ['repeat_vector_1[0][0]']
input_8 (InputLayer) [(None, 20, 3)] 0 []
lambda_3 (Lambda) (None, 20, 1) 0 ['reshape[0][0]',
'input_8[0][0]']
==================================================================================================
Total params: 220
Trainable params: 220
Non-trainable params: 0
__________________________________________________________________________________________________
None