I am going through this tutorial on how to customize the training loop
The last example shows a GAN implemented with a custom training, where only __init__
, train_step
, and compile
methods are defined
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super(GAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, d_optimizer, g_optimizer, loss_fn):
super(GAN, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
...
What happens if my model also has a call()
custom function? Does train_step()
overrides call()
?
Aren't call()
and train_step()
both called by fit()
and what is the difference between both ?
Below another piece of code "I" wrote where I wonder what is called into fit()
, call()
or train_step()
:
class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True,
reset_after=True
)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
@tf.function
def train_step(self, inputs):
# unpack the data
inputs, labels = inputs
with tf.GradientTape() as tape:
predictions = self(inputs, training=True) # forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss=self.compiled_loss(labels, predictions, regularization_losses=self.losses)
# compute the gradients
grads=tape.gradient(loss, model.trainable_variables)
# Update weights
self.optimizer.apply_gradients(zip(grads, model.trainable_variables))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(labels, predictions)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
CodePudding user response:
These are different concepts and are used like this:
train_step
is called byfit
. Basically,fit
loops over the dataset and provide each batch totrain_step
(and then handles metrics, bookkeeping, etc., of course).call
is used when you, well, call the model. To be precise, writingmodel(inputs)
or in your caseself(inputs)
will use the function__call__
, but theModel
class has that function defined such that it will in turn usecall
.
Those are the technical aspects. Intuitively:
call
should define the forward-pass of your model. i.e. how is the input transformed to the output.train_step
defines the logic of a training step, usually with gradient descent. It will often make use ofcall
since the training step tends to include a forward pass of the model to compute gradients.
As for the GAN tutorial you linked, I would say that can actually be considered incomplete. It works without defining call
because the custom train_step
explicitly calls the generator/discriminator fields (as these are predefined models, they can be called as usual). If you tried to call the GAN model like gan(inputs)
, I would assume you get an error message (I did not test this). So you would always have to call gan.generator(inputs)
to generate, for example.
Finally (this part may be a bit confusing), note that you can subclass a Model
to define a custom training step, but then initialize it via the functional API (like model = Model(inputs, outputs)
), in which case you can make use of call
in the training step without ever defining it yourself because the functional API takes care of that.