I'm trying to make a single variable regression using decision tree regression. However when I'm plotting the results. Multiple lines show in the plot just like the photo below. I didn't encounter this problem when I used linear regression.
https://snipboard.io/v9QaoC.jpg - I can't post images since i have less than 10 reputation
My code:
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
# Fit regression model
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(X_train.values.reshape(-1, 1), y_train.values.reshape(-1, 1))
regr_2.fit(X_train.values.reshape(-1, 1), y_train.values.reshape(-1, 1))
# Predict
y_1 = regr_1.predict(X_test.values.reshape(-1, 1))
y_2 = regr_2.predict(X_test.values.reshape(-1, 1))
# Plot the results
plt.figure()
plt.scatter(X_train, y_train, s=20, edgecolor="black", c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()
CodePudding user response:
Your plot is likely unattractive because your test samples aren't sorted, so you are 'connecting the dots' between different test datapoints randomly. This was unclear for your linear regression solution because the lines were overlapping.
You can get the plot you expect by sorting your test data:
# Sort
X_test = np.sort(X_test) # Need to specify axis=0 if X_test has shape (n_samples, 0)
# Predict
y_1 = regr_1.predict(X_test.values.reshape(-1, 1))
y_2 = regr_2.predict(X_test.values.reshape(-1, 1))
# Plot the results
plt.figure()
plt.scatter(X_train, y_train, s=20, edgecolor="black", c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()