I am attempting to adapt the Keras VAE model found here to fit my data, and I would quite like to include a call function to be able to integrate validation data as a monitor for performance. However I cannot figure out how to pass it the different losses I calculate on my data.
Here is what my code currently looks like:
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = keras.metrics.Mean(name="loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
if isinstance(data, tuple):
data = data[0]
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
## BASE RECONSTRUCTION LOSS:
reconstruction_loss = tf.reduce_mean( keras.losses.binary_crossentropy(data, reconstruction) )
## ELBO RECONSTRUCTION LOSS:
# reconstruction_loss = tf.reduce_mean( keras.backend.sum(keras.backend.binary_crossentropy(data, reconstruction), axis=-1) )
kl_loss = -0.5 * (1 z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
## BASE TOTAL LOSS:
total_loss = reconstruction_loss kl_loss
## WEIGHTED TOTAL LOSS: try to increase importance of reconstruction loss
# total_loss = reconstruction_loss 0.1*kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
def call(self, data):
## TENTATIVE CALL FUNCTION FOR VALIDATION DATA
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean( keras.losses.binary_crossentropy(data, reconstruction) )
kl_loss = -0.5 * (1 z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss kl_loss
self.add_loss(reconstruction_loss)
self.add_loss(kl_loss)
self.add_loss(total_loss)
return reconstruction
The self.add_loss()
comes from this page of the TF guide, but during training the log just shows 0.0 for all validation losses.
Should I be using another metric and tracker and update those?
CodePudding user response:
Personally when I was learning how to use keras.Model
everything on the keras api was Arabic (and I don't know a single word of Arabic)... however, this page of the TF documentation explains it pretty clearly, in particular, there is explained what test_step
does, that is what you are looking for:
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
def test_step(self, data):
## TENTATIVE CALL FUNCTION FOR VALIDATION DATA
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean( keras.losses.binary_crossentropy(data, reconstruction) )
kl_loss = -0.5 * (1 z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss kl_loss
return {
"loss": total_loss,
"reconstruction_loss": reconstruction_loss,
"kl_loss": kl_loss,
}
I can see this output:
Epoch 1/30
438/438 [==============================] - 7s 15ms/step - loss: 147.3851 - reconstruction_loss: 141.3100 - kl_loss: 6.2865 - val_loss: 6.6573 - val_reconstruction_loss: 0.1790 - val_kl_loss: 6.4783
Which I think is was you were looking for