Home > Net >  Evaluating Pretrained Tensorflow keras model using various loss functions
Evaluating Pretrained Tensorflow keras model using various loss functions

Time:12-06

I'm searching for a way to evaluate a pre-trained TensorFlow Keras model using various loss functions such as MAE, MSE,.... and as I checked the Model.evaluate() function doesn't accept a loss function type as an argument, is it possible to do this without the need of recompiling the model every time we want to evaluate with a new loss function? what is the easiest way to do this?

CodePudding user response:

You can use multiple loss functions without recompiling; all you have to do is Assuming First Loss Method As Loss 1 & Second As Loss 2.

optimizer1 = tf.train.AdamOptimizer().minimize(loss1)
 
optimizer2 = tf.train.AdamOptimizer().minimize(loss2)

_, _, l1, l2 = sess.run(fetches=[optimizer1, optimizer2, loss1, loss2], feed_dict={x: batch_x, y: batch_y})

Sorry about the inconvenient writing of code, I'm new here

CodePudding user response:

You can recompile your model with new metrics. I believe you need new metrics for evaluating, not a new loss.

For example, define a model like this:

import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(x_train.shape[1:])),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(10, activation="softmax"),
])

model.compile(loss="sparse_categorical_crossentropy", metrics=["accuracy"], optimizer="adam")
model.fit(x_train, y_train, epochs=3, validation_split=0.2)
model.evaluate(x_test, y_test) 
# 313/313 [=================] - 1s 3ms/step - loss: 0.2179 - accuracy: 0.9444

Then you can recompile and evaluate again like:

# Change metrics
model.compile(metrics=["mae", "mse"], loss="sparse_categorical_crossentropy")
model.evaluate(x_test, y_test)
# 313/313 [=================] - 1s 3ms/step - loss: 0.2179 - mae: 4.3630 - mse: 27.3351
  • Related