Home > Enterprise >  Saving Attributes and Methods with Tensorflow for custom Model
Saving Attributes and Methods with Tensorflow for custom Model

Time:08-05

I created a basic model with a custom method - new_method - and a custom attribute - testing - that I want to save. Is it possible to do so using model.save()? Below is an example of what I wish to accomplish.

@tf.keras.utils.register_keras_serializable()
class GreatClass(tf.keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.testing = 3424
        self.dense = tf.keras.layers.Dense(100)
        
    def get_config(self):
        config = super().get_config()
        config['testing'] = self.testing
        config['dense'] = self.dense
        return config
    
    def new_method(self):
        print('hello world')
    
    def call(self, inputs):
        return self.dense(inputs)

Below I create and save an instance of the above class.

model = GreatClass()
model.compile()

array = np.array([100,10])
model.predict(array)

model.save('testing')

I can save the model, but the loaded model does not have access to the new_method method or testing attribute.

loaded_model = tf.keras.models.load_model("testing")
reconstructed_model.new_method()

AttributeError: 'Custom>GreatClass' object has no attribute 'new_method'

reconstructed_model.get_vars()

'Custom>GreatClass' object has no attribute 'testing'

Is it possible to save custom methods and attributes using model.save()?

CodePudding user response:

Save your attributes that you want to serialize as tf.Variables and use tf.function with an input_signature to save and load methods. See here for more details. Here is an example:

import tensorflow as tf

class GreatClass(tf.keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.testing = tf.Variable(3424, trainable=False)
        self.dense = tf.keras.layers.Dense(100)

    @tf.function(input_signature=[])
    def new_method(self):
        tf.print('hello world')
    
    def call(self, inputs):
        return self.dense(inputs)

model = GreatClass()
model.compile()
model(tf.random.normal((1, 10)))
model.save('testing')


loaded_model = tf.keras.models.load_model("testing")
loaded_model.new_method()
loaded_model.testing
hello world
<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=3424>
  • Related