I have the following Pandas DataFrame:
true_y m1_labels m1_probs_0 m1_probs_1 m2_labels m2_probs_0 m2_probs_1
0 0 0.628205 0.371795 1 0.491648 0.508352
0 0 0.564113 0.435887 1 0.474973 0.525027
0 1 0.463897 0.536103 0 0.660307 0.339693
0 1 0.454559 0.545441 0 0.512349 0.487651
0 0 0.608345 0.391655 1 0.499531 0.500469
0 0 0.816127 0.183873 1 0.456669 0.543331
0 1 0.442693 0.557307 0 0.573354 0.426646
1 0 0.653497 0.346503 1 0.487212 0.512788
0 1 0.392380 0.607620 0 0.627419 0.372581
0 1 0.375816 0.624184 0 0.631532 0.368468
This is a collection of disagreeing ML model predictions with labels and label probabilities of two models (m1, m2
) and the actual label (true_y
).
I would like to have any of the hard label predictions (m1_labels
or m2_labels
) which have higher probability to the respective predicted class of their respective models per row. So for row #1, I expect 0
(as m1
model has higher probability for its prediction 0
than m2
model for its prediction 1
). Basically this is intended to be a manual voting ensemble of the two models.
How can I get this vector with a Pandas query?
CodePudding user response:
You can use the apply function for this:
df.apply(lambda x: x["m1_labels"] if max(x["m1_probs_0"], x["m1_probs_1"]) > max(x["m2_probs_0"], x["m2_probs_1"]) else x["m2_labels"], axis=1)
This select the first model label if the probabilty of its predicted class is higher than the probability of the second model predicted class. Otherwise, it selects the label from the second model.
CodePudding user response:
You can use:
# get max probability for m1
p1 = df.filter(like='m1_probs').max(axis=1)
# get max probability for m2
p2 = df.filter(like='m2_probs').max(axis=1)
# m1_label if it has a greater probability, else m2_label
df['best'] = df['m1_labels'].where(p1.gt(p2), df['m2_labels'])
output:
true_y m1_labels m1_probs_0 m1_probs_1 m2_labels m2_probs_0 m2_probs_1 best
0 0 0 0.628205 0.371795 1 0.491648 0.508352 0
1 0 0 0.564113 0.435887 1 0.474973 0.525027 0
2 0 1 0.463897 0.536103 0 0.660307 0.339693 0
3 0 1 0.454559 0.545441 0 0.512349 0.487651 1
4 0 0 0.608345 0.391655 1 0.499531 0.500469 0
5 0 0 0.816127 0.183873 1 0.456669 0.543331 0
6 0 1 0.442693 0.557307 0 0.573354 0.426646 0
7 1 0 0.653497 0.346503 1 0.487212 0.512788 0
8 0 1 0.392380 0.607620 0 0.627419 0.372581 0
9 0 1 0.375816 0.624184 0 0.631532 0.368468 0