Home > Software design >  error in training the GAN neural network (ValueError: Dimensions must be equal, but are N and M )
error in training the GAN neural network (ValueError: Dimensions must be equal, but are N and M )

Time:07-29

Trying to run SRGAN for training, however, I get the error Dimensions must be equal. To be honest, I don't really understand the problem. The problem seems to me to be with the discriminator model, but what it is I can't quite put my finger on it.

class SuperResolutionGAN:
    def generatorModel(self):
        inputLayer = Input(shape=(32, 32, 1))

        layerSet = Conv2D(filters=64, kernel_size=9, activation='relu',  padding='same')(inputLayer)
        residualOutput = skipConnection = PReLU(shared_axes=[1, 2])(layerSet)

        for _ in range(5):
            layerSet = Conv2D(filters=64, kernel_size=3, activation='relu', padding='same')(layerSet)
            layerSet = BatchNormalization(momentum=0.8)(layerSet)
            #layerSet = PReLU(shared_axes=[1, 2])(layerSet)
            layerSet = Conv2D(filters=64, kernel_size=3, activation='relu', padding='same')(layerSet)
            layerSet = BatchNormalization(momentum=0.8)(layerSet)
            residualOutput = Add()([residualOutput, layerSet])

        layerSet = Conv2D(filters=64, kernel_size=3, activation='relu', padding='same')(residualOutput)
        layerSet = BatchNormalization()(layerSet)
        layerSet = Add()([skipConnection, layerSet])

        for _ in range(2):
            layerSet = Conv2D(filters=256, kernel_size=3, activation='relu', padding='same')(layerSet)
            layerSet = tf.nn.depth_to_space(layerSet, 2)
            layerSet = PReLU(shared_axes=[1, 2])(layerSet)

        outputLayer = Conv2D(filters=1, kernel_size=9, padding='same', activation='tanh')(layerSet)

        self.__generatorModel = Model(inputLayer, outputLayer)

        self.__generatorModel.compile(optimizer='adam', loss='mse')
        print(self.__generatorModel.summary())

    def discriminatorModel(self, inputShape=(128, 128, 1)):
        inputLayer = Input(shape=inputShape)
        layerSet = Lambda(lambda inputValue: inputValue / 127.5 - 1)(inputLayer)

        layerSet = Conv2D(filters=64, kernel_size=3, activation='relu', padding='same')(layerSet)
        layerSet = BatchNormalization(momentum=0.8)(layerSet)
        #layerSet = LeakyReLU(alpha=0.2)(layerSet)

        filterNumber = 64
        for i in range(4):
            print(filterNumber)
            layerSet = Conv2D(filters=filterNumber, kernel_size=3, strides=(2, 2), activation='relu',padding='same')(layerSet)
            #layerSet = LeakyReLU(alpha=0.2)(layerSet)
            layerSet = Conv2D(filters=filterNumber, kernel_size=3, activation='relu',padding='same')(layerSet)
            layerSet = BatchNormalization(momentum=0.8)(layerSet)
            #layerSet = LeakyReLU(alpha=0.2)(layerSet)
            filterNumber *= 2

        layerSet = Flatten()(layerSet)

        layerSet = Dense(1024)(layerSet)
        #layerSet = LeakyReLU(alpha=0.2)(layerSet)
        layerSet = Dense(1)(layerSet)

        self.__discriminatorModel = Model(inputLayer, layerSet)

        self.__discriminatorModel.compile(optimizer='adam', loss='mse')
        print(self.__discriminatorModel.summary())

However, in the Shape checks, all incoming and outgoing arrays have the correct dimensions.

fakeValidationData(23760, 128, 128, 1)
realValidationData(23760, 128, 128, 1)
generatedData(23760, 128, 128, 1)

     ValueError: Dimensions must be equal, but are 23760 and 128 for '{{node mean_squared_error/SquaredDifference}} = SquaredDifference[T=DT_FLOAT](model_32/dense_22/BiasAdd, IteratorGetNext:1)' with input shapes: [23760,1], [23760,128,128,1].

Here is the whole code.

