I am writing a seq2seq and would like to keep only three checkpoints; I thought I was implementing this with
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
manager = tf.train.CheckpointManager(
checkpoint, directory=checkpoint_dir, max_to_keep=3)
then
# saving (checkpoint) the model every 2 epochs
if (epoch 1) % 2 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
I am disappointed and this is not working. Would you have a hint?
CodePudding user response:
Hmm maybe you should try restoring your checkpoint every time you begin training again:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
manager = tf.train.CheckpointManager(
checkpoint, directory=checkpoint_dir, max_to_keep=3)
if manager.latest_checkpoint:
checkpoint.restore(manager.latest_checkpoint)
CodePudding user response: