I wrote a custom Tree-RNN-CELL that can handle several different inputs when they are provided as a tuple.
...
treeCell3_1 = TreeRNNCell(units=encodingBitLength, num_children=2)
RNNC = layers.RNN(treeCell3_1, return_state=True, return_sequences=True)
h_c_batch, h, c = RNNC(
inputs=(h_batch2_1, c_batch2_1, h_batch2_2, c_batch2_2))
This is working fine, but now I wanted to put it together in a submodel, so that i can sum the 4 lines up in 2 lines and to have a better overview ( the tree gets big so its worth it)
class TreeCellModel(tf.keras.Model):
def __init__(self, units, num_children):
super().__init__()
self.units = units
self.num_children = num_children
self.treeCell = TreeRNNCell(units=units, num_children=num_children)
self.treeRNN = layers.RNN(self.treeCell, return_state=True, return_sequences=True)
def call(self, inputs, **kwargs):
h_c_batch, h, c = self.treeRNN(inputs=(inputs))
h_batch, c_batch = AddCellStatesLayer(units=self.units)(h_c_batch)
return h_batch, c_batch
treeCell2_1 = TreeCellModel(units=encodingBitLength, num_children=2)
h_batch2_1, c_batch2_1 = treeCell1_1(inputs=(h_batch1_1, c_batch1_1, h_batch1_2, c_batch1_2))
But now i get this error: ValueError: Layer rnn expects 1 input(s), but it received 4 input tensors. Inputs received: [<tf.Tensor 'h_batch1_1' shape=(1, 5, 19) dtype=float32>, <tf.Tensor 'c_batch1_1' shape=(1, 5, 19) dtype=float32>, <tf.Tensor 'h_batch1_2' shape=(1, 5, 19) dtype=float32>, <tf.Tensor 'c_batch1_2' shape=(1, 5, 19) dtype=float32>]
I checked the error already, and normally it gets fixed when using a tuple around the inputs. But thats what I'm already doing. I also doublechecked by outputting the type of "inputs" and it is a tuple.
Help please.
CodePudding user response:
RNN
is expecting "one" input, then you must give it "one" input. The implementation of your cell will probably not matter.
You can change your code to join the 4 tensors together and separate them inside your cell. This is possible because all your tensors have the same shape.
You could use a:
joined_inputs = layers.Lambda(lambda x: keras.backend.stack(x, axis=-1))([input1, input2, input3, input4])
Then your cell should be able to separate the inputs:
def call(self, inputTensor .....):
inputs = [inputTensor[:,:,:,i] for i in range(4)]
....