Home > OS >  How to plot perform linear regression analysis on a simple data set
How to plot perform linear regression analysis on a simple data set

Time:11-16

I am writing a simple python program to analyze a data set using linear regression. The program is constructed like so

# Author: Evan Gertis
# Date 11/15
# program: linear regression

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# read the data from the file
df = pd.read_csv('father_son_data.csv')

logging.info(f"{df}")

# create groups of column names to be plotted together
def chunker(seq, size):
    return [seq[pos:pos   size] for pos in range(0, len(seq), size)]

# function call
col_list = chunker(df.columns, 2)

# iterate through each group of column names to plot
for x, y in chunker(df.columns, 2):
   plot = sns.scatterplot(data=df, x=x, y=y, label=y)

fig = plot.get_figure()
fig.savefig(f"Father_Son.png")

The data is

Height X of father (in), Height Y of son (in) 
65, 68 
63, 66 
67, 68 
68, 65 
62, 69 
70, 68 
66, 65 
67, 67 
67, 68 
68, 69 
71, 70

What is the best way to plot the line through the points of data? I would like to figure out how to best construct a line through the data using python.

Expected: a straight line through the curve shown in Father_Son.png

Actual: curve shown in Father_Son.png

Thank you in advance!

I made a simple program to plot the data from a csv. I would like to figure out how to best construct a line through the data using python.

CodePudding user response:

Assuming that the csv file has exactly the same format you shown in the question and that the first column represents the independent variable, while the second one is the dependent:

# few libraries
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from scipy import stats

#read the dataset
df = pd.read_csv(FILE_PATH)

#initialize the figure
plt.figure(figsize=FIG_SIZE)

#perform the linear regression    
linear_regression = stats.linregress(df)

#scatter the points
plt.scatter(
    *df.T.values,
    alpha=.5, color=SCATTER_COLOR
)

#define the range of the independent variable as its extreme values in the dataset
x_range = np.array([df.iloc[:,0].min(), df.iloc[:,0].max()])

#define the label shown in the legend
label = """
regression line:
slope = %.{decimals}f ± %.{decimals}f
R = %.{decimals}f (p-value: %.{decimals}e)
""".format(decimals=DECIMALS)%(
    linear_regression.slope, #the line slope
    linear_regression.stderr, #the line slope standard deviation
    linear_regression.rvalue, #the pearson's correlation coefficient
    linear_regression.pvalue, #the p-value associated to the pearson's correlation coefficient (Wald test, the null hypothesis is slope=0)
)

#plot the line
plt.plot(
    x_range,
    linear_regression.intercept x_range*linear_regression.slope, #y=a bx
    alpha=.75, color=LINE_COLOR, linestyle=LINE_STYLE,
    label=label
)

plt.xlabel(df.columns[0])
plt.ylabel(df.columns[1])
plt.grid(True, linestyle=":", alpha=.5)
plt.legend(loc="best")
plt.savefig(IMG_PATH)
plt.show()

the output in both the IMG_PATH file and the displayed figure is the following: output figure

The parameters I have used are the following:

FIG_SIZE = (10,5)
SCATTER_COLOR = "brown"
LINE_COLOR = "black"
LINE_STYLE = "--"
DECIMALS = 2

You also have to assign to FILE_PATH the path of your input csv file and to IMG_PATH the path of your output image file. Note that: the csv extension for the input file is not mandatory since you just need a text comma-separated file (it is important the file is formatted correctly); the format of the output figure file is inferred from the file name (e.g. if IMG_PATH="./img.svg", the image will be formatted in svg).

  • Related