Home > Mobile >  TensorFlow Inner Product Multiplication
TensorFlow Inner Product Multiplication

Time:02-28

I am trying to take an inner product of two vectors in tensorflow, for which I use the dot product:

x = tf.constant([1, 2, 3], dtype=tf.int32)
y = tf.constant([4, 5, 6], dtype=tf.int32)

# desired result
tf.tensordot(x, y, axes=1)
# Output: 32

Now I'm dealing with batch tensors which both have shape (32, 3). I still want the same operation, yielding an output vector of shape (32, ). My only succesful attempt so far is:

tf.linalg.diag_part(tf.tensordot(x, y, axes=[[1], [1]]))
# Output: <tf.Tensor: shape=(32,)>
# where each entry is the inner product of the vectors of length 3

However, I compute 32 as many inner products as required.

How do I solve my problem more efficiently?

CodePudding user response:

Think about what this operation is at the end of the day: Element-wise multiplication and a sum over axis 1. So you can just do this

tf.reduce_sum(x * y, axis=1)
  • Related