I'm having a hard time understanding what the problem is. Consider the following model:
Model: "model_8"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input (InputLayer) [(None, 15)] 0 []
dense_1 (Dense) (None, 128) 2048 ['input[0][0]']
dense_2 (Dense) (None, 1024) 132096 ['dense_1[0][0]']
dense_3 (Dense) (None, 5120) 5248000 ['dense_2[0][0]']
a_out (Dense) (None, 17) 87057 ['dense_3[0][0]']
b_out (Dense) (None, 27) 138267 ['dense_3[0][0]']
c_out (Dense) (None, 71) 363591 ['dense_3[0][0]']
d_out (Dense) (None, 29) 148509 ['dense_3[0][0]']
==================================================================================================
Total params: 6,119,568
Trainable params: 6,119,568
Non-trainable params: 0
It's a rather simple model with one input and 4 outputs (a_out
, b_out
, c_out
, and d_out
). I'm trying to fit the model by feeding it some dataset:
dataset = tf.data.Dataset.from_tensor_slices((inputs, {'a_out': targets[:, 0],
'b_out': targets[:, 1],
'c_out': targets[:, 2],
'd_out': targets[:, 3]}))
The inputs
and targets
are two numpy arrays with shapes: (525081, 15)
and (525081, 4)
respectively. When I run the fit method:
model.fit(dataset, epochs=10, batch_size=128)
I get the following error:
ValueError: Exception encountered when calling layer "model_8" (type Functional).
Input 0 of layer "dense_1" is incompatible with the layer: expected min_ndim=2, found ndim=1. Full shape received: (15,)
Call arguments received:
• inputs=tf.Tensor(shape=(15,), dtype=float64)
• training=True
• mask=None
It seems to me like the tensor sent to the layer_1
is missing the batch dimension which does not make sense to me. Am I constructing my dataset wrong?
CodePudding user response:
When using the tf.data
, batch_size
parameter in model.fit
is ignored.
Batching should be done with .batch()
method of tf.data.Dataset
.
In your case it should be dataset.batch(batch_size)
.