Home > Enterprise >  Cannot evaluate f1-score on sklearn cross_val_score
Cannot evaluate f1-score on sklearn cross_val_score

Time:03-25

I'm working on a multiclass classification problem.

I can get f1 scores when using train_test_split and then getting the classification report as shown below:

    X_train, X_test, y_train, y_test = train_test_split(data, y_data,test_size=0.3, random_state = 1, stratify=y_data)
    knn_clf=KNeighborsClassifier(n_neighbors)
    knn_clf.fit(X_train,y_train)
    ypred=knn_clf.predict(X_test) #These are the predicted output values
    print(classification_report(y_test, ypred))

However, due to the size of the dataset cross-validation is more applicable. The problem is that I cannot get the f1 scores using the cross-validation method. Without the addition of the f1 the cross-validation looks like this:

    knn_cv = KNeighborsClassifier(n_neighbors)
    cv_scores = cross_val_score(knn_cv, data, y_data, cv=3)

    #print each cv score (accuracy) and average them
    print("cv_scores: ", cv_scores)
    print('cv_scores mean:{}'.format(np.mean(cv_scores)))

Which outputs: cv_scores: [0.83333333 1. 1. ] cv_scores mean:0.9444444444444445

When I add in F1 as follows:

print(cross_val_score(knn_cv, data, y_data, scoring="f1", cv = 3))

It outputs: [nan nan nan] cv_scores: [nan nan nan] cv_scores mean:nan

Any help would be greatly appreciated, thanks!

EDIT: The error is raises is: ValueError: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted'].

However when running:

print(cross_val_score(knn_cv, data, y_data, average ='weighted', scoring='f1', cv = 3, error_score="raise"))

The returned error is: TypeError: cross_val_score() got an unexpected keyword argument 'average'

which occurs for all None, 'micro', 'macro', and 'weighted'

CodePudding user response:

The error message's reference to the average argument is for the function f1_score. When specifying the scorer for cross_val_score as a string, the correct specification is scoring="f1_weighted" etc.; see https://scikit-learn.org/stable/modules/model_evaluation.html#common-cases-predefined-values.

  • Related