I have to implement a pandas groupby operation which is more difficult than the usual simple aggregates I do. The table I'm working with has the following structure:
category price
0 A 89
1 A 58
2 ... ...
3 B 75
4 B 120
5 ... ...
6 C 90
7 C 199
8 ... ...
As shown above, my example DataFrame consists of 3 categories A, B, and C (the real DataFrame I'm working on has ~1000 categories). We will assume that category A has 20 rows and categories B and C have more than 100 rows. These are denoted by the 3 dots (...).
I would like to calculate the average price of each category with the following conditions:
If the number of elements in the category is greater than 100 (i.e., B and C in this example), then the average should be calculated while excluding values that are 3 standard deviations away from the mean within each category.
Else, for the categories that have less than 100 elements (i.e., A in this example), the average should be calculated on the entire group, without any exclusion criteria.
Calculating the average price for each category without any condition on the groups is straightforward: df.groupby("category").agg({"price": "mean"})
, but I'm stuck with the extra conditions here.
I also always try to provide a reproducible example while asking questions here but I don't know how to properly write one for this problem with fake data. I hope this format is still ok.
CodePudding user response:
Maybe you can do it like this?
df.groupby('category')['price'].apply(
lambda x: np.mean(x) if len(x) <= 100
else np.mean(x[(x >= np.mean(x) - 3*np.std(x))
& (x <= np.mean(x) 3*np.std(x))]))
Or without numpy (but with numpy usually works faster):
df.groupby('category')['price'].apply(
lambda x: x.mean() if len(x) <= 100
else x[(x >= x.mean() - 3*x.std())
& (x <= x.mean() 3*x.std())].mean())
CodePudding user response:
I'm not sure if you will be able to do all of this at once. Try to break down the steps, like this:
- Identify the number of elements per category:
df_elements = df.groupby('category').agg({'price':'count'}).reset_index()
df_elements.rename({'price':'n_elements'}, inplace=True,axis=1)
- Identify if the number of elements is less than 100 or greater than 100 and then perform the appropriate average calculation:
aux = []
for cat in df_elements.category.unique():
if df_elements[df_elements.category==cat]['n_elements'] < 100:
df_aux = df[df.category==cat].groupby('category').agg({'price':'mean'})
aux.append(df_aux.reset_index())
else:
std_cat = df[df.category==cat]['price'].std()
mean_cat = df[df.category==cat]['price'].mean()
th = 3*std_cat
df_cut = df[(df.category==cat) & (df.price <= mean_cat th) & (df.price >= mean_cat - th]
df_aux = df_cut.groupby('category').agg({'price':'mean'})
aux.append(df_aux.reset_index())
final_df = pd.concat(aux,axis=0)
final_df.rename({'price':'avg_price'},axis=1,inplace=True)