Home > Mobile >  Filter out index from 2-D numpy array
Filter out index from 2-D numpy array

Time:11-01

Say I have a 2-D numpy array,

A = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] # Shape: (3, 4)

and I have another 1-D array,

B = [0, 2, 1] # Shape: (3,)

I want to extract an element from the index of A's ith row corresponding to B's ith element. Meaning, if we consider the 3nd element in B, output will be 1st element of 3rd row of A (that is, 8).

Thus, the desired output is:

C = [0, 6, 8]

Now, this can easily be done using a for loop, however, I am looking for some other optimized ways of doing this.

I tried np.take(), however it's not working as desired.

CodePudding user response:

If your B array is the position along the second axis which you'd like for each element in the first axis, just provide a corresponding set of indices into the first dimension:

In [4]: A[range(A.shape[0]), B]
Out[4]: array([0, 6, 9])

This is equivalent to:

In [5]: A[[0, 1, 2], [1, 2, 0]]
Out[5]: array([0, 6, 9])

CodePudding user response:

Simple enough with advanced indexing

C = A[np.arange(B.size), B]

which, for your example, yields [0, 6, 9] which is the correct result as @MichaelDelgado points out.

CodePudding user response:

Answer: Use List Comprehension with Enumerate

You could list comprehension, using enumerate which will give you a list of touples in the syntax of (index, value). So for instance, enumerate([0,2]) will result in [(0,0), (1,2)].

So, in order to do so, it would be

In: C = [A[point[0]][point[1]] for point in enumerate(B)]
Out: [0,6,9]

which would be the answer you're looking for. I believe the 8 you have in your supposed answer is incorrect.

  • Related