I am trying to take an inner product of two vectors in tensorflow, for which I use the dot product:
x = tf.constant([1, 2, 3], dtype=tf.int32)
y = tf.constant([4, 5, 6], dtype=tf.int32)
# desired result
tf.tensordot(x, y, axes=1)
# Output: 32
Now I'm dealing with batch tensors which both have shape (32, 3)
. I still want the same operation, yielding an output vector of shape (32, )
. My only succesful attempt so far is:
tf.linalg.diag_part(tf.tensordot(x, y, axes=[[1], [1]]))
# Output: <tf.Tensor: shape=(32,)>
# where each entry is the inner product of the vectors of length 3
However, I compute 32 as many inner products as required.
How do I solve my problem more efficiently?
CodePudding user response:
Think about what this operation is at the end of the day: Element-wise multiplication and a sum over axis 1. So you can just do this
tf.reduce_sum(x * y, axis=1)