I have a model which consists of two heads, a concatenation, and a bunch of layers; all dense. However, for the concatenation I use a Lambda
layer to not only concatenate both heads, but to also change the concatenation order for each entry in the batch using gather
.
For this purpose I make use of an index Input(batch_size, 512)
, and the Lambda
layer I'm using is this:
Lambda(lambda x: gather(Concatenate()([x[0], x[1]])[1], x[2]))([h1, h2, idx])
Where h1
is the output of the first head, h2
is the output of the second head, and idx
is the index tensor.
If I remove gather
and leave only Concatenate
, the model learns and the loss decreases. However, this way, it doesn't, and it gets stuck.
Concatenate()([h1, h2]) # this works well
Just in case, idx has a shape of (None, 512), and h1 and h2 (None, 256). Batch size is 2048.
What am I doing wrong? Any help would be much appreciated.
CodePudding user response:
I found the error. I had to specify the batch_dims=1 parameter in gather and remove the indexing in Concatenate. So now I have this and works well:
Lambda(lambda x: gather(Concatenate()([x[0], x[1]]), x[2], batch_dims=1))([h1, h2, idx])