Home > Net >  How to add class_weight to a custom train_step function in a custom keras model?
How to add class_weight to a custom train_step function in a custom keras model?

Time:09-02

I am using tensorflow 2.8. I followed a tutorial from tensorflow on how to create your own fit function by overwriting the train_step function in your custom keras model class.

I wanted to add class_weight but in their section "Supporting sample_weight & class_weight" they don't show how to actually use class_weight, only sample_weight.

Is there a way to use class_weight in a custom train_step function?

I also found this Colab notebook in a GitHub issue. However this creates a custom model class but doesn't even use it and is therefore of no help either.

When actually creating the custom model and calling fit() I get the error: TypeError: __call__() got an unexpected keyword argument 'class_weight', when the loss in train_step() is calculated.

Example code (with the error) of what I'm trying to do:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Get model
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, name="predictions")(x)

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, class_weight = data
        else:
            x, y = data

        with tf.GradientTape() as tape:
            logits = self(x, training=True)  # Forward pass
            # Compute the loss value.
            # The loss function is configured in `compile()`.
            loss = self.compiled_loss(
                y,
                logits,
                class_weight=class_weight,
                regularization_losses=self.losses,
            )

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics.
        # Metrics are configured in `compile()`.
        self.compiled_metrics.update_state(y, y_pred, class_weight=class_weight)

        # Return a dict mapping metric names to current value.
        return {m.name: m.result() for m in self.metrics}


# Construct and compile an instance of CustomModel
model = CustomModel(inputs=inputs, outputs=outputs)
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])

class_weight = {
    0: 1.0,
    1: 1.0,
    2: 1.0,
    3: 1.0,
    4: 1.0,
    # Set weight "2" for class "5",
    # making this class 2x more important
    5: 2.0,
    6: 1.0,
    7: 1.0,
    8: 1.0,
    9: 1.0,
}

model.fit(train_dataset, class_weight=class_weight, epochs=3)

All those examples do not work because loss functions don't take class_weight as an argument, as far as I understand when looking at the documentation.

I tried to fix this by creating my own loss function:

@tf.function
def weighted_sparse_categorical_crossentropy(labels, logits, class_weight=None):
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(labels, logits)
    if class_weight is None:
        if debug:
            print("None weights: NO WEIGHTS SET!")
        return loss

    # get all class weights as list
    if type(class_weight) is not list:
        class_weight = list(class_weight.values())

    class_weights_gathered = tf.gather(class_weight, labels)

    return tf.reduce_mean(class_weights_gathered * loss)

Then using this to compile my model and calling .fit()

model.compile(optimizer=optimizer, loss=weighted_sparse_categorical_crossentropy)
model.fit(X, class_weight=class_weight, epochs=1)

But I still get TypeError: __call__() got an unexpected keyword argument 'class_weight' despite class_weight being clearly an argument in my function.

I also looked at GitHub to see what Tensorflow actually does with the class_weight inside the .fit() function and seems to convert it to a sample_weight somehow.

So I'm not sure if what I want is even possible. But then the section in the official tensorflow tutorial would be wrong since there would be no support for class_weight.

CodePudding user response:

Pretty sure the error is here:

loss = self.compiled_loss(
    y,
    logits,
    class_weight=class_weight,
    regularization_losses=self.losses,
)

because class_weight is not recognized... instead, you should use sample_weight, and from keras documentation seems that is as simple as expanding the incoming data, as you are doing:

def train_step(self, data):
   x, y, sample_weight = data

your doubt about the "missing example", is due to the fact that keras will automatically transform your class_weight to sample_weight, as you can see from there code:

class M(K.Model):
    def train_step(self, data):
        tf.print(data)
        return {"loss":0}

model = M()
model.compile(K.optimizers.Adam(), K.losses.SparseCategoricalCrossentropy(), run_eagerly=True)
model.fit(np.array([[-1], [-1], [-1]]), np.array([[0],[0],[1]]), class_weight={0:7, 1:9})

which prints:

([[-1][-1][-1]], 
 [[0] [1] [0]], 
 [7 9 7])

where you can clearly see that the first line are the x, the second line are the y, the third one are the associated class weights to y, which you can feed to a loss as usual:

x, y, sample_weight = data
... 
loss = self.compiled_loss(
    y,
    logits,
    sample_weight=sample_weight,
    regularization_losses=self.losses,
)
  • Related