Home > Blockchain >  Linear regression line on a scatter plot in python
Linear regression line on a scatter plot in python

Time:03-26

I am trying add a linear regression line to my scatter plot, but the way my data are stored and/or structured make the task pretty tricky for me. I managed to create a scatter plot with my data using "sns.scatterplot", but couldn't replicate it with "sns.regplot".

By using the following piece of code, I was able to produce the figure bellow :

sns.set_theme(color_codes=True)

sns.scatterplot(x=100, y=dict_all_mv['dict_100'])
sns.scatterplot(x=90, y=dict_all_mv['dict_90'])
sns.scatterplot(x=80, y=dict_all_mv['dict_80'])
sns.scatterplot(x=60, y=dict_all_mv['dict_60'])
sns.scatterplot(x=40, y=dict_all_mv['dict_40'])
sns.scatterplot(x=20, y=dict_all_mv['dict_20'])

plt.show()

Scatter plot without regression line

The data displayed on the y axis are stored in nested dictionaries. The type and a print of one of the dictionaries are displayed bellow :

print(type(dict_all_mv['dict_100']))
print(dict_all_mv['dict_100'])
<class 'dict'>
{'P1S1': 0.2825781935483871, 'P2S1': 0.1566529827586207, 'P3S1': 0.16534197660818717, 'P4S1': 0.2463573039215686, 'P5S1': 0.21233809917355378, 'P6S1': 0.2780758316831682, 'P7S1': 0.261270830357143, 'P8S1': 0.25920007258064504, 'P9S1': 0.17294190131578946, 'P10S1': 0.18588214503816788, 'P1S2': 0.1870260291970803, 'P2S2': 0.2059823125, 'P3S2': 0.17782266216216217, 'P4S2': 0.1835217697841726, 'P5S2': 0.21022608870967743, 'P6S2': 0.2695895306122449, 'P7S2': 0.2393472735042735, 'P8S2': 0.18402302247191013, 'P9S2': 0.17260260273972605, 'P10S2': 0.20730387826086952}

I want to plot all the values stored in each dictionaries on the y axis, while the x axis values are equals to 100 for each values of 'dict_100', 90 for each values of 'dict_90', 80 for 'dict_80', etc. I have tried some options to plot a linear regression line with sns.regplot and lmplot, but without any success. Here is one of my failed attempts :

sns.regplot(x=100,y=dict_all_mv['dict_100'])
sns.regplot(x=90,y=dict_all_mv['dict_90'])
sns.regplot(x=80,y=dict_all_mv['dict_80'])

plt.show()

I got this error message :

AttributeError: 'int' object has no attribute 'shape'

My goal is to obtain a graph like this one:

sns.regplot with concatenated values

  • Related