Home > Software engineering >  Correct axes to use dot product to evaluate the final output of a listwise learning to rank model
Correct axes to use dot product to evaluate the final output of a listwise learning to rank model

Time:02-16

I'm not being able to find the correct configuration to pass to a tf.keras.layers.Dot to make a pairwise dot product when the entries each have lists of values, like from a listwise learning to rank model. For instance, suppose:

repeated_query_vector = [
  [[1, 2], [1, 2]],
  [[3, 4], [3, 4]]
]

document_vectors = [
  [[5, 6], [7, 8]],
  [[9, 10], [11, 12]],
]

Calling tf.keras.layers.Dot(axes=??)([repeated_query_vector, document_vectors]) I want the output to be like:

[
  [1*5   2*6, 1*7   2*8]
  [3*9   4*10, 3*11   4*12]
]

All examples I found in the documentation have one dimension less than my use case. What would be the correct value of axes for this call?

CodePudding user response:

You should be able to solve this with tf.keras.layers.Multiply() and tf.reshape:

import tensorflow as tf

repeated_query_vector = tf.constant([
  [[1, 2], [1, 2]],
  [[3, 4], [3, 4]]
])

document_vectors = tf.constant([
  [[5, 6], [7, 8]],
  [[9, 10], [11, 12]],
])

multiply_layer = tf.keras.layers.Multiply()
result = multiply_layer([repeated_query_vector, document_vectors])
shape = tf.shape(result)
result = tf.reduce_sum(tf.reshape(result, (shape[0], shape[1] * shape[2])), axis=1, keepdims=True)
tf.Tensor(
[[ 40]
 [148]], shape=(2, 1), dtype=int32)

Or with tf.keras.layers.Dot and tf.reshape:

import tensorflow as tf

repeated_query_vector = tf.constant([
  [[1, 2], [1, 2]],
  [[3, 4], [3, 4]]
])

document_vectors = tf.constant([
  [[5, 6], [7, 8]],
  [[9, 10], [11, 12]],
])

dot_layer = tf.keras.layers.Dot(axes=1)
result = dot_layer([tf.reshape(repeated_query_vector, (repeated_query_vector.shape[0], repeated_query_vector.shape[1] * repeated_query_vector.shape[2])),
                         tf.reshape(document_vectors, (document_vectors.shape[0], document_vectors.shape[1] * document_vectors.shape[2]))])
tf.Tensor(
[[ 40]
 [148]], shape=(2, 1), dtype=int32)
  • Related