Home > Mobile >  How to index a probability matrix using an argmax-matrix in NumPy?
How to index a probability matrix using an argmax-matrix in NumPy?

Time:09-16

Assume that we have defined the following matrices:

# shape = (2, 3)
m = np.array(
    [
        [0, 1, 0],
        [1, 0, 1]
    ]
)

# shape = (2, 3, 2)
p = np.array(
    [
        [
            [0.6, 0.4], [0.3, 0.7], [0.8, 0.2]
        ],
        [
             [0.35, 0.65], [0.7, 0.3], [0.1, 0.9]
        ],
    ]
)

In this case p is a probability matrix, and m contains the index of the maximum probability.

How can we index p using m to get a matrix of shape = m.shape, where each element would be the probability corresponding to the index in m, such as:

result = np.array(
    [
        [0.6, 0.7, 0.8],
        [0.65, 0.7, 0.9]
    ]
)

CodePudding user response:

Using np.indices:

>>> p[tuple(np.indices(m.shape))   (m,)]
array([[0.6 , 0.7 , 0.8 ],
       [0.65, 0.7 , 0.9 ]])

Relatively beautiful two line solution:

>>> ii, jj = np.indices(m.shape, sparse=True)   # sparse=True use less memory
>>> p[ii, jj, m]
array([[0.6 , 0.7 , 0.8 ],
       [0.65, 0.7 , 0.9 ]])

Since m ranges from 0 to 1, np.where also works:

>>> np.where(m.ravel() == 0, *p.reshape(-1, 2).T).reshape(m.shape)
array([[0.6 , 0.7 , 0.8 ],
       [0.65, 0.7 , 0.9 ]])
  • Related