I have a pandas df, that looks something like this (after scaling):
Age blood_Press golucse Cholesterol
0 1.953859 -1.444088 -1.086684 -1.981315
1 0.357992 -0.123270 -0.585981 0.934929
2 0.997219 0.998712 2.005212 0.019169
3 2.589318 -0.528543 -1.123484 -1.299904
4 2.088141 0.792976 0.021526 -0.777959
and a binary target feature:
y
0 1.0
1 1.0
2 1.0
3 0.0
4 1.0
I want to make a bar chart for each predictive feature with the y
target. So the y
values would be on the x-axis, which is just 1
or 0
, and on the y-axis would be the values for the predictive feature. For example, something that looks like this (ignore the features used here, just an example of what I need). So here instead of male
and female
I'd have 1
and 0
...
the code for this plot is something like this:
myPlot = sns.catplot(data = data, x = 'the y feature' , y = 'the x feature', kind = 'bar')
myPlot.fig.suptitle('title title', size=15, y=1.);
myPlot.set_ylabels('Y label whatever', fontsize=15, x=1.02)
myPlot.fig.set_size_inches(9,8);
But I don't want to repeat it for every feature, I'm sure it's much simpler than that. But how?
CodePudding user response:
Setup
print(df)
Age blood_Press golucse Cholesterol y
0 1.953859 -1.444088 -1.086684 -1.981315 1.0
1 0.357992 -0.123270 -0.585981 0.934929 1.0
2 0.997219 0.998712 2.005212 0.019169 1.0
3 2.589318 -0.528543 -1.123484 -1.299904 0.0
4 2.088141 0.792976 0.021526 -0.777959 1.0
Melt
the dataframe to convert from wide to long format
m = df.melt(id_vars=['y'], var_name='feature')
print(m)
# y feature value
# 0 1.0 Age 1.953859
# 1 1.0 Age 0.357992
# 2 1.0 Age 0.997219
# 3 0.0 Age 2.589318
# 4 1.0 Age 2.088141
# 5 1.0 blood_Press -1.444088
# 6 1.0 blood_Press -0.123270
# 7 1.0 blood_Press 0.998712
# 8 0.0 blood_Press -0.528543
# 9 1.0 blood_Press 0.792976
# 10 1.0 golucse -1.086684
# 11 1.0 golucse -0.585981
# 12 1.0 golucse 2.005212
# 13 0.0 golucse -1.123484
# 14 1.0 golucse 0.021526
# 15 1.0 Cholesterol -1.981315
# 16 1.0 Cholesterol 0.934929
# 17 1.0 Cholesterol 0.019169
# 18 0.0 Cholesterol -1.299904
# 19 1.0 Cholesterol -0.777959
Then use the catplot
method and pass the col
parameter as feature
column
sns.catplot(data=m, x='y', y='value', col='feature', kind='bar', col_wrap=2)