I have a model I am trying to set up with 8 inputs. The first 7 are length 1 IDs that are each fed into embedding layers and these outputs are concatenated with a set of 4 numeric variables.
So in the model definition includes:
input_A = keras.Input(shape=(1,))
input_B = keras.Input(shape=(1,))
input_C = keras.Input(shape=(1,))
input_D = keras.Input(shape=(1,))
input_E = keras.Input(shape=(1,))
input_F = keras.Input(shape=(1,))
input_G = keras.Input(shape=(1,))
input_nums = keras.Input(shape=(4,))
embed_A = keras.layers.Embedding(1223 1, 50)(input_A)
embed_B = keras.layers.Embedding(50 1, 25)(input_B)
embed_C = keras.layers.Embedding(1259 1, 50)(input_C)
embed_D = keras.layers.Embedding(3995 1, 50)(input_D)
embed_E = keras.layers.Embedding(2040 1, 50)(input_E)
embed_F = keras.layers.Embedding(174 1, 50)(input_F)
embed_G = keras.layers.Embedding(227 1, 50)(input_G)
embed_A = keras.layers.Flatten()(embed_A)
embed_B = keras.layers.Flatten()(embed_B)
embed_C = keras.layers.Flatten()(embed_C)
embed_D = keras.layers.Flatten()(embed_D)
embed_E = keras.layers.Flatten()(embed_E)
embed_F = keras.layers.Flatten()(embed_F)
embed_G = keras.layers.Flatten()(embed_G)
x = keras.layers.concatenate([embed_A,embed_B,embed_C,embed_D,embed_E,embed_F,embed_G,input_nums])
Then the model is constructed:
model = keras.Model(inputs=[input_A, input_B, input_C, input_D, input_E, input_F, input_G, input_nums], outputs = [out])
In the tfdataset map function I tried to structure the input data like this but fitting the model produces an error:
# keras needs: Should return a tuple of either (inputs, targets) or (inputs, targets, sample_weights).
return (
(example["A"],example["B"],example["C"],
example["D"],example["E"],example["F"],
example["G"],
(example[‘num_A'],example[' num_B '],example[' num_C'],example[' num_D '])
),
label)
ValueError: Layer model expects 8 input(s), but it received 11 input tensors
How can I set up the map function of tfdataset to work with this model?
CodePudding user response:
I found this works as the return out of the map function:
return (example["A"],example["B"],example["C"],
example["D"],example["E"],example["F"],
example["G"],
[example[‘num_A'],example[' num_B '],example[' num_C'],example[' num_D ']]
),
label