I am trying to build an optimized SVM classification model using scikit-learn and I am fairly new in Python, not really in ML in general. Here's the code I am using:
# Training the SVM model on the Training set
from sklearn.svm import SVC
classifier = SVC()
# define evaluation
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=5, random_state=1)
# define search space
space = dict()
space['kernel'] = ["linear", "rbf", "sigmoid", "poly"]
space['C'] = [0.1, 1, 10, 100, 1000]
space['gamma'] = [1, 0.1, 0.01, 0.001, 0.0001]
space['tol'] = [1e-3, 1e-4, 1e-5, 1e-6]
# define search
search = RandomizedSearchCV(classifier, space, n_iter=500, scoring='accuracy', n_jobs=-1, cv=cv, random_state=1)
# execute search
result = search.fit(X_train, y_train)
# summarize result
print('Best Score: %s' % result.best_score_)
print('Best Hyperparameters: %s' % result.best_params_)
bestModel = result.best_estimator_
#Test
a = X_train
b = y_train
grid_predictions = bestModel.predict(a)
accuracy_score(b, grid_predictions)
I am trying to get how well my training data is classified. My question is: Why I am getting different accuracy outputs from result.best_score_
(which is the best search model's accuracy) and accuracy_score(b, grid_predictions)
(which is where the exact training data is fed to the model with best performance)?
CodePudding user response:
The difference is because best_score_
shows the best score (in your case, accuracy) of the best estimator, "estimator which gave highest score (or smallest loss if specified) on the left out data". The left out data comes from your cross validation, meaning that is the accuracy on an unseen fold of your CV (remember that this happens inside a RandomizedSearchCV
).
On the other hand, the output of accuracy_score(b, grid_predictions)
is calculated by the same predictor but over unseen data: not a fold, but using all of your training data (based on the code you provided).
This means that both metrics are calculated in the same manner, using the same model, but predicting on different sets of data.