Home > Software engineering >  Tensorflow with custom loss containing multiple inputs - Graph disconnected error
Tensorflow with custom loss containing multiple inputs - Graph disconnected error

Time:04-15

I have a CNN output a scalar, this output is concatenated with the output of an MLP and then fed to another dense layer. I get a Graph Disconnected error

Please advise as to how to fix this. Thanks in advance.

from tensorflow.keras.models import  Model 
from tensorflow.keras.layers import Conv2D, Dense, Flatten, concatenate, Input
import tensorflow as tf

tf.keras.backend.clear_session()


#----custom function
def custom_loss(ytrue, ypred):
 loss = tf.math.log(1.   ytrue) - tf.math.log(1.   ypred)
 loss = tf.math.square(loss)
 loss = tf.math.reduce_mean(loss)
 return loss   


#------------------
cnnin = Input(shape=(10, 10, 1))
x = Conv2D(8, 4)(cnnin)
x = Conv2D(16, 4)(x)
x = Conv2D(32, 2)(x)
x = Conv2D(64, 2)(x)
x = Flatten()(x)
x = Dense(4)(x)
x = Dense(4, activation="relu")(x)
cnnout = Dense(1, activation="linear")(x)

cnnmodel= Model(cnnin, cnnout, name="cnn_model")

yt = Input(shape=(2, ))   #---dummy input 

#---mlp start 
mlpin    = Input(shape=(2, ), name="mlp_input")
z        = Dense(4, activation="sigmoid")(mlpin)
z        = Dense(4, activation = "softmax")(z)
mlpout   = Dense(1, activation="linear")(z)
mlpmodel = Model(mlpin, mlpout, name="mlp_model")


#----concatenate 
combinedout = concatenate([mlpmodel.output, cnnmodel.output ])

x = Dense(4, activation="sigmoid")(combinedout)
finalout = Dense(2, activation="linear")(x)


model = Model( [mlpin, cnnin], finalout)

model.add_loss(custom_loss(yt, finalout))
model.compile(optimizer='adam', learning_rate=1e-3, initialization="glorotnorm",
              loss=None)

Graph disconnected: cannot obtain value for tensor Tensor("input_8:0", shape=(None, 2), dtype=float32) at layer "input_8". The following previous layers were accessed without issue: ['input_7', 'conv2d_12', 'conv2d_13', 'conv2d_14', 'conv2d_15', 'flatten_3', 'mlp_input', 'dense_24', 'dense_27', 'dense_25', 'dense_28', 'dense_29', 'dense_26', 'concatenate_3', 'dense_30', 'dense_31']

CodePudding user response:

You can customize what happens in Model.fit based on https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit

  • We create a new class that subclasses keras.Model.
  • We just override the method train_step(self, data).
  • We return a dictionary mapping metric names (including the loss) to their current value.

For example with your models:

loss_tracker = tf.keras.metrics.Mean(name = "custom_loss")
class TestModel(tf.keras.Model):
    def __init__(self, model1):
        super(TestModel, self).__init__()
        self.model1 = model1
    def compile(self, optimizer):
        super(TestModel, self).compile()
        self.optimizer = optimizer
    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            ypred = self.model1([x], training = True)
            loss_value = custom_loss(y, ypred)
        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss_value, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        loss_tracker.update_state(loss_value)
        return {"loss": loss_tracker.result()}

import numpy as np
x = np.random.rand(6, 10,10,1)
x2 = np.random.rand(6,2)
y = tf.ones((6,2))

model = Model( [mlpin, cnnin], finalout)
trainable_model = TestModel(model)
trainable_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate = 0.0001))
trainable_model.fit(x=(x2, x), y = y, epochs=5)

Gives the following output:

Epoch 1/5
1/1 [==============================] - 0s 382ms/step - loss: 0.2641
Epoch 2/5
1/1 [==============================] - 0s 4ms/step - loss: 0.2640
Epoch 3/5
1/1 [==============================] - 0s 6ms/step - loss: 0.2638
Epoch 4/5
1/1 [==============================] - 0s 7ms/step - loss: 0.2635
Epoch 5/5
1/1 [==============================] - 0s 6ms/step - loss: 0.2632
<tensorflow.python.keras.callbacks.History at 0x14c69572688>
  • Related