Home > front end >  Create call function for custom VAE Model
Create call function for custom VAE Model

Time:07-27

I am attempting to adapt the Keras VAE model found here to fit my data, and I would quite like to include a call function to be able to integrate validation data as a monitor for performance. However I cannot figure out how to pass it the different losses I calculate on my data.

Here is what my code currently looks like:

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            ## BASE RECONSTRUCTION LOSS:
            reconstruction_loss = tf.reduce_mean( keras.losses.binary_crossentropy(data, reconstruction) )
            ## ELBO RECONSTRUCTION LOSS: 
            # reconstruction_loss = tf.reduce_mean( keras.backend.sum(keras.backend.binary_crossentropy(data, reconstruction), axis=-1) )
            kl_loss = -0.5 * (1   z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            ## BASE TOTAL LOSS:
            total_loss = reconstruction_loss   kl_loss
            ## WEIGHTED TOTAL LOSS: try to increase importance of reconstruction loss
            # total_loss = reconstruction_loss   0.1*kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }
    
    def call(self, data):
        ## TENTATIVE CALL FUNCTION FOR VALIDATION DATA
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean( keras.losses.binary_crossentropy(data, reconstruction) )
        kl_loss = -0.5 * (1   z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss   kl_loss
        self.add_loss(reconstruction_loss)
        self.add_loss(kl_loss)
        self.add_loss(total_loss)
        return reconstruction


The self.add_loss() comes from this page of the TF guide, but during training the log just shows 0.0 for all validation losses.

Should I be using another metric and tracker and update those?

CodePudding user response:

Personally when I was learning how to use keras.Model everything on the keras api was Arabic (and I don't know a single word of Arabic)... however, this page of the TF documentation explains it pretty clearly, in particular, there is explained what test_step does, that is what you are looking for:

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1   z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss   kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }
    def test_step(self, data):
        ## TENTATIVE CALL FUNCTION FOR VALIDATION DATA
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean( keras.losses.binary_crossentropy(data, reconstruction) )
        kl_loss = -0.5 * (1   z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss   kl_loss
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

I can see this output:

Epoch 1/30
438/438 [==============================] - 7s 15ms/step - loss: 147.3851 - reconstruction_loss: 141.3100 - kl_loss: 6.2865 - val_loss: 6.6573 - val_reconstruction_loss: 0.1790 - val_kl_loss: 6.4783

Which I think is was you were looking for

  • Related