I've been having this question bugging me for some time: Is it possible to use the method call()
of tf.keras.model
with labels? From what I've seen it is not plausible, but it just strikes me as odd that you are able to train the model using this method but you can't pass it labels like the .fit()
method.
Also, this question arised when I was reading the tutorial to make a DCGAN in the tensorflow documentation.
Source: https://www.tensorflow.org/tutorials/generative/dcgan
CodePudding user response:
You can pass a list of tensors to the call function, so you could pass the labels. However, this is not in the logic of tensorflow/Keras training. In your example, the basic training routine is train_step. The output tensors are first calculated by the generator and discriminator call function, and then passed to the functions that calculate the losses. This is the standard way of doing things:
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))