Home > Mobile >  How to get a vector. from 2D numpy array using argmax?
How to get a vector. from 2D numpy array using argmax?

Time:08-30

I have the following numpy.ndarray:

testArr:

array([[  2.55053788e-01,   6.25406146e-01,   1.19271643e-01,
          2.68359261e-04],
       [  2.59611636e-01,   0.19562805e-01,   1.20518960e-01,
          3.06535745e-01],
       [  8.52524495e-01,   5.24317825e-01,   1.22851081e-01,
          3.06610862e-04],
       [  2.55068243e-01,   6.24345124e-01,   1.20263465e-01,
          3.23178538e-04],
       [  2.46678621e-01,   6.29301071e-01,   1.23693809e-01,
          3.26490292e-04]], dtype=float32)

If I do testVec = np.argmax(testArr), I get a single number. How can I get a vector of 0, 1 or 2, depending on the maximum value in each row of 2D array testArr?

Expected output:

[1, 3, 0, 1, 1]

CodePudding user response:

If you give the documentation a look, you'll see there is an axis parameter which allows you to choose along which axis you want to perform the operation. From the docs:

Returns the indices of the maximum values along an axis.

Int this case, you want:

np.argmax(a, axis=1)
# array([1, 3, 0, 1, 1], dtype=int64)

CodePudding user response:

By default, np.argmax gives you the index of the maximum value in the flattened array. To get the maximum along a single dimension (for example, the maximum value in each row), you have to specify the keyword argument axis. This has to be an integer: 0 for the columns, 1 for the rows. (Or any integer up to n-1 if your array has n dimensions.)

import numpy as np
testArr = np.array([[  2.55053788e-01,   6.25406146e-01,   1.19271643e-01,
          2.68359261e-04],
       [  2.59611636e-01,   0.19562805e-01,   1.20518960e-01,
          3.06535745e-01],
       [  8.52524495e-01,   5.24317825e-01,   1.22851081e-01,
          3.06610862e-04],
       [  2.55068243e-01,   6.24345124e-01,   1.20263465e-01,
          3.23178538e-04],
       [  2.46678621e-01,   6.29301071e-01,   1.23693809e-01,
          3.26490292e-04]], dtype=np.float32)
np.argmax(testArr, axis=1)
>>> array([1, 3, 0, 1, 1])
  • Related