Home > other >  can you please explain the output of this code? and how to use where function in numpy?
can you please explain the output of this code? and how to use where function in numpy?

Time:12-17

a = np.array([[1,2],[3,4]])
np.where(a<4)

answer:

array([0,0,1]), array([0,1,0])

please explain output:

answer:

array([0,0,1]),array([0,1,0])

CodePudding user response:

numpy.where gives you the indices of the truthy values.

I hope this break down helps you to understand the logic:

a = np. array([[1,2],[3,4]])
#         0  1
# array([[1, 2],   # 0
#        [3, 4]])  # 1

a<4
#            0      1
# array([[ True,  True],    0
#        [ True, False]])   1


# flat version
# row:          0     0       1      1
# col:          0     1       0      1
# # array([[ True,  True], [ True, False]])

# keep only the True
# row:  [0, 0, 1]
# col:  [0, 1, 0]

np.where(a<4)
# (array([0, 0, 1]), array([0, 1, 0]))

CodePudding user response:

np.where returns a tuple with (in this case) 2 elements.

The first element are row indices of elements meeting the condition.

The second element are column indices of elements meeting the condition.

To check it, save the result, e.g.:

ind = np.where(a < 4)

Now, when you run a[ind], you will get a (1-D) array filled with elements meeting this condition, i.e.:

array([1, 2, 3])

If your source array has more dimensions, the resulting tuple will have more components.

  • Related