I am training a model(VAEGAN) with intermediate outputs and I have two losses,
- KL Divergence loss I compute from output layer
- Similarity (rec) loss I compute from an intermediate layer.
Can I simply sum them up and apply gradients like below?
with tf.GradientTape() as tape:
z_mean, z_log_sigma, z_encoder_output = self.encoder(real_images, training = True)
kl_loss = self.kl_loss_fn(z_mean, z_log_sigma) * kl_loss_coeff
fake_images = self.decoder(z_encoder_output)
fake_inter_activations, logits_fake = self.discriminator(fake_images, training = True)
real_inter_activations, logits_real = self.discriminator(real_images, training = True)
rec_loss = self.rec_loss_fn(fake_inter_activations, real_inter_activations) * rec_loss_coeff
total_encoder_loss = kl_loss rec_loss
grads = tape.gradient(total_encoder_loss, self.encoder.trainable_weights)
self.e_optimizer.apply_gradients(zip(grads, self.encoder.trainable_weights))
or do I need to seperate them like below while keeping tape persistent?
with tf.GradientTape(persistent = True) as tape:
z_mean, z_log_sigma, z_encoder_output = self.encoder(real_images, training = True)
kl_loss = self.kl_loss_fn(z_mean, z_log_sigma) * kl_loss_coeff
fake_images = self.decoder(z_encoder_output)
fake_inter_activations, logits_fake = self.discriminator(fake_images, training = True)
real_inter_activations, logits_real = self.discriminator(real_images, training = True)
rec_loss = self.rec_loss_fn(fake_inter_activations, real_inter_activations) * rec_loss_coeff
grads_kl_loss = tape.gradient(kl_loss, self.encoder.trainable_weights)
self.e_optimizer.apply_gradients(zip(grads_kl_loss, self.encoder.trainable_weights))
grads_rec_loss = tape.gradient(rec_loss, self.encoder.trainable_weights)
self.e_optimizer.apply_gradients(zip(grads_rec_loss, self.encoder.trainable_weights))
CodePudding user response:
Yes, you can generally sum the losses and compute a single gradient. Since the gradient of a sum is the sum of the respective gradients, so the step taken by the summed loss is the same as taking both steps one after another.
Here's a simple example: Say you have two weights, and you are currently at the point (1, 3) ("starting point"). The gradient for loss 1 is (2, -4) and the gradient for loss 2 is (1, 2).
- If you apply the steps one after the other, you will first move to (3, -1) and then to (4, 1).
- If you sum the gradients first, the overall step is (3, -2). Following this direction from the starting point gets you to (4, 1) as well.