Home > Software engineering >  Writing a custom 'scoring' function for sklearns cross validation
Writing a custom 'scoring' function for sklearns cross validation

Time:12-19

I'm trying to perform hyperparameter optimisation (Tree Parzen Estimation) for a multilabel classification problem. My output class or target feature is a set of 14 bits each of which can be on or off. The output features are also not balanced and hence I'm trying to use a 'balanced_accuracy_score' for my cross validation. However this expects labels as an input for the confusion matrix so I defined the following wrapper around the balanced scoring method,

_balanced_accuracy_score = lambda grnd, pred: balanced_accuracy_score(grnd.argmax(axis=1), pred.argmax(axis=1))

Still I get an error telling me that,

<lambda>() takes 2 positional arguments but 3 were given

How should I write my scoring function so that cross_validate accepts it?

CodePudding user response:

If you can reframe multi-label into multiclass: you can use the balanced_accuracy_score inside of cross_validate after wrapping the metric with make_scorer.

MRE (based on 14-class classification):

from sklearn.datasets import make_classification
from sklearn.metrics import balanced_accuracy_score, make_scorer
from sklearn.model_selection import cross_validate
from sklearn.ensemble import RandomForestClassifier

X, y = make_classification(n_samples=1000, n_features=50, n_clusters_per_class=1, n_informative=4, n_classes=14)

cross_validate(
    RandomForestClassifier(max_depth=4),
    X,
    y,
    scoring=make_scorer(balanced_accuracy_score),
)

Output:

{
    'fit_time': array([0.18832374, 0.17689013, 0.17627716, 0.17738914, 0.19771028]),
    'score_time': array([0.00904989, 0.00889611, 0.00884914, 0.00906992, 0.00915742]),
    'test_score': array([0.54115646, 0.53129252, 0.49727891, 0.54353741, 0.48401361])
}

Currently (scikit-learn==1.2.0) balanced_accuracy_score does not support multilabel:

from sklearn.datasets import make_multilabel_classification
from sklearn.metrics import balanced_accuracy_score

X, y = make_multilabel_classification(n_samples=1000, n_features=50, n_classes=14)
print(balanced_accuracy_score(y, y))
ValueError: multilabel-indicator is not supported
  • Related