I am using groupby()
and for plotting each group. I have a dataframe:
which each group have the same 'normal' value
id date value normal
1 5.2 20 200
1 5.4 100 200
1 6.9 30 200
2 2.4 20 500
2 3.4 100 500
I want to plot each group date and value columns (for x,y axes) and add axhline with the normal values correspond to the group. I've tried:
ax = sns.scatterplot("date", "value", data = data)
grouped = data.groupby('id')
normal_val= grouped['normal']
plt.axhline(normal_val,c='red', ls=':') #===> THIS LINE
plt.show()
but does not work. Hope someone could help! thanks
CodePudding user response:
Note that df.groupby('id').normal
itself doesn't make sense. You should apply an aggregating function to it afterward. Otherwise, it is just a groupby object. I take the mean of normal values for each id, but of course you can play with it.
normal_val = df.groupby('id').normal.mean()
for item in normal_val:
plt.axhline(y=item, c='red', ls=':')
plt.show()
gives
CodePudding user response:
Referring to the highlighted line, the normal_val
has more than 1 value, not sure which one are you interested in.
print(data.groupby('id')[['normal']].mean())
normal
id
1.0 200.0
2.0 500.0
If you change the line to something like below, you will get an output
normal_val1 = 200
normal_val2 = 500
plt.axhline(normal_val1, c='red', ls=':')
plt.axhline(normal_val2, c='green', ls='-.')
Edit: Depending on how many 'normal' values you have, you can add to the color list c
. I've started with 8, so this code should work:
import seaborn as sns
import matplotlib.pyplot as plt
ax = sns.scatterplot("date", "value", data = data)
df = data.groupby('id')[['normal']].mean().reset_index()
c = ['red', 'green', 'yellow', 'blue', 'navy', 'cyan', 'pink', 'brown']
if len(df['normal']) <= len(c):
for i, normal in enumerate(df['normal']):
plt.axhline(normal, c=c[i], ls='-.')
plt.show()
CodePudding user response:
This is because the line grouped = data.groupby('id)
return pandas.core.groupby.generic.DataFrameGroupBy object
.
You might want to try -
data = {
"id": [1, 1, 1, 2, 2],
"date": [5.2, 5.4, 6.9, 2.4, 3.4],
"value": [20, 100, 30, 20, 100],
"normal": [200, 200, 200, 500, 500]
}
data = pd.DataFrame(data)
for id in data["id"].unique():
plot_data = data.loc[data["id"] == id]
ax = sns.scatterplot("date", "value", data = plot_data)
ax.axhline(plot_data["normal"].unique(), c='red', ls=':')
plt.show()