I am working on implementing from scratch a linear regression model means without using Sklearn package.
all was working just fine , until i tried ploting the result.
i looked at a bunch of solution but neither of them was for myy problem
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
data = pd.read_csv(r'C:\Salary.csv')
x=data['Salary']
y=data['YearsExperience']
#y= mx b
m = 0
b = 0
Learning_Rate = .01
epochs = 5000
n = np.float(x.shape[0])
error = []
for i in range(epochs):
Y_hat = m*x b
#error
mse= (1/n)*np.sum((y-Y_hat)**2)
error.append(mse)
#gradient descend
db = (-2/n) * np.sum(x*(y-Y_hat))
dm = (-2/n) * np.sum((y-Y_hat))
m = m - Learning_Rate * dm
b = b - Learning_Rate * db
#tracing x and y line
x_line = np.linspace(0, 15, 100)
y_line = (m*x_line) b
#ploting result
plt.figure(figsize=(8,6))
plt.title('LR result')
**plt.plot(x_line,y_line) #the problem is apparently here
# i just don't know what to do**
plt.scatter(x,y)
plt.show()
appart from that, there is no problem with the code .
CodePudding user response:
Your code has multiple problems:
you are plotting the line from
0
and15
, while data range from about40000
to140000
. Even if you are correctly computing the line, you are going to plot it in a region far away from your datain the loop there is a mistake in the computation of
dm
anddb
, they are swapped. The corrected expressions are:dm = (-2/n)*np.sum(x*(y - Y_hat)) db = (-2/n)*np.sum((y - Y_hat))
your
x
andy
data are on very different scales:x
is~10⁴
magnitude, whiley
is~10¹
. For this reason, alsom
andb
will likely be very different from each other (different orders of magnitude). This is the reason why you should use two different learning rate for the different quantities you are optimizing:Learning_Rate_m
form
andLearning_Rate_b
forb
-
CodePudding user response:
The problem is not happening while plotting, the problem is with the parameters in plt.plot(x_line,y_line), I tested your code and found that y_line is all NaN values, double check the calculations (y_line, m, dm).