Home > database >  tensorflow.keras.Model inherit
tensorflow.keras.Model inherit

Time:02-22

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class KerasSupervisedModelWrapper(keras.Model):
    def __init__(self, batch_size, **kwargs):
        super().__init__()
        self.batch_size = batch_size
        
    def summary(self, input_shape): # temporary fix for a bug
        x = layers.Input(shape=input_shape)
        model = keras.Model(inputs=[x], outputs=self.call(x))
        return model.summary()

class ExampleModel(KerasSupervisedModelWrapper):
    def __init__(self, batch_size):
        super().__init__(batch_size)
        self.conv1 = layers.Conv2D(32, kernel_size=(3, 3), activation='relu')

    def call(self, x):
        x = self.conv1(x)
        return x
        
model = MyModel(15)
model.summary([28, 28, 1])

output:

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d_2 (Conv2D)           (None, 26, 26, 32)        320       
                                                                 
=================================================================
Total params: 320
Trainable params: 320
Non-trainable params: 0
_________________________________________________________________

I'm writting a wrapper for keras model to pre-define some useful method and variables as above.
And I'd like to modify the wrapper to get some layers to compose model as the keras.Sequential does.
Therefore, I added Sequential method that assigns new call method as below.

class KerasSupervisedModelWrapper(keras.Model):
    ...(continue)...

    @staticmethod
    def Sequential(layers, **kwargs):
        model = KerasSupervisedModelWrapper(**kwargs)
        pipe = keras.Sequential(layers)
        def call(self, x):
            return pipe(x)
        model.call = call
        return model

However, it seems not working as I intended. Instead, it shows below error message.

model = KerasSupervisedModelWrapper.Sequential([
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu")
], batch_size=15)
model.summary((28, 28, 1))

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_91471/2826773946.py in <module>
      1 # model.build((None, 28, 28, 1))
      2 # model.compile('adam', loss=keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'])
----> 3 model.summary((28, 28, 1))

/tmp/ipykernel_91471/3696340317.py in summary(self, input_shape)
     10     def summary(self, input_shape): # temporary fix for a bug
     11         x = layers.Input(shape=input_shape)
---> 12         model = keras.Model(inputs=[x], outputs=self.call(x))
     13         return model.summary()
     14 

TypeError: call() missing 1 required positional argument: 'x'

What can I do for the wrapper to get keras.Sequential model while usuing other properties?

CodePudding user response:

You could try something like this:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class KerasSupervisedModelWrapper(keras.Model):
    def __init__(self, batch_size, **kwargs):
        super().__init__()
        self.batch_size = batch_size
        
    def summary(self, input_shape): # temporary fix for a bug
        x = layers.Input(shape=input_shape)
        model = keras.Model(inputs=[x], outputs=self.call(x))
        return model.summary()

    @staticmethod
    def Sequential(layers, **kwargs):
        model = KerasSupervisedModelWrapper(**kwargs)
        pipe = keras.Sequential(layers)
        model.call = pipe
        return model

class ExampleModel(KerasSupervisedModelWrapper):
    def __init__(self, batch_size):
        super().__init__(batch_size)
        self.conv1 = layers.Conv2D(32, kernel_size=(3, 3), activation='relu')

    def call(self, x):
        x = self.conv1(x)
        return x
        
model = ExampleModel(15)
model.summary([28, 28, 1])

model = KerasSupervisedModelWrapper.Sequential([
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu")
], batch_size=15)
model.summary((28, 28, 1))

print(model(tf.random.normal((1, 28, 28, 1))).shape)
Model: "model_9"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_14 (InputLayer)       [(None, 28, 28, 1)]       0         
                                                                 
 conv2d_17 (Conv2D)          (None, 26, 26, 32)        320       
                                                                 
=================================================================
Total params: 320
Trainable params: 320
Non-trainable params: 0
_________________________________________________________________
Model: "model_10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_15 (InputLayer)       [(None, 28, 28, 1)]       0         
                                                                 
 sequential_8 (Sequential)   (None, 26, 26, 32)        320       
                                                                 
=================================================================
Total params: 320
Trainable params: 320
Non-trainable params: 0
_________________________________________________________________
(1, 26, 26, 32)
  • Related