Home > Blockchain >  plot and save history of kfold training
plot and save history of kfold training

Time:06-17

i am trying to train model with kfold cross validation , now i want to keep history for plotting and saving the history. how can i do that?

it seems that some questions post answers of this question but i want to save and plot all history once, not with parted files

num_folds = 10
kfold = KFold(n_splits=num_folds, shuffle=True)
# K-fold Cross Validation model evaluation
fold_no = 1
a = []
for train, test in kfold.split(X, label):
  print("---"*20)
  history = siamese.fit(
      [tf.gather(X[:,0], train),tf.gather(X[:,1], train)],
      tf.gather(label, train),
      validation_data=([tf.gather(X[:,0], test),tf.gather(X[:,1], test)], tf.gather(label, test)),
      batch_size=batch_size,
      epochs=epochs,
  )
  a.append(history)

CodePudding user response:

Blatantly using this answer you could try saving history separately. I used enumerate to give each save a specific filename and you could add a i % N == 0 to only save each N times.

for i (train, test) in enumerate(kfold.split(X, label)):
  print("---"*20)
  history = siamese.fit(
      [tf.gather(X[:,0], train),tf.gather(X[:,1], train)],
      tf.gather(label, train),
      validation_data=([tf.gather(X[:,0], test),tf.gather(X[:,1], test)], tf.gather(label, test)),
      batch_size=batch_size,
      epochs=epochs,
  )
  a.append(history)
  # you could add if i%N==0:..
  with open(f'/trainHistoryDict_{i}', 'wb') as file_pi:
      pickle.dump(history.history, file_pi)

CodePudding user response:

by adding a dictionary to my code i can use all histories at once

num_folds = 10
kfold = KFold(n_splits=num_folds, shuffle=True)
# K-fold Cross Validation model evaluation
fold_no = 1
histories = {'accuracy':[], 'loss':[], 'val_accuracy':[], 'val_loss':[]}

for train, test in kfold.split(X, label):
  print("---"*20)
  history = siamese.fit(
      [tf.gather(X[:,0], train),tf.gather(X[:,1], train)],
      tf.gather(label, train),
      validation_data=([tf.gather(X[:,0], test),tf.gather(X[:,1], test)], tf.gather(label, test)),
      batch_size=batch_size,
      epochs=epochs,
  )
  histories['accuracy'].append(history.history['accuracy'])
  histories['loss'].append(history.history['loss'])
  histories['val_accuracy'].append(history.history['val_accuracy'])
  histories['val_loss'].append(history.history['val_loss'])
with open('./trainHistoryDict', 'wb') as file_pi:
      pickle.dump(histories, file_pi)
  • Related