I am adding a custom loss to a VAE, as suggested here: https://www.linkedin.com/pulse/supervised-variational-autoencoder-code-included-ibrahim-sobh-phd/
Instead of defining a loss function, it uses a dense
network and takes its output as the loss (if I understand correctly).
# New: add a classifier
clf_latent_inputs = Input(shape=(latent_dim,), name='z_sampling_clf')
clf_outputs = Dense(10, activation='softmax', name='class_output')(clf_latent_inputs)
clf_supervised = Model(clf_latent_inputs, clf_outputs, name='clf')
clf_supervised.summary()
# instantiate VAE model
# New: Add another output
outputs = [decoder(encoder(inputs)[2]), clf_supervised(encoder(inputs)[2])]
vae = Model(inputs, outputs, name='vae_mlp')
vae.summary()
reconstruction_loss = binary_crossentropy(inputs, outputs[0])
reconstruction_loss *= original_dim
kl_loss = 1 z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean((reconstruction_loss kl_loss) /100.0)
vae.add_loss(vae_loss)
# New: add the clf loss
vae.compile(optimizer='adam', loss={'clf': 'categorical_crossentropy'}) ===> this line <===
vae.summary()
# reconstruction_loss = binary_crossentropy(inputs, outputs)
svae_history = vae.fit(x_train, {'clf': y_train},
epochs=epochs,
batch_size=batch_size)
I was stuck at the compilation step (annotated as ===> this line <===) that I met a type error:
TypeError: Expected float32, got <function BaseProtVAE.init..vae_loss at 0x7ff53051dd08> of type 'function' instead.
I need your help if you've got any suggestions.
CodePudding user response:
There are several ways to implement VAE in Tensorflow. I propose an alternative implementation that can be found in custom_layers_and_models in Tensorflow guide pages :
Let's put all of these things together into an end-to-end example: we're going to implement a Variational AutoEncoder (VAE). We'll train it on MNIST digits.
It uses custom Model classes and the gradient tape. In this way, it is quite easy to add the classifier into the VAE model and add the categorical cross-entropy to the total loss during the optimization.
All you need is to modify:
class VariationalAutoEncoder(Model):
"""Combines the encoder and decoder into an end-to-end model for training."""
def __init__(
self,
original_dim,
intermediate_dim=64,
latent_dim=32,
name="autoencoder",
**kwargs
):
super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)
self.original_dim = original_dim
self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)
self.clf_supervised = Dense(10, activation='softmax', name='class_output')
def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
# Add KL divergence regularization loss.
kl_loss = -0.5 * tf.reduce_mean(
z_log_var - tf.square(z_mean) - tf.exp(z_log_var) 1
)
self.add_loss(kl_loss)
# classifier
y_pred = self.clf_supervised(z)
return reconstructed, y_pred
by adding the lines self.clf_supervised = Dense(10, activation='softmax', name='class_output')
and y_pred = self.clf_supervised(z)
.
The optimization is done this way:
vae = VariationalAutoEncoder(original_dim, intermediate_dim, latent_dim)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()
loss_metric = tf.keras.metrics.Mean()
epochs = 2
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=500).batch(4)
# Iterate over epochs.
for epoch in range(epochs):
print("Start of epoch %d" % (epoch,))
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
reconstructed, y_pred = vae(x_batch_train)
clf_loss = tf.keras.losses.SparseCategoricalCrossentropy()(y_batch_train, y_pred)
# Compute reconstruction loss
loss = mse_loss_fn(x_batch_train, reconstructed)
loss = sum(vae.losses) # Add KLD regularization loss
loss = clf_loss
grads = tape.gradient(loss, vae.trainable_weights)
optimizer.apply_gradients(zip(grads, vae.trainable_weights))
loss_metric(loss)
if step % 100 == 0:
print("step %d: mean loss = %.4f" % (step, loss_metric.result()))
The rest of the code is in the link above. The main change is the optimization done with tf.GradientTape(). It's a bit more complicated than the fit method but it's still quite simple and very powerful.