Home > Mobile >  Retrieving same output for different instances for XGBoost regression algorithm
Retrieving same output for different instances for XGBoost regression algorithm

Time:10-04

I have the following data using the XGBoost regression algorithm to perform prediction. The problem is, however, that the regression algorithm predicts the same output for any input and I'm not really sure why.

data= pd.read_csv("depthwise_data.csv", delimiter=',', header=None, skiprows=1, names=['input_size','input_channels','conv_kernel','conv_strides','running_time'])

X = data[['input_size', 'input_channels','conv_kernel', 'conv_strides']]
Y = data[["running_time"]]

X_train, X_test, y_train, y_test = train_test_split(
    np.array(X), np.array(Y), test_size=0.2, random_state=42)

y_train_log = np.log(y_train)
y_test_log = np.log(y_test)

xgb_depth_conv = xgb.XGBRegressor(objective ='reg:squarederror',
                  n_estimators = 1000,
                   seed = 123,
                    tree_method = 'hist',
                    max_depth=10)

xgb_depth_conv.fit(X_train, y_train_log)
y_pred_train = xgb_depth_conv.predict(X_train)
#y_pred_test = xgb_depth_conv.predict(X_test)

X_data=[[8,576,3,2]] #instance
X_test=np.log(X_data)
y_pred_test=xgb_depth_conv.predict(X_test)
print(np.exp(y_pred_test))


MSE_test, MSE_train = mse(y_test_log,y_pred_test), mse(y_train_log, y_pred_train)
R_squared = r2_score(y_pred_test,y_test_log)
print("MSE-Train = {}".format(MSE_train))
print("MSE-Test = {}".format(MSE_test))
print("R-Squared: ", np.round(R_squared, 2))

Output for first instance

X_data=[[8,576,3,2]]
print(np.exp(y_pred_test))
[0.7050679]

Output for second instance

X_data=[[4,960,3,1]]
print(np.exp(y_pred_test))
[0.7050679]

CodePudding user response:

Your problem stems from this X_test=np.log(X_data)

Why are you applying log on the test cases while you have not applied it on the training samples?

If you take away the np.log completely, even from the target (y), you get really good results. I tested it myself with the data you provided us with.

  • Related