Home > Back-end >  Expand several dims from tensor's shape?
Expand several dims from tensor's shape?

Time:09-28

Suppose I have a tensor t of shape (2, 3, 4), I'd like to construct a tuple/list from t:

(np.arange(2)[:,None,None], np.arange(3)[:,None], np.arange(4))

where, the number inside np.arange corresponds to the shape of t, and the number of new added axis goes down from len(t.shape)-1 (i.e. 2) to 0

Is there an fast and elegant way to constructing such tuple/list? such as using a list comprehension? I've also tried using np.expand_dims(a,(-1,-1,-1)) hoping to add 3 new axis but it doesn't allow repeated axis. Thanks for any help!

CodePudding user response:

expand_dims may not allow repeated, but that doesn't rule out sequences:

In [62]: np.expand_dims(np.arange(3),(0,2))
Out[62]: 
array([[[0],
        [1],
        [2]]])
In [63]: _.shape
Out[63]: (1, 3, 1)
In [64]: np.expand_dims(np.arange(3),(0,1,2))
Out[64]: array([[[[0, 1, 2]]]])
In [65]: _.shape
Out[65]: (1, 1, 1, 3)
In [66]: np.expand_dims(np.arange(3),(-1,-2))
Out[66]: 
array([[[0]],

       [[1]],

       [[2]]])
In [67]: _.shape
Out[67]: (3, 1, 1)

Under the covers expand_dims is just a reshape call.

In [69]: (np.arange(2)[:,None,None], np.arange(3)[:,None], np.arange(4))
Out[69]: 
(array([[[0]],
 
        [[1]]]),
 array([[0],
        [1],
        [2]]),
 array([0, 1, 2, 3]))

ix_ can create the original tuple of broadcastable arrays:

In [71]: np.ix_(np.arange(2), np.arange(3), np.arange(4))
Out[71]: 
(array([[[0]],
 
        [[1]]]),
 array([[[0],
         [1],
         [2]]]),
 array([[[0, 1, 2, 3]]]))

For broadcasting purposes, (1,1,4) is the same as (4,).

also ogrid:

In [72]: np.ogrid[:2,:3,:4]
Out[72]: 
[array([[[0]],
 
        [[1]]]),
 array([[[0],
         [1],
         [2]]]),
 array([[[0, 1, 2, 3]]])]

and

np.meshgrid(np.arange(2), np.arange(3), np.arange(4),sparse=True, indexing='ij')

I don't think times differ significantly, but your original is probably as fast as any. As for elegance, that's subjective.

CodePudding user response:

I think what you're looking for is the functionality of np.ogrid:

>>> np.ogrid[:2, :3, :4]
[array([[[0]],
 
        [[1]]]),
 array([[[0],
         [1],
         [2]]]),
 array([[[0, 1, 2, 3]]])]

In which case, more generally, you could use np.meshgrid with sparse=True.

But this might be an XY problem, because you really need np.indices:

>>> np.indices((2, 3, 4), sparse=True)  # t.shape in there
(array([[[0]],
 
        [[1]]]),
 array([[[0],
         [1],
         [2]]]),
 array([[[0, 1, 2, 3]]]))

As you noted in comments, the above is not exactly what you asked for, because it has the same number of dimensions for each array. But your use case is that you want to use these arrays in a fancy index together with t:

# x.shape == (2, 3, 2, 2)
# t.shape == (2, 2, 2)
out = x[np.arange(2)[:,None,None], t, np.arange(2)[:,None], np.arange(2)]
# out.shape == (2, 2, 2)

So you can still do this with np.indices:

i, k, l = np.indices(t.shape, sparse=True)
out = x[i, t, k, l]

Quick verification:

>>> np.array_equal(
...     x[np.arange(2)[:,None,None], t, np.arange(2)[:,None], np.arange(2)],
...     x[i, t, k, l]
... )
True
  • Related