I created a basic model with a custom method - new_method
- and a custom attribute - testing
- that I want to save. Is it possible to do so using model.save()
? Below is an example of what I wish to accomplish.
@tf.keras.utils.register_keras_serializable()
class GreatClass(tf.keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.testing = 3424
self.dense = tf.keras.layers.Dense(100)
def get_config(self):
config = super().get_config()
config['testing'] = self.testing
config['dense'] = self.dense
return config
def new_method(self):
print('hello world')
def call(self, inputs):
return self.dense(inputs)
Below I create and save an instance of the above class.
model = GreatClass()
model.compile()
array = np.array([100,10])
model.predict(array)
model.save('testing')
I can save the model, but the loaded model does not have access to the new_method
method or testing
attribute.
loaded_model = tf.keras.models.load_model("testing")
reconstructed_model.new_method()
AttributeError: 'Custom>GreatClass' object has no attribute 'new_method'
reconstructed_model.get_vars()
'Custom>GreatClass' object has no attribute 'testing'
Is it possible to save custom methods and attributes using model.save()
?
CodePudding user response:
Save your attributes that you want to serialize as tf.Variables
and use tf.function
with an input_signature
to save and load methods. See here for more details. Here is an example:
import tensorflow as tf
class GreatClass(tf.keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.testing = tf.Variable(3424, trainable=False)
self.dense = tf.keras.layers.Dense(100)
@tf.function(input_signature=[])
def new_method(self):
tf.print('hello world')
def call(self, inputs):
return self.dense(inputs)
model = GreatClass()
model.compile()
model(tf.random.normal((1, 10)))
model.save('testing')
loaded_model = tf.keras.models.load_model("testing")
loaded_model.new_method()
loaded_model.testing
hello world
<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=3424>