Home > database >  making GRU/LSTM states trainable in Tensorflow/Keras and add some random noise
making GRU/LSTM states trainable in Tensorflow/Keras and add some random noise

Time:02-23

I train the following model based on GRU, note that I am passing the argument stateful=True to the GRU builder.

class LearningToSurpriseModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   stateful=True,
                                   return_sequences=True,
                                   return_state=True,
                                   reset_after=True  
                                   )
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x

  @tf.function
  def train_step(self, inputs):
    [defining here my training step]

I instantiate my model

model = LearningToSurpriseModel(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
    )

[compile and do stuff] the custom callback below reset states manually at the end of each epoch.

gru_layer = model.layers[1]

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, gru_layer):
        self.gru_layer = gru_layer
   def on_epoch_end(self, epoch, logs=None):
        self.gru_layer.reset_states()
        
model.fit(train_dataset, validation_data=validation_dataset, \
    epochs=EPOCHS, callbacks = [EarlyS, CustomCallback(gru_layer)], verbose=1)

States will be reset to zero. I would like to follow ideas in https://r2rt.com/non-zero-initial-states-for-recurrent-neural-networks.html to make states trainable. Implementation in this post seems based on tensorflow, and overwrites native functions, maybe there is a more elegant way in Keras.

(1) how do I make states trainable ?

(2) how do I combine trainable states and random initialization ?

CodePudding user response:

You could try defining a custom GRU layer with a trainable variable for the states but not sure how the performance will be:

import tensorflow as tf

class CustomGRULayer(tf.keras.layers.Layer):
  def __init__(self, rnn_units, batch_size):
    super(CustomGRULayer, self).__init__()
    self.rnn_units = rnn_units
    self.batch_size = batch_size
    self.gru = tf.keras.layers.GRU(self.rnn_units,
                                   stateful=True,
                                   return_sequences=True,
                                   return_state=True,
                                   reset_after=True
                                   )

  def build(self, input_shape):
    w_init = tf.random_normal_initializer(mean=0.0, stddev=0.01)
    self.w = tf.Variable(
        initial_value=w_init(shape=(self.batch_size, self.rnn_units),
                             dtype='float32'), trainable=True, constraint=lambda z: tf.clip_by_value(z, 0.0, 0.1))
  def call(self, inputs): 
    return self.gru(inputs, initial_state = self.w)

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, gru_layer):
        self.gru_layer = gru_layer
   def on_epoch_end(self, epoch, logs=None):
        self.gru_layer.gru.reset_states(self.gru_layer.w)

batch_size = 2
gru_layer = CustomGRULayer(rnn_units = 32, batch_size = batch_size)
inputs = tf.keras.layers.Input(batch_shape=(batch_size, 5, 10))
x, _ = gru_layer(inputs)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(units=1)(x)
model = tf.keras.Model(inputs, x)
model.compile(loss=tf.keras.losses.BinaryCrossentropy())
x_train = tf.random.normal((10, 5, 10))
print(model.summary())
model.fit(x_train, tf.random.uniform((10, 1), maxval=2), epochs=10, batch_size=batch_size, callbacks = [CustomCallback(gru_layer)])
Model: "model_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_19 (InputLayer)       [(2, 5, 10)]              0         
                                                                 
 custom_gru_layer_16 (Custom  [(2, 5, 32),             4288      
 GRULayer)                    (2, 32)]                           
                                                                 
 flatten_11 (Flatten)        (2, 160)                  0         
                                                                 
 dense_11 (Dense)            (2, 1)                    161       
                                                                 
=================================================================
Total params: 4,449
Trainable params: 4,449
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/10
5/5 [==============================] - 2s 4ms/step - loss: 10.7969
Epoch 2/10
5/5 [==============================] - 0s 5ms/step - loss: 8.6558
Epoch 3/10
5/5 [==============================] - 0s 6ms/step - loss: 8.1546
Epoch 4/10
5/5 [==============================] - 0s 4ms/step - loss: 4.9713
Epoch 5/10
5/5 [==============================] - 0s 4ms/step - loss: 4.7329
Epoch 6/10
5/5 [==============================] - 0s 4ms/step - loss: 4.1399
Epoch 7/10
5/5 [==============================] - 0s 4ms/step - loss: 3.9296
Epoch 8/10
5/5 [==============================] - 0s 4ms/step - loss: 4.3430
Epoch 9/10
5/5 [==============================] - 0s 5ms/step - loss: 3.7707
Epoch 10/10
5/5 [==============================] - 0s 5ms/step - loss: 3.7045
<keras.callbacks.History at 0x7f940546b210>

Or without a custom layer:

import tensorflow as tf

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, gru_layer, w):
        self.gru_layer = gru_layer
        self.w = w
   def on_epoch_end(self, epoch, logs=None):
        self.gru_layer.reset_states(self.w)

batch_size = 2
w_init = tf.random_normal_initializer(mean=0.0, stddev=0.01)
w = tf.Variable(
    initial_value=w_init(shape=(batch_size, 32),
                          dtype='float32'), trainable=True, constraint=lambda z: tf.clip_by_value(z, 0.0, 0.1))

gru_layer = tf.keras.layers.GRU(32, stateful=True, return_sequences=True, return_state=True, reset_after=True)

inputs = tf.keras.layers.Input(batch_shape=(batch_size, 5, 10))
x, _ = gru_layer(inputs, initial_state = w)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(units=1)(x)
model = tf.keras.Model(inputs, x)
model.compile(loss=tf.keras.losses.BinaryCrossentropy())
x_train = tf.random.normal((10, 5, 10))
print(model.summary())
model.fit(x_train, tf.random.uniform((10, 1), maxval=2), epochs=10, batch_size=batch_size, callbacks = [CustomCallback(gru_layer, w)])
  • Related