Home > OS >  How to make a barplot for the target variable with each predictive variable?
How to make a barplot for the target variable with each predictive variable?

Time:12-31

I have a pandas df, that looks something like this (after scaling):

      Age     blood_Press   golucse   Cholesterol
0   1.953859    -1.444088   -1.086684   -1.981315
1   0.357992    -0.123270   -0.585981   0.934929
2   0.997219    0.998712    2.005212    0.019169
3   2.589318    -0.528543   -1.123484   -1.299904
4   2.088141    0.792976    0.021526    -0.777959

and a binary target feature:

     y
0   1.0
1   1.0
2   1.0
3   0.0
4   1.0

I want to make a bar chart for each predictive feature with the y target. So the y values would be on the x-axis, which is just 1 or 0, and on the y-axis would be the values for the predictive feature. For example, something that looks like this (ignore the features used here, just an example of what I need). So here instead of male and female I'd have 1 and 0...

enter image description here

the code for this plot is something like this:

myPlot = sns.catplot(data = data, x = 'the y feature' , y = 'the x feature', kind = 'bar')
myPlot.fig.suptitle('title title', size=15, y=1.);
myPlot.set_ylabels('Y label whatever', fontsize=15, x=1.02)
myPlot.fig.set_size_inches(9,8);

But I don't want to repeat it for every feature, I'm sure it's much simpler than that. But how?

CodePudding user response:

Setup

print(df)

        Age  blood_Press   golucse  Cholesterol    y
0  1.953859    -1.444088 -1.086684    -1.981315  1.0
1  0.357992    -0.123270 -0.585981     0.934929  1.0
2  0.997219     0.998712  2.005212     0.019169  1.0
3  2.589318    -0.528543 -1.123484    -1.299904  0.0
4  2.088141     0.792976  0.021526    -0.777959  1.0

Melt the dataframe to convert from wide to long format

m = df.melt(id_vars=['y'], var_name='feature')
print(m)

#       y      feature     value
# 0   1.0          Age  1.953859
# 1   1.0          Age  0.357992
# 2   1.0          Age  0.997219
# 3   0.0          Age  2.589318
# 4   1.0          Age  2.088141
# 5   1.0  blood_Press -1.444088
# 6   1.0  blood_Press -0.123270
# 7   1.0  blood_Press  0.998712
# 8   0.0  blood_Press -0.528543
# 9   1.0  blood_Press  0.792976
# 10  1.0      golucse -1.086684
# 11  1.0      golucse -0.585981
# 12  1.0      golucse  2.005212
# 13  0.0      golucse -1.123484
# 14  1.0      golucse  0.021526
# 15  1.0  Cholesterol -1.981315
# 16  1.0  Cholesterol  0.934929
# 17  1.0  Cholesterol  0.019169
# 18  0.0  Cholesterol -1.299904
# 19  1.0  Cholesterol -0.777959

Then use the catplot method and pass the col parameter as feature column

sns.catplot(data=m, x='y', y='value', col='feature', kind='bar', col_wrap=2)

enter image description here

  • Related