I have a tf model that has two outputs, as indicated by this model.compile():
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=7e-4),
loss={"BV": tf.keras.losses.MeanAbsoluteError(), "Rsp": tf.keras.losses.MeanAbsoluteError()},
metrics={"BV": [tf.keras.metrics.RootMeanSquaredError(name="RMSE"), tfa.metrics.r_square.RSquare(name="R2")],
"Rsp": [tf.keras.metrics.RootMeanSquaredError(name="RMSE"), tfa.metrics.r_square.RSquare(name="R2")]})
I would like to use the ModelCheckpoint callback, which should monitor a sum of val_BV_R2 and val_Rsp_R2. I am able to run the callback like this:
save_best_model = tf.keras.callbacks.ModelCheckpoint("xyz.hdf5", monitor="val_Rsp_R2")
However, I don't know how to make it to save the model with the highest sum of two metrics.
CodePudding user response:
According to the tf.keras.callbacks.ModelCheckpoint
documentation, the metric to monitor
che be only one at a time.
One way to achieve what you want, could be to define an additional custom metric, that performs the sum of the two metrics. Then you could monitor your custom metric and save the checkpoints as you are already doing. However this is a bit complicated, due to having multiple outputs.
Alternatively you could define a custom callback that does the same combining. Below a simple example of this second option. It should work (sorry I can't test it right now):
class CombineCallback(tf.keras.callbacks.Callback):
def __init__(self, **kargs):
super(CombineCallback, self).__init__(**kargs)
def on_epoch_end(self, epoch, logs={}):
logs['combine_metric'] = 0.5*logs['val_BV_R2'] 0.5*logs['val_Rsp_R2']
Inside the callback you should be able to access your metrics directly with logs['name_of_my_metric']
or through the get function logs.get("name_of_my_metric")
.
Also I multiplied by 0.5
to leave the combined metric approximately in the same range, but see if this works for your case.
To use it just do:
save_best_model = CombineCallback("xyz.hdf5")
model.fit(..., callbacks=[save_best_model])
More information can be found at the Examples of Keras callback applications.