Home > Software engineering >  How to get the MAX value of previous group in PANDAS?
How to get the MAX value of previous group in PANDAS?


Im trying to get the MAX of previous group category. I used DEFINE function and APPLY. But it seems performance of this code is slow. Are there other ways to do it where it will be faster?

Below is my code:

def Price_MAX_prev_TREND(row):
    prev_trend_count = row['trend_count'] - 1
    x = df_filtered.loc[ df_filtered['trend_count'] == prev_trend_count , 'Price' ].max()
    return x
df_filtered['Price_MAX_prev_TREND'] = df_filtered.apply(Price_MAX_prev_TREND, axis = 1)

CodePudding user response:

Compute the max per group with groupby.max, then map the max from the group using shifted values with shift.

Here is a simple example:

df = pd.DataFrame({'group': [0, 0, 0, 1, 1, 1, 2, 2, 2],
                   'value': [1, 4, 2, 3, 2, 0, 5, 3, 3]

# only use sort=False if you want to keep the original order
s = df.groupby('group', sort=False)['value'].max()

df['max_previous'] = df['group'].map(s.shift())


   group  value  max_previous
0      0      1           NaN
1      0      4           NaN
2      0      2           NaN
3      1      3           4.0
4      1      2           4.0
5      1      0           4.0
6      2      5           3.0
7      2      3           3.0
8      2      3           3.0

If you have discontinuous integer/year/etc. groups and want to ensure that you map the previous (i.e. n-1) group, rather change the index with set_axis:

df['max_previous'] = df['group'].map(s.set_axis(s.index 1))

Example to see the difference:

    group  value  max_previous_shift  max_previous_discontinuous
0       0      1                 NaN                         NaN
1       0      4                 NaN                         NaN
2       0      2                 NaN                         NaN
3       1      3                 4.0                         4.0
4       1      2                 4.0                         4.0
5       1      0                 4.0                         4.0
6       2      5                 3.0                         3.0
7       2      3                 3.0                         3.0
8       2      3                 3.0                         3.0
9       4      7                 5.0                         NaN
10      4      3                 5.0                         NaN
11      4      1                 5.0                         NaN

CodePudding user response:

First aggregate max values by GroupBy.max and then create new column by mapping with Series.map by previous trend - add 1 to indices (from trend_count column), sort=False is used for improve performance:

s = df.groupby('trend_count', sort=False)['Price'].max()
df['Price_MAX_prev_TREND'] = df['trend_count'].map(s.rename(lambda x: x 1))

Another idea is add 1 to trend_count before aggregate max:

s = df.assign(trend_count = df['trend_count'].add(1)).groupby('trend_count', sort=False)['Price'].max()
df['Price_MAX_prev_TREND'] = df['trend_count'].map(s)

Performance depends of data, best test in real data:

N = 10000
df = pd.DataFrame({'trend_count':np.random.randint(1000, size=N),
                   'Price': np.random.randint(1000, size=N)})

#original solution
In [192]: %%timeit
     ...: df['Price_MAX_prev_TREND1'] = df.apply(Price_MAX_prev_TREND, axis = 1)
4.02 s ± 197 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [193]: %%timeit
     ...: s = df.groupby('trend_count', sort=False)['Price'].max()
     ...: df['Price_MAX_prev_TREND'] = df['trend_count'].map(s.rename(lambda x: x 1))
2.38 ms ± 22.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [194]: %%timeit
     ...: s = df.assign(trend_count = df['trend_count'].add(1)).groupby('trend_count', sort=False)['Price'].max()
     ...: df['Price_MAX_prev_TREND'] = df['trend_count'].map(s)
2.26 ms ± 16.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

#mozway solution
In [195]: %%timeit
     ...: s = df.groupby('trend_count', sort=False)['Price'].max()
     ...: df['max_previous'] = df['trend_count'].map(s.set_axis(s.index 1))
3.01 ms ± 340 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
  • Related