I am trying to develop a GAN, I have created the generator and the discriminator and now I am trying to train it. I am using the Mnist dataset but I plan to use some more. The problem is that when I train it I get this error: Input 0 of layer "conv2d_transpose_4" is incompatible with the layer: expected ndim=4, found ndim=2. Full shape received: (None, 100)
I don't really know if the problem is in the networks or in the data used to train the GAN, can someone tell me how should I train it or where the problem is?
imports:
import tensorflow
import keras
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Flatten, Input, BatchNormalization,
LeakyReLU, Reshape
from keras.layers import Conv2D, Conv2DTranspose, MaxPooling2D
from tensorflow.keras.optimizers import Adam
from keras import backend as K
from keras.utils import np_utils
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
generator:
def generator():
model = Sequential()
model.add(Conv2DTranspose(32, (3,3), strides=(2,
2), activation='relu', use_bias=False,
input_shape=img_shape))
model.add(BatchNormalization(momentum=0.3))
model.add(Conv2DTranspose(128,(3,3),strides=
(2,2), activation='relu', padding='same',
use_bias=False))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2DTranspose(128,(3,3),strides=
(2,2), activation='relu', padding='same',
use_bias=False))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.5))
model.add(BatchNormalization(momentum=0.3))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2DTranspose(128,(3,3),strides=
(2,2), activation='relu', padding='same',
use_bias=False))
model.add(BatchNormalization())
model.add(Dense(512,
activation=LeakyReLU(alpha=0.2)))
model.add(BatchNormalization(momentum=0.7))
model.build()
model.summary()
return model
discriminator:
def discriminator():
model = Sequential()
model.add(Conv2D(32, (5,5), strides=(2, 2),
activation='relu', use_bias=False,
input_shape=img_shape))
model.add(BatchNormalization(momentum=0.3))
model.add(Conv2D(64,(5,5),strides=(2,2),
activation='relu', padding='same',
use_bias=False))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(64,(5,5),strides=(2,2),
activation='relu', padding='same',
use_bias=False))
model.add(Dropout(0.5))
model.add(BatchNormalization(momentum=0.3))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512,
activation=LeakyReLU(alpha=0.2)))
model.add(Flatten())
model.add(BatchNormalization(momentum=0.7))
model.add(Dense(1, activation='sigmoid'))
model.build()
model.summary()
return model
train function:
def train(epochs, batch_size, save_interval):
(x_train, _), (_, _) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=3)
half_batch = int(batch_size / 2)
for epoch in range(epochs):
idx = np.random.randint(0, x_train.shape[0], half_batch)
imgs = x_train[idx]
noise = np.random.normal(0, 1, (half_batch, 100))
gen_imgs = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
noise = np.random.normal(0, 1, (batch_size, 100))
valid_y = np.array([1] * batch_size)
g_loss = combined.train_on_batch(noise, valid_y)
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
if epoch % save_interval == 0:
save_imgs(epoch)
Data used:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
optimizer = Adam(0.0002, 0.5)
discriminator = discriminator()
discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
generator = generator()
generator.compile(loss='binary_crossentropy',
optimizer=optimizer)
z = Input(shape=(100,))
img = generator(z) #error
discriminator.trainable = False
valid = discriminator(img)
combined = Model(z, valid)
combined.compile(loss='binary_crossentropy',
optimizer=optimizer)
train(epochs=100000, batch_size=32,
save_interval=10)
generator.save('generator_model.h5')
CodePudding user response:
The problem is coming from the first Flatten
layer in the Discriminator
model, which is converting your n-dimensional tensor to a 1D tensor. Since a MaxPooling2D
layer cannot work with a 1D tensor, you are seeing that error. If you remove it, it should work:
def discriminator():
model = Sequential()
model.add(Conv2D(32, (5,5), strides=(2, 2),
activation='relu', use_bias=False,
input_shape=img_shape))
model.add(BatchNormalization(momentum=0.3))
model.add(Conv2D(64,(5,5),strides=(2,2),
activation='relu', padding='same',
use_bias=False))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(64,(5,5),strides=(2,2),
activation='relu', padding='same',
use_bias=False))
model.add(Dropout(0.5))
model.add(BatchNormalization(momentum=0.3))
model.add(LeakyReLU(alpha=0.2))
model.add(Flatten())
model.add(Dense(512,
activation=LeakyReLU(alpha=0.2)))
model.add(BatchNormalization(momentum=0.7))
model.add(Dense(1, activation='sigmoid'))
model.build()
model.summary()
return model
Update 1:
Try rewriting your Generator
model like this:
def generator():
model = Sequential()
model = tf.keras.Sequential()
model.add(Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size
model.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)
model.summary()
return model
It should then work, but you should definitely go through this