Home > Back-end >  How to transfer weights to Conv2DTranspose layer in TensorFlow?
How to transfer weights to Conv2DTranspose layer in TensorFlow?

Time:09-07

from tensorflow import keras
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from keras.models import Sequential
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.layers import Dense
from keras.layers import Flatten
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical
from scipy import signal

(trainX, trainY) , (testX, testY) = keras.datasets.mnist.load_data()
trainX = trainX.reshape((trainX.shape[0], 28, 28, 1))
testX = testX.reshape((testX.shape[0], 28, 28, 1))
    # one hot encode target values
trainY = to_categorical(trainY)
testY = to_categorical(testY)


model=Sequential()
model.add(Conv2D(32, (3, 3),activation='relu',input_shape=(28, 28, 1)))
model.add(MaxPooling2D((1, 1)))
model.add(Flatten())
model.add(Dense(100, activation='relu'))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

history=model.fit(trainX,trainY,epochs=3)
history=model.evaluate(testX,testY)

After this, I get the weights of the conv layer

for layer in model.layers:
  if "con" in layer.name:
    filterForThis,bias=layer.get_weights()

Now the aim is to transfer the weights from Conv to new model

from keras.layers import Conv2DTranspose
model2 = Sequential()
model2.add(Conv2DTranspose(32,kernel_size=(3, 3), padding='same',activation='relu',input_shape=(28, 28, 1)))
model2.summary()
model2.layers[0].set_weights([filterForThis,bias])

Here I am getting the error

Layer conv2d_transpose_15 weight shape (3, 3, 32, 1) is not compatible with provided weight shape (3, 3, 1, 32)

Here I know reshaping would work, but how to do it the right way?

CodePudding user response:

I think numpy.swapaxes should do it:

print(filterForThis.shape)  # prints (3, 3, 1, 32)
filter2 = np.swapaxes(filterForThis, axis1=2, axis2=3)
print(filter2.shape)  # prints (3, 3, 32, 1)
model2.layers[0].set_weights([filter2,bias])
  • Related