def train(self, imageDataPath:string, inputShape=(32, 32, 1), 

    self.__discriminatorModel.trainable = False
    inputLayer = Input(shape=inputShape)
    outputLayer = self.__discriminatorModel(self.__generatorModel(inputLayer))
        
    self.__superResolutionModel = Model(inputLayer, outputLayer)
    self.__superResolutionModel.compile(optimizer='adam', loss='mse')

    fakeValidationData = numpy.zeros((self.__targetTrain.shape[0], 128, 128, 1), dtype=float)
    for i in range(self.__targetTrain.shape[0]):
      fakeValidationData[i] = numpy.zeros((128, 128, 1), dtype=float)

      realValidationData = numpy.zeros((self.__targetTrain.shape[0], 128, 128, 1), dtype=float)
    for i in range(self.__targetTrain.shape[0]):
      realValidationData[i] = numpy.ones((128, 128, 1), dtype=float)

    generatedData = numpy.zeros((self.__targetTrain.shape[0], 128, 128, 1), dtype=float)
    for i in range(self.__sourceTrain.shape[0]):
      input = numpy.zeros((1, self.__sourceTrain.shape[1], self.__sourceTrain.shape[2], 1), dtype=float)
      imageToProzess = self.__sourceTrain[i]
      input[0, :, :, 0] = imageToProzess.astype( dtype=float ) / 255
      output = self.__generatorModel.predict(input)
      generatedData[i] = output
                
      for currentEpoch in range(epochsNumber):
        print('EPOCH : ', currentEpoch)
            
        self.__discriminatorModel.train_on_batch(generatedData, fakeValidationData)
        self.__discriminatorModel.train_on_batch(self.__targetTrain, realValidationData)
        self.__discriminatorModel.save_weights('diskriminatorPretrainedWeights.h5')
        self.__superResolutionModel.train_on_batch(self.__targetTrain, realValidationData)
        self.__superResolutionModel.save_weights(weightsPath)
'
Model: "model_31"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_33 (InputLayer)          [(None, 32, 32, 1)]  0           []                               
                                                                                                  
 conv2d_264 (Conv2D)            (None, 32, 32, 64)   5248        ['input_33[0][0]']               
                                                                                                  
 conv2d_265 (Conv2D)            (None, 32, 32, 64)   36928       ['conv2d_264[0][0]']             
                                                                                                  
 batch_normalization_176 (Batch  (None, 32, 32, 64)  256         ['conv2d_265[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 conv2d_266 (Conv2D)            (None, 32, 32, 64)   36928       ['batch_normalization_176[0][0]']
                                                                                                  
 batch_normalization_177 (Batch  (None, 32, 32, 64)  256         ['conv2d_266[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 conv2d_267 (Conv2D)            (None, 32, 32, 64)   36928       ['batch_normalization_177[0][0]']
                                                                                                  
 batch_normalization_178 (Batch  (None, 32, 32, 64)  256         ['conv2d_267[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 conv2d_268 (Conv2D)            (None, 32, 32, 64)   36928       ['batch_normalization_178[0][0]']
                                                                                                  
 batch_normalization_179 (Batch  (None, 32, 32, 64)  256         ['conv2d_268[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 conv2d_269 (Conv2D)            (None, 32, 32, 64)   36928       ['batch_normalization_179[0][0]']
                                                                                                  
 batch_normalization_180 (Batch  (None, 32, 32, 64)  256         ['conv2d_269[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 conv2d_270 (Conv2D)            (None, 32, 32, 64)   36928       ['batch_normalization_180[0][0]']
                                                                                                  
 batch_normalization_181 (Batch  (None, 32, 32, 64)  256         ['conv2d_270[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 conv2d_271 (Conv2D)            (None, 32, 32, 64)   36928       ['batch_normalization_181[0][0]']
                                                                                                  
 batch_normalization_182 (Batch  (None, 32, 32, 64)  256         ['conv2d_271[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 conv2d_272 (Conv2D)            (None, 32, 32, 64)   36928       ['batch_normalization_182[0][0]']
                                                                                                  
 p_re_lu_73 (PReLU)             (None, 32, 32, 64)   64          ['conv2d_264[0][0]']             
                                                                                                  
 batch_normalization_183 (Batch  (None, 32, 32, 64)  256         ['conv2d_272[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 add_66 (Add)                   (None, 32, 32, 64)   0           ['p_re_lu_73[0][0]',             
                                                                  'batch_normalization_177[0][0]']
                                                                                                  
 conv2d_273 (Conv2D)            (None, 32, 32, 64)   36928       ['batch_normalization_183[0][0]']
                                                                                                  
 add_67 (Add)                   (None, 32, 32, 64)   0           ['add_66[0][0]',                 
                                                                  'batch_normalization_179[0][0]']
                                                                                                  
 batch_normalization_184 (Batch  (None, 32, 32, 64)  256         ['conv2d_273[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 add_68 (Add)                   (None, 32, 32, 64)   0           ['add_67[0][0]',                 
                                                                  'batch_normalization_181[0][0]']
                                                                                                  
 conv2d_274 (Conv2D)            (None, 32, 32, 64)   36928       ['batch_normalization_184[0][0]']
                                                                                                  
 add_69 (Add)                   (None, 32, 32, 64)   0           ['add_68[0][0]',                 
                                                                  'batch_normalization_183[0][0]']
                                                                                                  
 batch_normalization_185 (Batch  (None, 32, 32, 64)  256         ['conv2d_274[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 add_70 (Add)                   (None, 32, 32, 64)   0           ['add_69[0][0]',                 
                                                                  'batch_normalization_185[0][0]']
                                                                                                  
 conv2d_275 (Conv2D)            (None, 32, 32, 64)   36928       ['add_70[0][0]']                 
                                                                                                  
 batch_normalization_186 (Batch  (None, 32, 32, 64)  256         ['conv2d_275[0][0]']             
 Normalization)                                                                                   
                                                                                                  
 add_71 (Add)                   (None, 32, 32, 64)   0           ['p_re_lu_73[0][0]',             
                                                                  'batch_normalization_186[0][0]']
                                                                                                  
 conv2d_276 (Conv2D)            (None, 32, 32, 256)  147712      ['add_71[0][0]']                 
                                                                                                  
 tf.nn.depth_to_space_22 (TFOpL  (None, 64, 64, 64)  0           ['conv2d_276[0][0]']             
 ambda)                                                                                           
                                                                                                  
 p_re_lu_74 (PReLU)             (None, 64, 64, 64)   64          ['tf.nn.depth_to_space_22[0][0]']
                                                                                                  
 conv2d_277 (Conv2D)            (None, 64, 64, 256)  147712      ['p_re_lu_74[0][0]']             
                                                                                                  
 tf.nn.depth_to_space_23 (TFOpL  (None, 128, 128, 64  0          ['conv2d_277[0][0]']             
 ambda)                         )                                                                 
                                                                                                  
 p_re_lu_75 (PReLU)             (None, 128, 128, 64  64          ['tf.nn.depth_to_space_23[0][0]']
                                )                                                                 
                                                                                                  
 conv2d_278 (Conv2D)            (None, 128, 128, 1)  5185        ['p_re_lu_75[0][0]']             
                                                                                                  
==================================================================================================
Total params: 715,073
Trainable params: 713,665
Non-trainable params: 1,408
__________________________________________________________________________________________________



Model: "model_32"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_34 (InputLayer)       [(None, 128, 128, 1)]     0         
                                                                 
 lambda_11 (Lambda)          (None, 128, 128, 1)       0         
                                                                 
 conv2d_279 (Conv2D)         (None, 128, 128, 64)      640       
                                                                 
 batch_normalization_187 (Ba  (None, 128, 128, 64)     256       
 tchNormalization)                                               
                                                                 
 conv2d_280 (Conv2D)         (None, 64, 64, 64)        36928     
                                                                 
 conv2d_281 (Conv2D)         (None, 64, 64, 64)        36928     
                                                                 
 batch_normalization_188 (Ba  (None, 64, 64, 64)       256       
 tchNormalization)                                               
                                                                 
 conv2d_282 (Conv2D)         (None, 32, 32, 128)       73856     
                                                                 
 conv2d_283 (Conv2D)         (None, 32, 32, 128)       147584    
                                                                 
 batch_normalization_189 (Ba  (None, 32, 32, 128)      512       
 tchNormalization)                                               
                                                                 
 conv2d_284 (Conv2D)         (None, 16, 16, 256)       295168    
                                                                 
 conv2d_285 (Conv2D)         (None, 16, 16, 256)       590080    
                                                                 
 batch_normalization_190 (Ba  (None, 16, 16, 256)      1024      
 tchNormalization)                                               
                                                                 
 conv2d_286 (Conv2D)         (None, 8, 8, 512)         1180160   
                                                                 
 conv2d_287 (Conv2D)         (None, 8, 8, 512)         2359808   
                                                                 
 batch_normalization_191 (Ba  (None, 8, 8, 512)        2048      
 tchNormalization)                                               
                                                                 
 flatten_11 (Flatten)        (None, 32768)             0         
                                                                 
 dense_21 (Dense)            (None, 1024)              33555456  
                                                                 
 dense_22 (Dense)            (None, 1)                 1025      
                                                                 
=================================================================
Total params: 38,281,729
Trainable params: 38,279,681
Non-trainable params: 2,048
_________________________________________________________________

Thank you in advance

CodePudding user response:

generatedData takes as first arg the inputs, and as second argument the targets... here:

self.__discriminatorModel.train_on_batch(generatedData, fakeValidationData)

you are feeding as target images, but you want the labels (for example np.ones(self.__targetTrain.shape[0]))

  • Related