Home > Blockchain >  Trouble with loss function tf.nn.weighted_cross_entropy_with_logits
Trouble with loss function tf.nn.weighted_cross_entropy_with_logits

Time:09-07

I am trying to train a u-net network with binary targets. The usual Binary Cross Entropy loss does not perform well, since the lables are very imbalanced (many more 0 pixels than 1s). So I want to punish false negatives more. But tensorflow doesn't have a ready-made weighted binary cross entropy. Since I didn't want to write a loss from scratch, I'm trying to use tf.nn.weighted_cross_entropy_with_logits. to be able to easily feed the loss to model.compile function, I'm writing this wrapper:

def loss_wrapper(y,x):
   x = tf.cast(x,'float32')
   loss = tf.nn.weighted_cross_entropy_with_logits(y,x,pos_weight=10)
   return loss

However, regardless of casting x to float, I'm still getting the error:

TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type int32 of argument 'x'.

when the tf loss is called. Can someone explain what's happening?

CodePudding user response:

If x represents your predictions. It probably already has the type float32. I think you need to cast y, which is presumably your labels. So:

loss = tf.nn.weighted_cross_entropy_with_logits(tf.cast(y, dtype=tf.float32),x,pos_weight=10)
  • Related