import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class KerasSupervisedModelWrapper(keras.Model):
def __init__(self, batch_size, **kwargs):
super().__init__()
self.batch_size = batch_size
def summary(self, input_shape): # temporary fix for a bug
x = layers.Input(shape=input_shape)
model = keras.Model(inputs=[x], outputs=self.call(x))
return model.summary()
class ExampleModel(KerasSupervisedModelWrapper):
def __init__(self, batch_size):
super().__init__(batch_size)
self.conv1 = layers.Conv2D(32, kernel_size=(3, 3), activation='relu')
def call(self, x):
x = self.conv1(x)
return x
model = MyModel(15)
model.summary([28, 28, 1])
output:
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 28, 28, 1)] 0
conv2d_2 (Conv2D) (None, 26, 26, 32) 320
=================================================================
Total params: 320
Trainable params: 320
Non-trainable params: 0
_________________________________________________________________
I'm writting a wrapper for keras model to pre-define some useful method and variables as above.
And I'd like to modify the wrapper to get some layers to compose model as the keras.Sequential
does.
Therefore, I added Sequential
method that assigns new call
method as below.
class KerasSupervisedModelWrapper(keras.Model):
...(continue)...
@staticmethod
def Sequential(layers, **kwargs):
model = KerasSupervisedModelWrapper(**kwargs)
pipe = keras.Sequential(layers)
def call(self, x):
return pipe(x)
model.call = call
return model
However, it seems not working as I intended. Instead, it shows below error message.
model = KerasSupervisedModelWrapper.Sequential([
layers.Conv2D(32, kernel_size=(3, 3), activation="relu")
], batch_size=15)
model.summary((28, 28, 1))
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_91471/2826773946.py in <module>
1 # model.build((None, 28, 28, 1))
2 # model.compile('adam', loss=keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'])
----> 3 model.summary((28, 28, 1))
/tmp/ipykernel_91471/3696340317.py in summary(self, input_shape)
10 def summary(self, input_shape): # temporary fix for a bug
11 x = layers.Input(shape=input_shape)
---> 12 model = keras.Model(inputs=[x], outputs=self.call(x))
13 return model.summary()
14
TypeError: call() missing 1 required positional argument: 'x'
What can I do for the wrapper to get keras.Sequential
model while usuing other properties?
CodePudding user response:
You could try something like this:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class KerasSupervisedModelWrapper(keras.Model):
def __init__(self, batch_size, **kwargs):
super().__init__()
self.batch_size = batch_size
def summary(self, input_shape): # temporary fix for a bug
x = layers.Input(shape=input_shape)
model = keras.Model(inputs=[x], outputs=self.call(x))
return model.summary()
@staticmethod
def Sequential(layers, **kwargs):
model = KerasSupervisedModelWrapper(**kwargs)
pipe = keras.Sequential(layers)
model.call = pipe
return model
class ExampleModel(KerasSupervisedModelWrapper):
def __init__(self, batch_size):
super().__init__(batch_size)
self.conv1 = layers.Conv2D(32, kernel_size=(3, 3), activation='relu')
def call(self, x):
x = self.conv1(x)
return x
model = ExampleModel(15)
model.summary([28, 28, 1])
model = KerasSupervisedModelWrapper.Sequential([
layers.Conv2D(32, kernel_size=(3, 3), activation="relu")
], batch_size=15)
model.summary((28, 28, 1))
print(model(tf.random.normal((1, 28, 28, 1))).shape)
Model: "model_9"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_14 (InputLayer) [(None, 28, 28, 1)] 0
conv2d_17 (Conv2D) (None, 26, 26, 32) 320
=================================================================
Total params: 320
Trainable params: 320
Non-trainable params: 0
_________________________________________________________________
Model: "model_10"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_15 (InputLayer) [(None, 28, 28, 1)] 0
sequential_8 (Sequential) (None, 26, 26, 32) 320
=================================================================
Total params: 320
Trainable params: 320
Non-trainable params: 0
_________________________________________________________________
(1, 26, 26, 32)