I would like to calculate a generalised inner product in TensorFlow, similarly to this discussion for numpy.
In particular, I would like a function inner_product(f,a,b)
that takes a function f
(of two 1D tensors, which returns a scalar tensor) and applies f
to slices of a
and b
such that the i,jth element of the output is given by f(a[i,:], b[:,j])
.
This is just tf.matmul
if f(x, y) = tf.reduce_sum(x * y)
. However I'm struggling to come up with an efficient solution for other f
functions. Something that gets the correct answer (assuming a function f
with arguments f(x, y)
) is
def inner_product(f, a, b):
def f_row_function(row, a):
return tf.map_fn(partial(f, y=row), a)
return tf.transpose(
tf.map_fn(partial(f_row_function, a=a), tf.transpose(b))
)
but this is very slow (it's doing effectively two loops over f
).
As an example, with
a = tf.cast(
tf.constant([[1, 2, 3, 5, 2], [3, 4, 6, 3, 2], [1, 5, 6, 8, 1]]), tf.float32
)
b = tf.cast(
tf.constant(
[[4, 2, 5, 3], [4, 9, 3, 4], [1, 7, 8, 3], [4, 3, 5, 7], [1, 6, 7, 9]]
),
tf.float32,
)
def f(x, y):
r = tf.norm(x - y)
return tf.exp(-((0.1 * r) ** 2))
inner_product(f, a, b)
should give
<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[0.82695913, 0.423162 , 0.5117086 , 0.5433509 ],
[0.75578374, 0.6505091 , 0.6838614 , 0.4771139 ],
[0.6004956 , 0.506617 , 0.50157607, 0.45384485]], dtype=float32)>
CodePudding user response:
Here my attempt to trade some memory for time. The idea is to generate the necessary pairs of a
and b
beforehand and then apply the reduction f
on this (now larger) tensor once. Let me know if that's faster for you.
import tensorflow as tf
a = tf.cast(
tf.constant([[1, 2, 3, 5, 2], [3, 4, 6, 3, 2], [1, 5, 6, 8, 1]]), tf.float32
)
b = tf.cast(
tf.constant(
[[4, 2, 5, 3], [4, 9, 3, 4], [1, 7, 8, 3], [4, 3, 5, 7], [1, 6, 7, 9]]
),
tf.float32,
)
@tf.function
def modified_f(x):
r = tf.norm(x[..., 0] - x[..., 1], axis=-1)
return tf.exp(-((0.1 * r) ** 2))
# Create indices for the axes we want to iterate (i.e. i and j)
a_is = tf.range(a.shape[0])
b_js = tf.range(b.shape[-1])
print(a_is.shape) # (3,)
print(b_js.shape) # (4,)
A_IS, B_JS = tf.meshgrid(a_is, b_js) # get all combinations of indices. The first two axes now correspond to the [i,j]
all_a = tf.gather(a, A_IS) # Now we extract the corresponding values from "a"
all_b = tf.gather(tf.transpose(b), B_JS) # and now from "b"
x = tf.stack([all_a, all_b], axis=-1) # stack both into a single array, you probably could skip this...
print(all_a.shape) # (4, 3, 5)
print(all_b.shape) # (4, 3, 5)
print(x.shape) # (4, 3, 5, 2)
@tf.function
def modified_f(x):
r = tf.norm(x[..., 0] - x[..., 1], axis=-1)
return tf.exp(-((0.1 * r) ** 2))
out = tf.transpose(modified_f(x))
print(out.shape) # (3, 4)
print(out)
# [[0.82695913 0.423162 0.5117086 0.5433508 ]
# [0.75578374 0.65050906 0.68386143 0.47711396]
# [0.6004956 0.50661695 0.50157607 0.45384485]], shape=(3, 4), dtype=float32)
CodePudding user response:
Inspired by @André 's answer, the solution I have ended up using (and runs nice and quickly on the real dataset sizes I'm working with) looks like
def f(x):
r = tf.norm(x[0] - x[1], axis=-1)
return tf.exp(-((0.1 * r) ** 2))
def inner_product(f, a, b):
all_b = tf.tile(tf.transpose(b)[:, tf.newaxis, :], [1, tf.shape(a)[0], 1])
all_a = tf.tile(a[tf.newaxis, :, :], [tf.shape(b)[-1], 1, 1])
x = [all_a, all_b]
return tf.transpose(f(x))
This does an explicit construction of the tiled a
and b
matrices @André 's method generates using tf.gather
.