I have some batched input x
of shape [batch, time, feature]
, and some batched indices i
of shape [batch, new_time]
which I want to gather into the time dim of x
. As output of this operation I want a tensor y
of shape [batch, new_time, feature]
with values like this:
y[b, t', f] = x[b, i[b, t'], f]
In Tensorflow, I can accomplish this by using the batch_dims: int
argument of tf.gather
: y = tf.gather(x, i, axis=1, batch_dims=1)
.
In PyTorch, I can think of some functions which do similar things:
torch.gather
of course, but this does not have an argument similar to Tensorflow'sbatch_dims
. The output oftorch.gather
will always have the same shape as the indices. So I would need to unbroadcast thefeature
dim intoi
before passing it totorch.gather
.torch.index_select
, but here, the indices must be one-dimensional. So to make it work I would need to unbroadcastx
to add a "batch * new_time
" dim, and then aftertorch.index_select
reshape the output.torch.nn.functional.embedding
. Here, the embedding matrices would correspond tox
. But this embedding function does not support the weights to be batched, so I run into the same issue as fortorch.index_select
(looking at the code,tf.embedding
usestorch.index_select
under the hood).
Is it possible to accomplish such gather operation without relying on unbroadcasting which is inefficient for large dims?
CodePudding user response:
This is actually the most frequent case: when input and index tensors don't perfectly match the number of dimensions. You can still utilize torch.gather
though since you can rewrite your expression:
y[b, t, f] = x[b, i[b, t], f]
as:
y[b, t, f] = x[b, i[b, t, f], f]
which ensures all three tensors have an equal number of dimensions. This reveals a third dimension on i
, which we can easily create for free by unsqueezing a dimension and expanding it to the shape of x
. You can do so with i[:,None].expand_as(x)
.
Here is a minimal example:
>>> b = 2; t = 3; f = 1
>>> x = torch.rand(b, t, f)
>>> i = torch.randint(0, t, (b, f))
>>> x.gather(1, i[:,None].expand_as(x))