Home > Software engineering >  Is there a way to pickle a custom tensorflow.keras metric?
Is there a way to pickle a custom tensorflow.keras metric?

Time:06-21

I defined the following custom metric to train my model in tensorflow:

import tensorflow as tf
from tensorflow import keras as ks
N_CLASSES = 15

class MulticlassMeanIoU(tf.keras.metrics.MeanIoU):
    def __init__(self,
                 y_true = None,
                 y_pred = None,
                 num_classes = None,
                 name = "Multi_MeanIoU",
                 dtype = None):
        super(MulticlassMeanIoU, self).__init__(num_classes = num_classes,
                                             name = name, dtype = dtype)
        self.__name__ = name

    def get_config(self):
        base_config = super().get_config()
        return {**base_config, "num_classes": self.num_classes}

    def update_state(self, y_true, y_pred, sample_weight = None):
        y_pred = tf.math.argmax(y_pred, axis = -1)
        return super().update_state(y_true, y_pred, sample_weight)

met = MulticlassMeanIoU(num_classes = N_CLASSES)

After training the model, I save the model and I also tried to save the custom object as follows:

with open("/some/path/custom_metrics.pkl", "wb") as f:
    pickle.dump(met, f)

However, when I try to load the metric like this:

with open(path_custom_metrics, "rb") as f:
    met = pickle.load(f)

I always get some errors, e.g. AttributeError: 'MulticlassMeanIoU' object has no attribute 'update_state_fn'.

Now I wonder whether it is possible to pickle a custom metric at all and if so, how? It would come in handy if I could save custom metrics with the model, so when I load the model in another Python session, I always have the metric which is required to load the model in the first place. It would be possible to define the metric anew through inserting the full code to the other script before loading the model, however, I think this would be bad style and could cause problems in case I would change something about the metric in the training script and forget to copy the code to the other script.

CodePudding user response:

If you need to pickle a metric, one possible solution is to use __getstate__() and __setstate__() methods. During the (de)serialization process, these two methods are called, if they are available. Add these methods to your code and you will have what you need. I tried to make it as general as possible, so that it works for any Metric:

    def __getstate__(self):
    variables = {v.name: v.numpy() for v in self.variables}
    state = {
        name: variables[var.name]
        for name, var in self._unconditional_dependency_names.items()
        if isinstance(var, tf.Variable)}
    state['name'] = self.name
    state['num_classes'] = self.num_classes
    return state

def __setstate__(self, state: Dict[str, Any]):
    self.__init__(name=state.pop('name'), num_classes=state.pop('num_classes'))
    for name, value in state.items():
        self._unconditional_dependency_names[name].assign(value)
  • Related