Home > database >  NumPy one-liner equivalent to this loop, condition changes according to index
NumPy one-liner equivalent to this loop, condition changes according to index

Time:11-23

In the code below I want to replace the loop in a compact NumPy one-liner equivalent. I think the code is self-explanatory but here is a short explanation: in the array of prediction, I one to threshold the prediction according to a threshold specific to the prediction (i.e. if I predict 1 I compare it to th[1] and if I predict 2 I compare it to th[2]. The loop does the work, but I think a one-liner would be more compact and generalizable.

import numpy as np

y_pred = np.array([1, 2, 2, 1, 1, 3, 3])
y_prob = np.array([0.5, 0.5, 0.75, 0.25, 0.75, 0.60, 0.40])

th = [0, 0.4, 0.7, 0.5]

z_true = np.array([0, 2, 0, 1, 0, 0, 3])
z_pred = y_pred.copy()

# I want to replace this loop with a NumPy one-liner
for i in range(len(z_pred)):
    if y_prob[i] > th[y_pred[i]]:
        z_pred[i] = 0

print(z_pred)

CodePudding user response:

If you make th a numpy array:

th = np.array(th)

z_pred = np.where(y_prob > th[y_pred], 0, y_pred)

Or with in-line conversion to array:

z_pred = np.where(y_prob > np.array(th)[y_pred], 0, y_pred)

Output: array([0, 2, 0, 1, 0, 0, 3])

Intermediates:

np.array(th)
# array([0. , 0.4, 0.7, 0.5])

np.array(th)[y_pred]
# array([0.4, 0.7, 0.7, 0.4, 0.4, 0.5, 0.5])

y_prob > np.array(th)[y_pred]
# array([ True, False,  True, False,  True,  True, False])
  • Related