I created a function to perform a grid search for the optimal parameters of my xgboost classifier. because my training set is large I want to limit the grid search to a sample of about 5000 observations.
this is the function:
def xgboost_search(X, y, search_verbose=1):
params = {
"gamma":[0.5, 1, 1.5, 2, 5],
"max_depth":[3,4,5,6],
"min_child_weight": [100],
"subsample": [0.6, 0.8, 1.0],
"colsample_bytree": [0.6, 0.8, 1.0],
"learning_rate": [0.1, 0.01, 0.001]
}
xgb = XGBClassifier(objective="binary:logistic", eval_metric="auc", use_label_encoder=False)
skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=1234)
grid_search = GridSearchCV(estimator=xgb, param_grid=params, scoring="roc_auc", n_jobs=1, cv=skf.split(X,y), verbose=search_verbose)
grid_search.fit(X, y)
print("Best estimator: ")
print(grid_search.best_estimator_)
print("Parameters: ", grid_search.best_params_)
print("Highest AUC: %.2f" % grid_search.best_score_)
return grid_search.best_params_
this is what I tried to get the 5000 observations:
rows = random.sample(list(X_res), 5000)
model_params = xgboost_search(X_res[rows], Y_res[rows])
I got this error:
IndexError Traceback (most recent call last)
/var/folders/cf/yh2vvpdn0klby68k9zrttfv00000gp/T/ipykernel_80963/3533706692.py in <module>
1 rows = random.sample(list(X_res), 5000)
----> 2 model_params = xgboost_search(X_res[rows], Y_res[rows])
IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices
I think this is because my 'X_res' and 'Y_res' are arrays and 'rows' is a list.
can someone help?
CodePudding user response:
Arrays can be indexed by a list
, the issue here is the type of indices which are not integer in your list:
only integers ... are valid indices
in the error is probably what was wrong.
This is because you sampled 5000 elements of Xres
with random.sample(list(X_res), 5000)
, not 5000 indices between 0
and len(Xres)
as you probably meant to.
Try:
rows = random.sample(range(len(Xres)), 5000)
model_params = xgboost_search(X_res[rows], Y_res[rows])