Home > Enterprise >  My tensorflow custom loss function does NOT work. What's wrong?
My tensorflow custom loss function does NOT work. What's wrong?

Time:06-29

I'd like to get the loss function code workable. Looks like my code has problem at loop in batch that my loss fuction cannot get tensor's shape. My code is

backbone = tf.keras.applications.resnet50.ResNet50(include_top=False, weights=None, input_shape=INPUT_SHAPE)
x = tf.keras.layers.Conv2D(filters=5, kernel_size=3, padding='same', activation='sigmoid')(backbone.output)
model = tf.keras.Model(inputs=backbone.input, outputs=x)

def custom_loss(y_true, y_pred):
    batch_loss = 0.0
    batch_cnt = len(y_true)
    for i in range(batch_cnt):
        tf.autograph.experimental.set_loop_options(shape_invariants=[(batch_loss, tf.TensorShape([None]))])
        y_true_unit = y_true[i]
        y_pred_unit = y_pred[i]
        
        loss = 0.0
        for j in range(18):
            for k in range(32):
                conf_true = y_true_unit[j,k,0]
                cell_loss = tf.where(conf_true==1, 5 * tf.math.abs(y_true_unit - y_pred_unit), 0.5 * tf.math.abs(conf_true - y_pred_unit[j,k,0]))
                loss = tf.where(loss==0, tf.identity(cell_loss), tf.math.add(loss, cell_loss))
        batch_loss = tf.where(batch_loss==0, tf.identity(loss), tf.math.add(batch_loss, loss))
    return batch_loss / batch_cnt
sgd = tf.keras.optimizers.SGD(momentum=0.99)
model.compile(sgd, custom_loss)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5)
model.fit(
    train_batch,
    validation_data = val_batch, 
    epochs = 100,
    callbacks = [reduce_lr]
)

and error is

ValueError: in user code:

    File "C:\Users\user\anaconda3\lib\site-packages\keras\engine\training.py", line 1021, in train_function  *
        return step_function(self, iterator)
    File "C:\Users\user\AppData\Local\Temp\ipykernel_5952\2961884429.py", line 4, in yolo_loss  *
        for i in range(batch_cnt):

    ValueError: 'batch_loss' has shape () before the loop, which does not conform with the shape invariant (None,).

CodePudding user response:

You can't have complex logic inside a loss function, your loss function needs to be differentiable and for loops and if commands prevents this.

You need to write your loss function without for and if or it will never work.

To have an idea of the operations you can use and how to rewrite your loss function check out Keras Backend.

To give an idea: https://towardsdatascience.com/how-to-create-a-custom-loss-function-keras-3a89156ec69b

CodePudding user response:

Got solved the problem by below code, but I still don't know how it works well.

def yolo_loss(y_true, y_pred):
    batch_loss = 0.0
    count = len(y_true)
    for i in range(0, count):
        y_true_unit = tf.identity(y_true[i])
        y_pred_unit = tf.identity(y_pred[i])
        
        y_true_unit = tf.reshape(y_true_unit, [576, 5])
        y_pred_unit = tf.reshape(y_pred_unit, [576, 5])
        
        loss = 0
        for j in range(0,len(y_true_unit)):
            conf_true = tf.identity(y_true_unit[j,0])
            box_true = tf.identity(y_true_unit[j,1:])
            
            conf_pred = tf.identity(y_pred_unit[j,0])
            box_pred = tf.identity(y_pred_unit[j,1:])
            
            obj_exist = tf.ones_like(conf_true)
            if box_true[0] == 0.0 and box_true[1] == 0.0 and box_true[2] == 0.0 and box_true[3] == 0.0:
                obj_exist = tf.zeros_like(conf_true)
            
            conf_err = tf.math.pow(tf.math.subtract(conf_true, conf_pred), 2)
            local_err_x = tf.math.pow(tf.math.subtract(box_true[0], box_pred[0]), 2)
            local_err_y = tf.math.pow(tf.math.subtract(box_true[1], box_pred[1]), 2)
            local_err_w = tf.math.pow(tf.math.subtract(tf.sqrt(box_true[2]), tf.sqrt(box_pred[2])), 2)
            local_err_h = tf.math.pow(tf.math.subtract(tf.sqrt(box_true[3]), tf.sqrt(box_pred[3])), 2)
            
            if tf.math.is_nan(conf_err) == True:
                conf_err = tf.zeros_like(conf_err, dtype=tf.float32)
            if tf.math.is_nan(local_err_w) == True:
                local_err_w = tf.zeros_like(local_err_w, dtype=tf.float32)
            if tf.math.is_nan(local_err_h) == True:
                local_err_h = tf.zeros_like(local_err_h, dtype=tf.float32)
            
            local_err_coord = tf.math.add(local_err_x, local_err_y)
            local_err_shape = tf.math.add(local_err_w, local_err_h)
            local_err = tf.math.add(local_err_coord, local_err_shape)
            total_err = tf.math.add(conf_err, local_err)
            
            weighted_err = tf.math.multiply(total_err, 5.0)
            weighted_err = tf.math.multiply(weighted_err, obj_exist)
            
            if loss == 0:
                loss = tf.identity(weighted_err)
            else:
                loss = tf.math.add(loss, weighted_err)
        
        if batch_loss == 0:
            batch_loss = tf.identity(loss)
        else:
            batch_loss = tf.math.add(batch_loss, loss)
        
    return tf.math.divide(batch_loss, float(count))
  • Related