Home > Enterprise >  Reshaping output of MultiHeadAttention - Tensorflow
Reshaping output of MultiHeadAttention - Tensorflow

Time:07-06

We know that MultiHeadAttention's Keras API offers an output_shape argument, where you can specify the size you need your output to be projected to. However, the batch size and the sequence dimension does not seem to be alterable.

For example:

layer = layers.MultiHeadAttention(num_heads=2, key_dim=2,output_shape=[5,])
target = tf.random.normal(shape=[3,5,1])
source = tf.random.normal(shape=[3,4,1])
output_tensor = layer(target, source)

This particular call will give me a shape of TensorShape([3, 5, 5]). I understand that the batch dimension of 3 and the sequence dimension are not alterable or customizable given how the query-key projection works. Now, I wanted to reshape this Tensor using another custom layer (I can do it externally, but I wanted to keep it as a part of the model).

For reshaping, I wanted to use a custom layer, to something like [15,5], I have tried something like this:

reshape = layers.Reshape((15,5))

And then applying the reshape layer. However, I am unable to reshape as Keras think I am trying to reshape to 3,15,5, which is proven by this error: "Input to reshape is a tensor with 75 values, but the requested shape has 225 [Op:Reshape]".

I have also tried (-1,5), but the operation does not change my tensor at all and assumes it is already correct.

Is the only way to customize the MHA output doing it externally, or is there any way I can do it by a layer or a part of a model?

Appreciate the help.

CodePudding user response:

The core issue with your reshaping attempt is that you are trying to flatten the batch dimension. All of the standard layers in Keras would keep it constant and reshape rest of the tensor.

For example, when you do reshape = layers.Reshape((15,5)) you get,

Input to reshape is a tensor with 75 values, but the requested shape has 225 [Op:Reshape]

What your layer in fact trying to do is reshape the tensor while keeping the batch dimension constant. That is, given the tensor of shape [3, 5, 5] (=75 elements) it is trying to get it into a shape of [3, 15, 5] (=225 elements - thus the error)

Then your next attempt, reshape = layers.Reshape((-1,5)), tries to recorrect shape to a 3 dimensional tensor (including batch dimension) which your tensor is already in. This is why you're not seeing a change.

To meddle with the batch dimension, you'd need a Lambda layer.

reshaped_output = tf.keras.layers.Lambda(lambda x: tf.reshape(x, (-1, 5)))(output_tensor)

which gives (e.g. reshaped_out.shape),

(15, 5)

  • Related