Home > database >  Use .take() to index multidimensional array
Use .take() to index multidimensional array

Time:11-03

I have a multidimensional array of shape (n,x,y). For this example can use this array

A = array([[[ 0,  1,  2],
            [ 3,  4,  5],
            [ 6,  7,  8],
            [ 9, 10, 11]],

           [[12, 13, 14],
            [15, 16, 17],
            [18, 19, 20],
            [21, 22, 23]],

           [[24, 25, 26],
            [27, 28, 29],
            [30, 31, 32],
            [33, 34, 35]]])

I then have another multidimensional array that has index values that I want to use on the original array, A. This has shape (z,2) and the values represent row values index’s

Row_values = array([[0,1],
                    [0,2],
                    [1,2],
                    [1,3]])

So I want to use all the index values in row_values to apply to each of the three arrays in A so I end up with a final array of shape (12,2,3)

Result = ([[[0,1,2],
            [3,4,5]],
           [[0,1,2],
            [6,7,8]],
           [[3,4,5],
            [6,7,8]]
           [[3,4,5],
            [9,10,11],
           [[12,13,14],
            [15,16,17]],
           [[12,13,14],
            [18,19,20]],
           [[15,16,17],
            [18,19,20]],
           [[15,16,17],
            [21,22,23]],
           [[24,25,26],
            [27,28,29]],
           [[24,25,26],
            [30,31,32]],
           [[27,28,29],
            [30,31,32]],
           [[27,28,29],
            [33,34,35]]]

I have tried using np.take() but haven’t been able to make it work. Not sure if there’s another numpy function that is easier to use

CodePudding user response:

We can advantage of NumPy's advanced indexing and using np.repeat and np.tile along with it.

cidx = np.tile(Row_values, (A.shape[0], 1))
ridx = np.repeat(np.arange(A.shape[0]), Row_values.shape[0])

out = A[ridx[:, None], cidx]
# out.shape -> (12, 2, 3)

Using np.take

np.take(A, Row_values, axis=1).reshape((-1, 2, 3))
# Or
A[:, Row_values].reshape((-1, 2, 3))

Output:

array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 0,  1,  2],
        [ 6,  7,  8]],

       [[ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 3,  4,  5],
        [ 9, 10, 11]],

       [[12, 13, 14],
        [15, 16, 17]],

       [[12, 13, 14],
        [18, 19, 20]],

       [[15, 16, 17],
        [18, 19, 20]],

       [[15, 16, 17],
        [21, 22, 23]],

       [[24, 25, 26],
        [27, 28, 29]],

       [[24, 25, 26],
        [30, 31, 32]],

       [[27, 28, 29],
        [30, 31, 32]],

       [[27, 28, 29],
        [33, 34, 35]]])

  • Related