I am doing a text classification, I want to use the class probability for the three classes that have the highest probabilities.I need your help. Thanks
import numpy as np
probability = get_predict_proba(X)
print(probability)
[[0.15682828 0.11664342 0.11088368 0.12925814 0.09544043 0.10655934 0.14538805 0.13899866]]
CodePudding user response:
This:
np.argsort(probability)[-3:] # 3 'best' classes
probability[np.argsort(probability)[-3:]] # 3 'best' probabilities
(np.argsort
gives you the sorted indices.)