Home > Software design >  Tensorflow seq2seq - keep max three checkpoints not working
Tensorflow seq2seq - keep max three checkpoints not working

Time:03-09

I am writing a seq2seq and would like to keep only three checkpoints; I thought I was implementing this with

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)
manager = tf.train.CheckpointManager(
    checkpoint, directory=checkpoint_dir, max_to_keep=3)

then

  # saving (checkpoint) the model every 2 epochs
  if (epoch   1) % 2 == 0:
    checkpoint.save(file_prefix = checkpoint_prefix)

I am disappointed and this is not working. Would you have a hint?

CodePudding user response:

Hmm maybe you should try restoring your checkpoint every time you begin training again:

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)
manager = tf.train.CheckpointManager(
    checkpoint, directory=checkpoint_dir, max_to_keep=3)

if manager.latest_checkpoint:
  checkpoint.restore(manager.latest_checkpoint)

CodePudding user response:

  • Related