Home > Mobile >  Feed decoder input in transformers
Feed decoder input in transformers

Time:11-29

Reading this tutorial on how to implement an Encoder/Decoder transformer I had some doubts on the training process. Specifically as reported by the original paper the decoder should iteratively use the last iteration output as input of the decoder. However the training step is implemented as

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]

  with tf.GradientTape() as tape:
    predictions, _ = transformer([inp, tar_inp],
                                 training = True)
    loss = loss_function(tar_real, predictions)

  gradients = tape.gradient(loss, transformer.trainable_variables)
  optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

Where tar_inp is simply a tokenized sentence without the EOS token and tar_real is the same sentence shifted by one position.

However I would have expected the target input (the decoder input) to be iteratively concatenated by previous prediction (or in teacher-forced by incrementing by one ground truth token at a time).

Why is it not the case?

CodePudding user response:

This particular example actually uses teacher-forcing, but instead of feeding one GT token at a time, it feeds the whole decoder input. However, because the decoder uses only autoregressive (i.e. right-to-left) attention, it can attend only to tokens 0...i-1 when generating the i'th token. Therefore, such training is equivalent to teacher-forcing one token at a time, but is much faster, because all these tokens are predicted in parallel.

  • Related