When are Model call() and train_step() called?


I am going through this tutorial on how to customize the training loop


The last example shows a GAN implemented with a custom training, where only __init__, train_step, and compile methods are defined

class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]

What happens if my model also has a call() custom function? Does train_step() overrides call()? Aren't call() and train_step() both called by fit() and what is the difference between both ?

Below another piece of code "I" wrote where I wonder what is called into fit(), call() or train_step():

class MyModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
      return x

  def train_step(self, inputs):
    # unpack the data
    inputs, labels = inputs
    with tf.GradientTape() as tape:
      predictions = self(inputs, training=True) # forward pass
      # Compute the loss value
      # (the loss function is configured in `compile()`)
      loss=self.compiled_loss(labels, predictions, regularization_losses=self.losses)

    # compute the gradients
    grads=tape.gradient(loss, model.trainable_variables)
    # Update weights
    self.optimizer.apply_gradients(zip(grads, model.trainable_variables))
    # Update metrics (includes the metric that tracks the loss)
    self.compiled_metrics.update_state(labels, predictions)

    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in self.metrics}

These are different concepts and are used like this:

  • train_step is called by fit. Basically, fit loops over the dataset and provide each batch to train_step (and then handles metrics, bookkeeping, etc., of course).
  • call is used when you, well, call the model. To be precise, writing model(inputs) or in your case self(inputs) will use the function __call__, but the Model class has that function defined such that it will in turn use call.

Those are the technical aspects. Intuitively:

  • call should define the forward-pass of your model. i.e. how is the input transformed to the output.
  • train_step defines the logic of a training step, usually with gradient descent. It will often make use of call since the training step tends to include a forward pass of the model to compute gradients.

As for the GAN tutorial you linked, I would say that can actually be considered incomplete. It works without defining call because the custom train_step explicitly calls the generator/discriminator fields (as these are predefined models, they can be called as usual). If you tried to call the GAN model like gan(inputs), I would assume you get an error message (I did not test this). So you would always have to call gan.generator(inputs) to generate, for example.

Finally (this part may be a bit confusing), note that you can subclass a Model to define a custom training step, but then initialize it via the functional API (like model = Model(inputs, outputs)), in which case you can make use of call in the training step without ever defining it yourself because the functional API takes care of that.

