Home > OS >  I'm not sure what needs to be reshaped in my data
I'm not sure what needs to be reshaped in my data

Time:04-01

I'm trying to use a LinearRegression() algorithm to predict the price of a house.

Here's my code:

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

df = pd.read_csv('data.csv')
df = df.drop(columns=['date', 'street', 'city', 'statezip', 'country'])

X = df.drop(columns=['price'])
y = df['price']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

lr = LinearRegression()
lr.fit(X_train, y_train)
pred = lr.predict(X_test)
pred.reshape((-1, 1))
acc = lr.score(pred, y_test)

However, I keep on getting this error:

Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

I've tried to reshape all the attributes in my data, but the only thing that I'm able to reshape is pred, and I still get the same error after doing that?

How should I fix this error?

Thanks in advance.

CodePudding user response:

Base on Documentation of sklearn.linear_model.LinearRegression.score:

score(X, y, sample_weight=None)

return R^2 score of self.predict(X) wrt. y.

You need to pass X as the first argument like below:

lr.fit(X_train, y_train)
acc = lr.score(X_test, y_test)
print(acc)

Or You can use sklearn.metrics.r2_score:

from sklearn.metrics import r2_score
acc = r2_score(y_test, pred)
print(acc)

Example:

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
y = np.dot(X, np.array([1, 2]))   3
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42)

lr = LinearRegression()
lr.fit(X_train, y_train)
pred = lr.predict(X_test)
acc = lr.score(X_test, y_test)
print(acc)
# Or
from sklearn.metrics import r2_score
acc = r2_score(y_test, pred)
print(acc)

Output:

0.8888888888888888
0.8888888888888888
  • Related