Home > Enterprise >  How to save and resume training a GAN with multiple model parts with Tensorflow 2/ Keras
How to save and resume training a GAN with multiple model parts with Tensorflow 2/ Keras

Time:10-28

I am currently trying to add a feature to interrupt and resume training on a GAN created form this example code: https://machinelearningmastery.com/how-to-develop-an-auxiliary-classifier-gan-ac-gan-from-scratch-with-keras/

I managed to get it working in a way where I save the weights of the entire composite GAN in the summarize_performance function, which gets triggered every 10 epochs, like this:

# save all weights
filename3 = 'weights_d.h5' % (step 1)
gan_model.save_weights(filename3)
print('>Saved: %s and %s and %s' % (filename1, filename2, filename3))

which is loaded in a function I added to the start of the program called load_model, which takes the architecture of the gan built like normal, but updates it's weights to the most recent values, like this:

#load model from file and return startBatch number
def load_model(gan_model):
   start_batch = 0
   files = glob.glob("./weights_0*.h5")
   if(len(files) > 0 ):
       most_recent_file = files[len(files)-1]
       gan_model.load_weights(most_recent_file)
       #TODO: breaks if using more than 8 digits for batches
       startBatch = int(most_recent_file[10:18])
       if (start_batch != 0):
           print("> found existing weights; starting at batch %d" % start_batch)
   return start_batch

where the start_batch gets passed to the train function in order to skip the already completed epochs.

While this weight saving approach does "work", I still think that my approach here is wrong since I've discovered that the weight data obviously does not include the optimizer status of the GAN, hence the training does not continue as it would if it hadn't been interrupted.

The way I've found to save progress while also saving optimizer status is apparently done by saving the entire model instead of just the weights

Here I run into a problem since in a GAN I dont just have one model which I train but I have 3 models:

  • The generator model g_model
  • The discriminator model d_model
  • and the composite GAN model gan_model

which are all connected and dependant on each other. If I did the naive approach and saved and restored each of these part models individually I'd end up having 3 seperate disjointed models instead of a GAN

Is there a way to save and restore the entire GAN in a way that would let me resume training as if no interruption had occured?

CodePudding user response:

Maybe consider using tf.train.Checkpoint, if you would like to restore your entire GAN:

### In your training loop

checkpoint_dir = '/checkpoints'
checkpoint = tf.train.Checkpoint(gan_optimizer=gan_optimizer,
                            discriminator_optimizer=discriminator_optimizer,
                                  generator=generator,
                                  discriminator=discriminator
                                  gan_model = gan_model)
  
ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if ckpt_manager.latest_checkpoint:
    checkpoint.restore(ckpt_manager.latest_checkpoint)  
    print ('Latest checkpoint restored!!')

....
....


if (epoch   1) % 40 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch 1,ckpt_save_path))

### After x number of epochs, just save your generator model for inference.

generator.save('your_model.h5')

You can also consider getting rid of the composite model completely. Here is an example of what I mean.

  • Related