Home > OS >  How to properly stack RNN layers?
How to properly stack RNN layers?

Time:05-11

I've been trying to implement a character-level language model in tensorflow based on this tutorial.

I would like to extend the model by allowing multiple RNN layers to be stacked. So far I've come up with this:

class MyModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_type, rnn_units, num_layers, dropout):
    super().__init__(self)
    self.rnn_type = rnn_type.lower()
    self.num_layers = num_layers
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    if self.rnn_type == 'gru':
      rnn_layer = tf.keras.layers.GRU
    elif self.rnn_type == 'lstm':
      rnn_layer = tf.keras.layers.LSTM
    elif self.rnn_type == 'rnn':
      rnn_layer = tf.keras.layers.SimpleRNN
    else:
      raise ValueError(f'Unsupported RNN layer: {rnn_type}')
    
    setattr(self, self.rnn_type, rnn_layer(rnn_units, return_sequences=True, return_state=True, dropout=dropout))

    for i in range(1, num_layers):
      setattr(self, f'{self.rnn_type}_{i}', rnn_layer(rnn_units, return_sequences=True, return_state=True, dropout=dropout))
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    
    rnn = self.get_layer(self.rnn_type)
    
    if states is None:
      states = rnn.get_initial_state(x)
    x, states = rnn(x, initial_state=states, training=training)
    for i in range(1, self.num_layers):
      layer = self.get_layer(f'{self.rnn_type}_{i}')
      x, states = layer(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x

model = MyModel(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    rnn_type='gru',
    rnn_units=512,
    num_layers=3,
    dropout=dropout)

When trained for 30 epochs on the dataset in the tutorial, this model generates some random gibberish. Now I don't know if I'm doing the stacking wrong or if the dataset is just too small.

CodePudding user response:

There are multiple factors contributing to the bad predictions of your model:

  • The dataset is small
  • The model itself you are using is quite simple
  • The training time is very short
  • Predicting Shakespeare sonnets will produce random gibberish even if trained right

Try to train it for longer. This will ultimately lead to better results although predicting coorect speech based on text may be one of the hardest tasks in Machine Learning in general. For example GPT3, one of the models, which solves this problem almost perfectly, consists of billions of parameters (see here).

EDIT: The reason why your model is performing worse than the one in the tutorial although you have more stacked RNN layers may be, that more layers need more training time. Simply increasing the number of layers will not necessarily increase your prediction quality. As I said, try to increase training time or play around with hyper parameters (learning rate, Nomralization layers, etc.).

  • Related