Home > database >  In tensorflow, how to combine multiple losses with a desired formula
In tensorflow, how to combine multiple losses with a desired formula

Time:06-22

I have a CNN model with a single output neuron consisting of sigmoid activation, hence its value is in between 0 and 1. I wanted to calculate a combination of loss for this particular output neuron.

I was using Mean Absolute Error and Mean Squared Error for the same, and creating a loss like this:

loss = tf.keras.losses.MeanAbsoluteError()   tf.keras.losses.MeanSquaredError()

Now, due to some issue, the tensorflow framework is not supporting loss function like this. Here is the error:

Traceback (most recent call last):
  File "run_kfold.py", line 189, in <module>
    loss = tf.keras.losses.MeanAbsoluteError()   tf.keras.losses.MeanSquaredError()
TypeError: unsupported operand type(s) for  : 'MeanAbsoluteError' and 'MeanSquaredError'

Can anyone suggest how to calculate combo loss for a certain output layer. This will help to create multiple weighted losses in combination, like this:

l_1 = 0.6
l_2 = 0.4
loss = l_1 * tf.keras.losses.MeanAbsoluteError()   l_2 *tf.keras.losses.MeanSquaredError()

I can then pass this loss variable to the model.compile() function

model.compile(optimizer=opt, 
                  loss=loss,
                  metrics = ['accuracy', sensitivity, specificity, tf.keras.metrics.RootMeanSquaredError(name='rmse')]
                )

CodePudding user response:

You can write a function and use MeanAbsoluteError() and MeanSquaredError() and compute custom_loss and return it:

import tensorflow as tf

# model = your_model
...

def custom_loss(y_true, y_pred):
    l_1 = 0.6
    l_2 = 0.4
    mae = tf.keras.losses.MeanAbsoluteError()
    mse = tf.keras.losses.MeanAbsoluteError()
    loss_mae = mae(y_true , y_pred)
    loss_mse = mse(y_true , y_pred)
    total_loss = l_1*loss_mae   l_2*loss_mse
    return total_loss


model.compile(loss=custom_loss, 
              optimizer='Adam')

model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS)
  • Related