I have the following wrapper class:
from tensorflow.keras.models import Sequential
class NeuralNet(Sequential):
def __init__(self, **kwargs):
super().__init__(**kwargs)
I can fit and save the model without problems, but when I try to load it:
from tensorflow.keras.models import load_model
model = load_model('model.h5')
I get:
--> 296 raise ValueError('Unknown ' printable_module_name ': ' class_name)
297
298 cls_config = config['config']
ValueError: Unknown layer: NeuralNet
I'd like to find a way to solve this error keeping the wrapper class.
CodePudding user response:
you can also define a custom_objects
dictionary as a class attribute in the NeuralNet class and pass it to the load_model
function like this:
from tensorflow.keras.models import Sequential
class NeuralNet(Sequential):
custom_objects = {"NeuralNet": NeuralNet}
def __init__(self, **kwargs):
super().__init__(**kwargs)
model = load_model('model.h5', custom_objects=NeuralNet.custom_objects)