Home > Software engineering >  pandas groupby and select value from aother column for ecah group for plotting
pandas groupby and select value from aother column for ecah group for plotting

Time:08-07

I am using groupby() and for plotting each group. I have a dataframe: which each group have the same 'normal' value

id  date  value  normal
1    5.2  20      200   
1    5.4  100     200
1    6.9   30     200
2    2.4   20     500
2    3.4   100    500

I want to plot each group date and value columns (for x,y axes) and add axhline with the normal values correspond to the group. I've tried:

ax = sns.scatterplot("date", "value", data = data)
grouped = data.groupby('id')
normal_val= grouped['normal']
plt.axhline(normal_val,c='red', ls=':')    #===> THIS LINE
plt.show()

but does not work. Hope someone could help! thanks

CodePudding user response:

Note that df.groupby('id').normal itself doesn't make sense. You should apply an aggregating function to it afterward. Otherwise, it is just a groupby object. I take the mean of normal values for each id, but of course you can play with it.

normal_val = df.groupby('id').normal.mean()
for item in normal_val:
    plt.axhline(y=item, c='red', ls=':')
plt.show()

gives

enter image description here

CodePudding user response:

Referring to the highlighted line, the normal_val has more than 1 value, not sure which one are you interested in.

print(data.groupby('id')[['normal']].mean())

     normal
id         
1.0   200.0
2.0   500.0

If you change the line to something like below, you will get an output

normal_val1 = 200
normal_val2 = 500
plt.axhline(normal_val1, c='red', ls=':')
plt.axhline(normal_val2, c='green', ls='-.')

enter image description here

Edit: Depending on how many 'normal' values you have, you can add to the color list c. I've started with 8, so this code should work:

import seaborn as sns
import matplotlib.pyplot as plt

ax = sns.scatterplot("date", "value", data = data)
df = data.groupby('id')[['normal']].mean().reset_index()
c = ['red', 'green', 'yellow', 'blue', 'navy', 'cyan', 'pink', 'brown']
if len(df['normal']) <= len(c):
    for i, normal in enumerate(df['normal']):
        plt.axhline(normal, c=c[i], ls='-.')
plt.show()

CodePudding user response:

This is because the line grouped = data.groupby('id) return pandas.core.groupby.generic.DataFrameGroupBy object.

You might want to try -

data = {
    "id": [1, 1, 1, 2, 2],
    "date":  [5.2, 5.4, 6.9, 2.4, 3.4],
    "value": [20, 100, 30, 20, 100],
    "normal": [200, 200, 200, 500, 500]
}

data = pd.DataFrame(data)

for id in data["id"].unique():
    plot_data = data.loc[data["id"] == id]
    ax = sns.scatterplot("date", "value", data = plot_data)
    ax.axhline(plot_data["normal"].unique(), c='red', ls=':')
    plt.show()
  • Related