Home > OS >  How to combine rows in groupby with several conditions?
How to combine rows in groupby with several conditions?

Time:12-16

I want to combine rows in pandas df with the following logic:

  • dataframe is grouped by users
  • rows are ordered by start_at_min
  • rows are combiend when:

Case A: if start_at_min<=200:

  • row1[stop_at_min] - row2[start_at_min] < 5
  • (eg: 101 -100 = 1 -> combine; 200-100=100: -> dont combine)

Case Bif 200> start_at_min<400:

  • change threhsold to 3

Case C if start_at_min>400:

  • Never combine

Example df

   user  start_at_min  stop_at_min
0     1           100          150  
1     1           152          201 #row0 with row1 combine
2     1           205          260 #row1 with row 2 NO -> start_at_min above 200 -> threshol =3 
3     2            65          100 #no
4     2           200          265 #no
5     2           300          451 #no
6     2           452          460 #no -> start_at_min above 400-> never combine 

Expected output:

   user  start_at_min  stop_at_min
0     1           100          201 #row1 with row2 combine
2     1           205          260 #row2 with row 3 NO -> start_at_min above 200 -> threshol =3 
3     2            65          100 #no
4     2           200          265 #no
5     2           300          451 #no
6     2           452          460 #no -> start_at_min above 400-> never combine 

I have written the funciton combine_rows, that takes in 2 Series and applies this logic

def combine_rows (s1:pd.Series, s2:pd.Series):
  # take 2 rows and combine them if start_at_min row2 - stop_at_min row1 < 5 
  if s2['start_at_min'] - s1['stop_at_min'] <5: 
     return pd.Series({
         'user': s1['user'],
         'start_at_min': s1['start_at_min'],
         'stop_at_min' : s2['stop_at_min']
         })
  else: 
    return pd.concat([s1,s2],axis=1).T

Howver I am unable to apply this function to the dataframe. This was my attempt:

df.groupby('user').sort_values(by=['start_at_min']).apply(combine_rows) # this not working 

Here is the full code:

import pandas as pd 
import numpy as np


df = pd.DataFrame({
    "user"       :  [1, 1, 2,2],
    'start_at_min': [60, 101, 65, 200], 
    'stop_at_min' : [100, 135, 100, 265] 
})

def combine_rows (s1:pd.Series, s2:pd.Series):
  # take 2 rows and combine them if start_at_min row2 - stop_at_min row1 < 5 
  if s2['start_at_min'] - s1['stop_at_min'] <5: 
     return pd.Series({
         'user': s1['user'],
         'start_at_min': s1['start_at_min'],
         'stop_at_min' : s2['stop_at_min']
         })
  else: 
    return pd.concat([s1,s2],axis=1).T

df.groupby('user').sort_values(by=['start_at_min']).apply(combine_rows) # this not working 

CodePudding user response:

version 1: one condition

Perform a custom groupby.agg:

threshold = 5
# if the successive stop/start per group are above threshold
# start a new group
group = (df['start_at_min']
         .sub(df.groupby('user')['stop_at_min'].shift())
         .ge(threshold).cumsum()
        )

# groupby.agg
out = (df.groupby(['user', group], as_index=False)
         .agg({'start_at_min': 'min',
               'stop_at_min': 'max'
              })
      )

Output:

   user  start_at_min  stop_at_min
0     1            60          135
1     2            65          100
2     2           200          265

Intermediate:

(df['start_at_min']
 .sub(df.groupby('user')['stop_at_min'].shift())
)

0      NaN
1      1.0  # below threshold, this will be merged
2      NaN
3    100.0  # above threshold, keep separate
dtype: float64

version 2: multiple conditions

# define variable threshold
threshold = np.where(df['start_at_min'].le(200), 5, 3)
# array([3, 3, 5, 3, 3, 5, 5])

# compute the new starts of group like in version 1
# but using the now variable threshold
m1 = (df['start_at_min']
         .sub(df.groupby('user')['stop_at_min'].shift())
         .ge(threshold)    
        )
# add a second restart condition (>400)
m2 = df['start_at_min'].gt(400)

# if either mask is True, start a new group
group = (m1|m2).cumsum()

# groupby.agg
out = (df.groupby(['user', group], as_index=False)
         .agg({'start_at_min': 'min',
               'stop_at_min': 'max'
              })
      )

Output:

   user  start_at_min  stop_at_min
0     1           100          201
1     1           205          260
2     2            65          100
3     2           200          265
4     2           300          451
5     2           452          460
  • Related