Home > OS >  AttributeError: 'Functional' object has no attribute 'predict_segmentation' When
AttributeError: 'Functional' object has no attribute 'predict_segmentation' When

Time:12-07

I have successfully trained a Keras model like:

import tensorflow as tf
from keras_segmentation.models.unet import vgg_unet

# initaite the model
model = vgg_unet(n_classes=50, input_height=512, input_width=608)

# Train
model.train(
    train_images=train_images,
    train_annotations=train_annotations,
    checkpoints_path="/tmp/vgg_unet_1", epochs=5
)

And saved it in hdf5 format with:

tf.keras.models.save_model(model,'my_model.hdf5')

Then I load my model with

model=tf.keras.models.load_model('my_model.hdf5')

Finally I want to make a segmentation prediction on a new image with

out = model.predict_segmentation(
    inp=image_to_test,
    out_fname="/tmp/out.png"
)

I am getting the following error:

AttributeError: 'Functional' object has no attribute 'predict_segmentation'

What am I doing wrong ? Is it when I am saving my model or when I am loading it ?

Thanks !

CodePudding user response:

predict_segmentation isn't a function available in normal Keras models. It looks like it was added after the model was created in the keras_segmentation library, which might be why Keras couldn't load it again.

I think you have 2 options for this.

  1. You could use the line from the code I linked to manually add the function back to the model.
model.predict_segmentation = MethodType(keras_segmentation.predict.predict, model)
  1. You could create a new vgg_unet with the same arguments when you reload the model, and transfer the weights from your hdf5 file to that model as suggested in the Keras documentation.
model = vgg_unet(n_classes=50, input_height=512, input_width=608)
model.load_weights('my_model.hdf5')
  • Related