I am using VGG16 for image segmentation with the loss function "balanced categorical entropy" using the code
beta=0.5
def balanced_cross_entropy(beta):
def loss(y_true, y_pred):
weight_a = beta * tf.cast(y_true, tf.float32)
weight_b = (1 - beta) * tf.cast(1 - y_true, tf.float32)
o = (tf.math.log1p(tf.exp(-tf.abs(y_pred))) tf.nn.relu(-y_pred)) * (weight_a weight_b) y_pred * weight_b
return tf.reduce_mean(o)
return loss
Everything works fine. Now I save this model in the h5 file using the code.
vgg.save('vgg.h5')
But when I use the load_model from Keras
model = load_model('vgg.h5', custom_objects={'balanced_cross_entropy(beta)': balanced_cross_entropy(beta)})
I encounter an error.
Unknown loss function: loss. Please ensure this object is passed to the `custom_objects` argument.
Can anybody help, I suspect the problem may be due to beta?
CodePudding user response:
If you want to only perform inference, you can avoid this problem by specifying
model = load_model('vgg.h5',compile=False)
Otherwise, you need to load the in the following way:
model = load_model("vgg.h5", custom_objects={'loss': balanced_cross_entropy(beta)})
; in your code you wrote balanced_cross_entropy(beta)
instead of loss
.
Short explanation:
The name of the key in custom_object
is actually the name of the inner function (which is in fact returned by balanced_cross_entropy(beta)
; the name of the outer function is actually the value of the <key,value>
pair in the custom_object
dictionary.