Home > Mobile >  Pandas get rows based on column max in slice with multiindex
Pandas get rows based on column max in slice with multiindex

Time:11-25

The task is that I have a dataframe that looks something like the following

            Text    text_len
userId                                                                                     
0        firstext          8
0  firsttextmore0         14
1       text ones          9
2           third          5
2       third two          9

It is grouped by userId, and there may be more rows pr userId. It doesnt matter what is in the Text column.

I then have a dictionary with information telling me which rows for a given user that belongs together. The dict, will then have userId as key and a list of tuples as values, the list of tuples should be intepreted as a list of row indexes for a given userId. An example is given here

session_dict = {0: [(0, 1), (1, 4), (4, 6)],
                1: [(0,1)], ...}

Let's say that for a userId which has 6 rows the value in session_dict would be rows (0:1), (1:4), (4:6).

I would then like to extract the row with the maximum text_len for each session in the session dict and for each userId. If there are 3 tuples for a given key in session_dict, I would want to extract 3 rows, which are the maximum text length rows for the in each session for the userId

What I have done now is a nested for loop, looping over each key in session_dict, then looping over the value in the session_dict and getting the max and appending the row to a results dataframe. The code looks something like this

    for userId, session in session_dict.items():
        for tup_index in session:
 
            longest_text_arg_max = df.loc[userId].iloc[tup_index[0] : tup_index[1]][
                "text_len"
            ].argmax()

            to_append = (
                df.loc[userId]
                .iloc[tup_index[0] : tup_index[1]]
                .iloc[longest_text_arg_max]
            )
            df_result = df_result.append(to_append.reset_index(), ignore_index=True)

I feel like there is alot of looping going on, and want to find a better way of doing it.

I hope I explaned the issue clearly

EDIT: Fixed index range

CodePudding user response:

I would create a helper series to create group number (per session). first add an extra level of index with groupby.cumcount

# dummy data
df = pd.DataFrame(
    {"userId":[0]*6 [1]*7,
     "Text": list('abcdefghijklm'),
     "text_len": [18, 27, 14, 11, 19, 8, 0, 26, 7, 15, 1, 5, 12]}
).set_index('userId')

session_dict = {0: [(0, 1), (1, 4), (4, 6)],
                1: [(0,5), (5,7)],} 

# first add an extra level of index with cumcount
df = df.set_index(df.groupby(level='userId').cumcount(), append=True)
print(df)
#          Text  text_len
# userId                 
# 0      0    a        18
#        1    b        27
#        2    c        14
#        3    d        11
#        4    e        19
#        5    f         8
# 1      0    g         0
#        1    h        26
#        2    i         7
#        3    j        15
#        4    k         1
#        5    l         5
#        6    m        12

now you can extract the beginning of each session per user from your dictionary and create a multiindex

idx = pd.MultiIndex.from_tuples(
    [(u_id, st) for u_id, _l in session_dict.items() 
                for st, _ in _l]
)
print(idx)
# MultiIndex([(0, 0),
#             (0, 1),
#             (0, 4),
#             (1, 0),
#             (1, 5)],
#            )

create the helper series with group id, from the previous multiindex, reindex with df.index and cumsum.

gr = (
    pd.Series(1, index=idx)
      .reindex(df.index, fill_value=0)
      .cumsum()
)
print(gr)
# userId   
# 0       0    1
#         1    2
#         2    2
#         3    2
#         4    3
#         5    3
# 1       0    4
#         1    4
#...

Now you can use groupby.idxmax to get the index of the max per session, use this in a loc to keep all the wanted row and drop the extra index level

res = (
    df.loc[df.groupby(gr)['text_len'].idxmax()]
      .reset_index(level=1, drop=True)
)
print(res)
#        Text  text_len
# userId               
# 0         a        18
# 0         b        27
# 0         e        19
# 1         h        26
# 1         m        12
  • Related