I have created a model with an LSTM layer as shown below and want to get the internal state (hidden state and cell state) after the training step and save it. After the training step, I will use the network for a prediction and want to reinitialize the LSTM with the saved internal state before the next training step. This way I can continue from the same point after each training step. I haven't been able to find something helpful for the current version of tensoflow, i.e 2.x.
import tensorflow as tf
class LTSMNetwork(object):
def __init__(self, num_channels, num_hidden_neurons, learning_rate, time_steps, batch_size):
self.num_channels = num_channels
self.num_hidden_neurons = num_hidden_neurons
self.learning_rate = learning_rate
self.time_steps = time_steps
self.batch_size =batch_size
def lstm_model(self):
self.model = tf.keras.Sequential()
self.model.add(tf.keras.layers.LSTM(batch_input_shape=(self.batch_size, self.time_steps, self.num_channels),
activation='tanh', recurrent_activation='sigmoid',
return_sequences=True, stateful=True))
#self.model.add(tf.keras.layers.LSTM(units=self.num_hidden_neurons[1], stateful=True))
hidden_layer = tf.keras.layers.Dense(units=self.num_hidden_neurons[1], activation=tf.nn.sigmoid)
self.model.add(tf.keras.layers.Dense(units=self.num_channels, name="output_layer", activation=tf.nn.tanh))
loss='mse', metrics=['binary_accuracy'])
return self.model
if __name__=='__main__':
num_channels = 3
num_hidden_neurons = [150, 100]
learning_rate = 0.001
time_steps = 1
batch_size = 1
lstm_network = LTSMNetwork(num_channels=num_channels, num_hidden_neurons=num_hidden_neurons,
learning_rate=learning_rate, time_steps=time_steps, batch_size=batch_size)
model = lstm_network.lstm_model()
CodePudding user response:
You can define a custom Callback
and save the hidden and cell states at every epoch for example. Afterwards, you can choose from which epoch you want to extract the states and then use lstm_layer.reset_states(*)
to set the initial state again:
import tensorflow as tf
class LTSMNetwork(object):
def __init__(self, num_channels, num_hidden_neurons, learning_rate, time_steps, batch_size):
self.num_channels = num_channels
self.num_hidden_neurons = num_hidden_neurons
self.learning_rate = learning_rate
self.time_steps = time_steps
self.batch_size =batch_size
def lstm_model(self):
self.model = tf.keras.Sequential()
self.model.add(tf.keras.layers.LSTM(batch_input_shape=(self.batch_size, self.time_steps, self.num_channels),
activation='tanh', recurrent_activation='sigmoid',
return_sequences=True, stateful=True))
hidden_layer = tf.keras.layers.Dense(units=self.num_hidden_neurons[1], activation=tf.nn.sigmoid)
self.model.add(tf.keras.layers.Dense(units=self.num_channels, name="output_layer", activation=tf.nn.tanh))
loss='mse', metrics=['binary_accuracy'])
return self.model
states = {}
class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self, lstm_layer):
self.lstm_layer = lstm_layer
def on_epoch_end(self, epoch, logs=None):
states[epoch] = lstm_layer.states
num_channels = 3
num_hidden_neurons = [150, 100]
learning_rate = 0.001
time_steps = 1
batch_size = 1
lstm_network = LTSMNetwork(num_channels=num_channels, num_hidden_neurons=num_hidden_neurons,
learning_rate=learning_rate, time_steps=time_steps, batch_size=batch_size)
model = lstm_network.lstm_model()
lstm_layer = model.layers[0]
x = tf.random.normal((1, 1, 3))
y = tf.random.normal((1, 1, 3))
model.fit(x, y, epochs=5, callbacks=[CustomCallback(lstm_layer)])
lstm_layer.reset_states(states[0]) # Sets hidden state from first epoch.
consists of 5 internal states for each of the 5 epochs.