I try to create a minimal non-convolutional NN image binary classifier with one hidden layer only (as a practice before more complicated models):
def make_model(input_shape):
inputs = keras.Input(shape=input_shape)
x = layers.Dense(128, activation="ReLU")(inputs)
outputs = layers.Dense(1, activation="sigmoid")(x)
return keras.Model(inputs, outputs)
model = make_model(input_shape=(256, 256, 3))
Its model.summary()
shows
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 256, 256, 3)] 0
dense (Dense) (None, 256, 256, 128) 512
dense_1 (Dense) (None, 256, 256, 1) 129
=================================================================
Total params: 641
Trainable params: 641
Non-trainable params: 0
Since the dense_1
layer has one neuron only, what I expect from this layer is an output shape of (None, 1)
(i,e, a single number indicating the predicted binary label) but instead the model gives (None, 256, 256, 1)
.
What's wrong with my model setting and how can I get it right?
CodePudding user response:
You have to flatten your preposterously large tensor if you want to use the output shape (None, 1)
:
import tensorflow as tf
def make_model(input_shape):
inputs = tf.keras.layers.Input(shape=input_shape)
x = tf.keras.layers.Dense(128, activation="relu")(inputs)
x = tf.keras.layers.Flatten()(x)
outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)
return tf.keras.Model(inputs, outputs)
model = make_model(input_shape=(256, 256, 3))
print(model.summary())
CodePudding user response:
A mistake is in your function make_model
.
def make_model(input_shape):
inputs = keras.Input(shape=input_shape)
x = layers.Dense(128, activation="ReLU")(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
return keras.Model(inputs, outputs)
You probably wanted the second line to be
x = layers.Dense(128, activation="ReLU")(inputs)
and not
x = layers.Dense(128, activation="ReLU")(x)
and unfortunately, x
exists in scope, so it didn't throw an error.