Home > database >  Get the internal states of LSTM layer after training and initialize the LSTM layer with saved intern
Get the internal states of LSTM layer after training and initialize the LSTM layer with saved intern

Time:02-22

I have created a model with an LSTM layer as shown below and want to get the internal state (hidden state and cell state) after the training step and save it. After the training step, I will use the network for a prediction and want to reinitialize the LSTM with the saved internal state before the next training step. This way I can continue from the same point after each training step. I haven't been able to find something helpful for the current version of tensoflow, i.e 2.x.

 import tensorflow as tf

class LTSMNetwork(object):
    def __init__(self, num_channels, num_hidden_neurons, learning_rate, time_steps, batch_size):
        self.num_channels = num_channels
        self.num_hidden_neurons = num_hidden_neurons
        self.learning_rate = learning_rate
        self.time_steps = time_steps
        self.batch_size =batch_size

    def lstm_model(self):
        self.model = tf.keras.Sequential()

        self.model.add(tf.keras.layers.LSTM(batch_input_shape=(self.batch_size, self.time_steps, self.num_channels),
                                            units=self.num_hidden_neurons[0], 
                                            activation='tanh', recurrent_activation='sigmoid', 
                                            return_sequences=True, stateful=True))
        #self.model.add(tf.keras.layers.LSTM(units=self.num_hidden_neurons[1], stateful=True))
        hidden_layer = tf.keras.layers.Dense(units=self.num_hidden_neurons[1], activation=tf.nn.sigmoid)
        self.model.add(hidden_layer)
        self.model.add(tf.keras.layers.Dense(units=self.num_channels, name="output_layer", activation=tf.nn.tanh))
        self.model.compile(optimizer=tf.optimizers.Adam(learning_rate=self.learning_rate), 
                            loss='mse', metrics=['binary_accuracy'])
        return self.model


if __name__=='__main__':

    num_channels = 3
    num_hidden_neurons = [150, 100]
    learning_rate = 0.001
    time_steps = 1
    batch_size = 1

    lstm_network = LTSMNetwork(num_channels=num_channels, num_hidden_neurons=num_hidden_neurons, 
                                learning_rate=learning_rate, time_steps=time_steps, batch_size=batch_size)
    model = lstm_network.lstm_model()
    model.summary()

CodePudding user response:

You can define a custom Callback and save the hidden and cell states at every epoch for example. Afterwards, you can choose from which epoch you want to extract the states and then use lstm_layer.reset_states(*) to set the initial state again:

import tensorflow as tf

class LTSMNetwork(object):
    def __init__(self, num_channels, num_hidden_neurons, learning_rate, time_steps, batch_size):
        self.num_channels = num_channels
        self.num_hidden_neurons = num_hidden_neurons
        self.learning_rate = learning_rate
        self.time_steps = time_steps
        self.batch_size =batch_size

    def lstm_model(self):
        self.model = tf.keras.Sequential()

        self.model.add(tf.keras.layers.LSTM(batch_input_shape=(self.batch_size, self.time_steps, self.num_channels),
                                            units=self.num_hidden_neurons[0], 
                                            activation='tanh', recurrent_activation='sigmoid', 
                                            return_sequences=True, stateful=True))
        hidden_layer = tf.keras.layers.Dense(units=self.num_hidden_neurons[1], activation=tf.nn.sigmoid)
        self.model.add(hidden_layer)
        self.model.add(tf.keras.layers.Dense(units=self.num_channels, name="output_layer", activation=tf.nn.tanh))
        self.model.compile(optimizer=tf.optimizers.Adam(learning_rate=self.learning_rate), 
                            loss='mse', metrics=['binary_accuracy'])
        return self.model


states = {}
class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, lstm_layer):
        self.lstm_layer = lstm_layer
   def on_epoch_end(self, epoch, logs=None):
        states[epoch] = lstm_layer.states

num_channels = 3
num_hidden_neurons = [150, 100]
learning_rate = 0.001
time_steps = 1
batch_size = 1

lstm_network = LTSMNetwork(num_channels=num_channels, num_hidden_neurons=num_hidden_neurons, 
                            learning_rate=learning_rate, time_steps=time_steps, batch_size=batch_size)
model = lstm_network.lstm_model()
lstm_layer = model.layers[0]
x = tf.random.normal((1, 1, 3))
y = tf.random.normal((1, 1, 3))
model.fit(x, y, epochs=5, callbacks=[CustomCallback(lstm_layer)])
model.summary()
lstm_layer.reset_states(states[0]) # Sets hidden state from first epoch.

States consists of 5 internal states for each of the 5 epochs.

  • Related