I am using an LSTM for fake news detection and added an embedding layer to my model.
It is working fine without adding any input_shape in the LSTM function, but I thought the input_shape
parameter was mandatory. Could someone help me with why there is no error even without defining input_shape
? Is it because the embedding layer implicitly defines the input_shape
?
Following is the code:
model=Sequential()
embedding_layer = Embedding(total_words, embedding_dim, weights=[embedding_matrix], input_length=max_length)
model.add(embedding_layer)
model.add(LSTM(64,))
model.add(Dense(1,activation='sigmoid'))
opt = SGD(learning_rate=0.01,decay=1e-6)
model.compile(loss = "binary_crossentropy", optimizer = opt,metrics=['accuracy'])
model.fit(data,train['label'], epochs=30, verbose=1)
CodePudding user response:
You only need to provide an input_length
to the Embedding
layer. Furthermore, if you use a sequential
model, you do not need to provide an input layer. Avoiding an input layer essentially means that your models weights are only created when you pass real data, as you did in model.fit(*)
. If you wanted to see the weights of your model before providing real data, you would have to define an input layer before your Embedding
layer like this:
embedding_input = tf.keras.layers.Input(shape=(max_length,))
And yes, as you mentioned, your model infers the input_shape
implicitly when you provide the real data. Your LSTM
layer does not need an input_shape
as it is also derived based on the output of your Embedding
layer. If the LSTM
layer were the first layer of your model, it would be best to specify an input_shape
for clarity. For example:
model = tf.keras.Sequential()
model.add(tf.keras.layers.LSTM(32, input_shape=(10, 5)))
model.add(tf.keras.layers.Dense(1))
where 10 represents the number of time steps and 5 the number of features. In your example, your input to the LSTM
layer has the shape(max_length, embedding_dim)
. Also here, if you do not specify the input_shape
, your model will infer the shape based on your input data.
For more information check out the Keras documentation.