Home > Software design >  How to create custom layer that can pass two inputs and tf.matmul them
How to create custom layer that can pass two inputs and tf.matmul them

Time:06-20

the codes of my own model is

class KeyQuery(keras.layers.Layer):
    def __init__(self, v):
        super(KeyQuery, self).__init__()
        self.v = tf.convert_to_tensor(v)
        
    def build(self, input_shape): 
        self.v = tf.Variable(self.v, trainable = True)
        print(self.v.shape)
    def call(self, inputs1, inputs2):
        y1 = tf.matmul(self.v, tf.transpose(inputs1))
        y2 = tf.matmul(y2, inputs2)
        return y2
    
keyquery = KeyQuery(v)

inputs1 = keras.Input(shape=(50,768))
inputs2 = keras.Input(shape=(50,3))
outputs = keyquery(inputs1,inputs2)
model = keras.Model([inputs1,inputs2], outputs)
model.summary()

where v for keyquery = KeyQuery(v) is a 2d array of size (1,768), which can also be seen as a vector.

My ideal situation is that, in y1 = tf.matmul(self.v, tf.transpose(inputs1)), because self.v.shape is (1,768), and inputs1 shape is (50,768), so shape of y1 should be (1, 50). and shape of inputs2 is (50,3), so shape of y2 should be (1, 3).

So when the inputs1.shape is (None, 50, 768) and inputs2.shape is (None, 50, 3) considering the batch dimension, it should return a result of shape (None, 1, 3). Please note that keras.Input does not require batch dimension.

But in real case it returns ValueError: Dimensions must be equal, but are 768 and 50 for '{{node key_query_4/MatMul}} = BatchMatMulV2[T=DT_FLOAT, adj_x=false, adj_y=false](key_query_4/MatMul/ReadVariableOp, key_query_4/transpose)' with input shapes: [1,768], [768,50,?]. Because of the batch dimension. I don't know how to fix this for my matrix multiplication.

CodePudding user response:

You need to consider:

  1. In The values that pass to the build_method, we have the shape of a tensor and in the call_method we have the value of the tensor.
  2. You need to use perm from tf.transpose() to fix the dimension of the batch and swap other dimensions.

Code:

import tensorflow as tf
import numpy as np

class KeyQuery(tf.keras.layers.Layer):
    def __init__(self, v):
        super(KeyQuery, self).__init__()
        self.v = tf.convert_to_tensor(v, dtype='float32')
        
    def build(self, input_shape): # here we have shape of input_tensor
        self.v = tf.Variable(self.v, trainable = True)

    def call(self, inputs): # here we have value of input_tensor
        y1 = tf.matmul(self.v, tf.transpose(inputs[0], perm=[0,2,1]))
        y2 = tf.matmul(y1, inputs[1])
        return y2
    
keyquery = KeyQuery(np.random.rand(1,768))
out = keyquery((tf.random.uniform((25, 50, 768)), tf.random.uniform((25, 50, 3))))
print(out.shape)
# (25, 1, 3)

# or with model
inputs1 = tf.keras.Input(shape=(50,768))
inputs2 = tf.keras.Input(shape=(50,3))
outputs = keyquery((inputs1,inputs2))
model = tf.keras.Model([inputs1,inputs2], outputs)
model.summary()

Output:

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_13 (InputLayer)          [(None, 50, 768)]    0           []                               
                                                                                                  
 input_14 (InputLayer)          [(None, 50, 3)]      0           []                               
                                                                                                  
 key_query_27 (KeyQuery)        (None, 1, 3)         768         ['input_13[0][0]',               
                                                                  'input_14[0][0]']               
                                                                                                  
==================================================================================================
Total params: 768
Trainable params: 768
Non-trainable params: 0
__________________________________________________________________________________________________
  • Related