Hello I'm working on object detection using tensorflow 2 object detection API model_main_tf2.py file normally we can use early stopping callback for model.fit() when we use normally but when i tried to training with pipeline config model_main_tf2.py file and .config file I'm not able to implement it because I'm unable to locate model.fit() in the main file so please is there any way i can implement the early stopping for model_main_tf2.py file please help me.
Link to the file: https://github.com/tensorflow/models/blob/master/research/object_detection/model_main_tf2.py
CodePudding user response:
I had a look inside the model_main_tf2.py
file. Let's take the following piece of code:
model_lib_v2.train_loop(
pipeline_config_path=FLAGS.pipeline_config_path,
model_dir=FLAGS.model_dir,
train_steps=FLAGS.num_train_steps,
use_tpu=FLAGS.use_tpu,
checkpoint_every_n=FLAGS.checkpoint_every_n,
record_summaries=FLAGS.record_summaries)
Instead of executing the training through fit
it is used a custom training loop. In the code above is called the function that executes the training operation. model_lib_v2
is just another file of the repo that you've linked.
If you have a look at the train_loop
function, you'll see that at some point is executed the following code:
with tf.GradientTape() as tape:
losses_dict, _ = _compute_losses_and_predictions_dicts(
detection_model, features, labels,
training_step=training_step,
add_regularization_loss=add_regularization_loss)
losses_dict = normalize_dict(losses_dict, num_replicas)
trainable_variables = detection_model.trainable_variables
total_loss = losses_dict['Loss/total_loss']
gradients = tape.gradient(total_loss, trainable_variables)
GradientTape basically computes the gradients needed to update the weights of the model during the training phase. I won't go into much details, but if you are interested you can have a look at the linked documentation.
Now, you are interested in adding an early stopping callback, but you don't have a fit
. You can still add early stopping, but in a different way.
You can follow a strategy like the one below (Refer to this tutorial by tensorflow for the full code):
epochs = 100
patience = 5 # you can play with this values to obtain the best config
wait = 0
best = 0
for epoch in range(epochs):
# training (calling the function that holds the GradientTape
for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
loss_value = train_step(x_batch_train, y_batch_train)
# updating the metrics after the whole training loop on a single epoch
train_acc = train_acc_metric.result()
train_loss = train_loss_metric.result()
train_acc_metric.reset_states()
train_loss_metric.reset_states()
print("Training acc over epoch: %.4f" % (train_acc.numpy()))
# evaluating the model just trained in a new epoch, on the validation data
for x_batch_val, y_batch_val in ds_test:
test_step(x_batch_val, y_batch_val)
# updating the metrics for validation
val_acc = val_acc_metric.result()
val_loss = val_loss_metric.result()
val_acc_metric.reset_states()
val_loss_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
# The early stopping strategy: stop the training if `val_loss` does not
# decrease over a certain number of epochs.
wait = 1
if val_loss > best:
best = val_loss
wait = 0
if wait >= patience:
break