Home > database >  How to use tf.gather inside lambda function in Keras?
How to use tf.gather inside lambda function in Keras?

Time:10-03

I have a model which consists of two heads, a concatenation, and a bunch of layers; all dense. However, for the concatenation I use a Lambda layer to not only concatenate both heads, but to also change the concatenation order for each entry in the batch using gather.

For this purpose I make use of an index Input(batch_size, 512), and the Lambda layer I'm using is this:

Lambda(lambda x: gather(Concatenate()([x[0], x[1]])[1], x[2]))([h1, h2, idx])

Where h1 is the output of the first head, h2 is the output of the second head, and idx is the index tensor.

If I remove gather and leave only Concatenate, the model learns and the loss decreases. However, this way, it doesn't, and it gets stuck.

Concatenate()([h1, h2]) # this works well

Just in case, idx has a shape of (None, 512), and h1 and h2 (None, 256). Batch size is 2048.

What am I doing wrong? Any help would be much appreciated.

CodePudding user response:

I found the error. I had to specify the batch_dims=1 parameter in gather and remove the indexing in Concatenate. So now I have this and works well:

Lambda(lambda x: gather(Concatenate()([x[0], x[1]]), x[2], batch_dims=1))([h1, h2, idx])
  • Related