Home > database >  What does the scikit-learn library do in this code?
What does the scikit-learn library do in this code?

Time:05-25

I am interested in the field of machine learning, I tried to understand the following code but I could not. Can anyone explain to me simply?

from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split  # This module divides our data into two parts, train and test
import sklearn.metrics as met
from sklearn.datasets import load_boston


boston = load_boston()
x = boston.data
y = boston.target

xtrain, xtest, ytrain, ytest = train_test_split(x, y, test_size=0.2, random_state=42)

model = LinearRegression()
model.fit(xtrain, ytrain)

ypredict = model.predict(xtest)

plt.scatter(ytest, ypredict)
plt.show()
print(met.mean_squared_error(ytest, ypredict))

CodePudding user response:

Here are the steps:

  1. Load the data and assign variables:
load_boston()
x = boston.data
y = boston.target
  1. Divide data into data to be used for training and for validation (test). Usually is a 80/20 ratio. random_state is set to 42 thanks to the The Hitchhiker's Guide to the Galaxy (I do not want to spoil you anything...)
xtrain, xtest, ytrain, ytest = train_test_split(x, y, test_size=0.2, random_state=42)
  1. Do a univariate linear regression. i.e. create the model.
model = LinearRegression()
model.fit(xtrain, ytrain)
  1. Use validation (test) data to check the model created in the previous step.
ypredict = model.predict(xtest)
  1. Draw a scatter plot of validation results vs prediction results
plt.scatter(ytest, ypredict)
plt.show()
  1. Print accuracy of the model as a mean squared error.
print(met.mean_squared_error(ytest, ypredict))

CodePudding user response:

Follow the comments to understand the code

# importing various modules 
# Imports Linear Regression model to fit the features in a linear combination to derive the target value.
from sklearn.linear_model import LinearRegression

#To plot/visualize the data importing matplotlib
import matplotlib.pyplot as plt

# This module divides our data into two parts, train and test
from sklearn.model_selection import train_test_split  

# metrics is used to analyze the model performance (such as mean squared error)
import sklearn.metrics as met

# sklearn.datasets has various datasets for quick use
from sklearn.datasets import load_boston

#loading the boston housing dataset form standard sklean.datasets module
boston = load_boston()
#seperating the features (X) and target variable (y) boston dataset 
x = boston.data
y = boston.target

# Dividing the dataset into training and test to train the model and evaluate the model.
xtrain, xtest, ytrain, ytest = train_test_split(x, y, test_size=0.2, random_state=42)

# creating an linear regression object
model = LinearRegression()
#training the model
model.fit(xtrain, ytrain)
#once the model is trained predicting the target values for test data which is not used in training (i.e. unseen data for model)
ypredict = model.predict(xtest)

#Ploting the actual value and target value in a scatter plot to visualize how far/close is the prediction from the actual values.
plt.scatter(ytest, ypredict)
plt.show()
#calculate the mean squared error, it indicates how far (avg) is the prediction from actual
print(met.mean_squared_error(ytest, ypredict))
  • Related