Home > OS >  What's the right way to specify a Keras input that contains data of different shape?
What's the right way to specify a Keras input that contains data of different shape?

Time:07-26

I'm playing with some reinforcement learning to create a solver for a word game.

A game has several turns that consist on the player choosing 5 letters (a to z) to form a word, and the environment colouring each inputted letter one of 3 colours depending on how "correct" it is.

My current plan is creating a deep neural network with the following inputs:

  1. The one-hot-encoded colour of each letter (shape (26, 3)).
  2. The current turn number (shape (1,)).

How can I "mix" these two inputs into a single input layer? Is there a better option than a uni-dimensional layer with shape (26 * 3 1,)?

CodePudding user response:

If I understand your question right, you want a way to transform the 2 input layers to create a (79,1) layer within the model instead of outside it.

You can do it like this

input_1 = tf.keras.layers.Input(shape=(26,3))
input_2 = tf.keras.layers.Input(shape=(1,))

flatten = tf.keras.layers.Flatten()(input_1)
concat = tf.keras.layers.Concatenate(axis=1)([flatten, input_2])

model = tf.keras.Model(inputs=[input_1,input_2],outputs=concat)
model.summary()


>>>
Model: "model_6"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_55 (InputLayer)          [(None, 26, 3)]      0           []                               
                                                                                                  
 flatten_15 (Flatten)           (None, 78)           0           ['input_55[0][0]']               
                                                                                                  
 input_56 (InputLayer)          [(None, 1)]          0           []                               
                                                                                                  
 concatenate_1 (Concatenate)    (None, 79)           0           ['flatten_15[0][0]',             
                                                                  'input_56[0][0]']               
                                                                                                  
==================================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
__________________________________________________________________________________________________

I was also able to do this by defining a Lambda layer

def concat(input_1,input_2):
  return tf.concat([tf.reshape(input_1,(78,)),input_2[0]],axis=0)


input_1 = tf.keras.layers.Input(shape=(26,3))
input_2 = tf.keras.layers.Input(shape=(1,))

output = tf.keras.layers.Lambda(lambda x: concat(x[0],x[1]),name="LambdaLayer")([input_1,input_2])

model = tf.keras.Model(inputs=[input_1,input_2],outputs=output)
model.summary()

>>>
Model: "model_7"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_71 (InputLayer)          [(None, 26, 3)]      0           []                               
                                                                                                  
 input_72 (InputLayer)          [(None, 1)]          0           []                               
                                                                                                  
 LambdaLayer (Lambda)           (79,)                0           ['input_71[0][0]',               
                                                                  'input_72[0][0]']               
                                                                                                  
==================================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
__________________________________________________________________________________________________
  • Related