Home > database >  aggregate as list with max two elements
aggregate as list with max two elements

Time:03-13

Given the user table as follow:

   user       query
0    a1      orange
1    a1  strawberry
2    a1        pear
3    a2      orange
4    a2  strawberry
5    a2       lemon
6    a3      orange
7    a3      banana
8    a6        meat
9    a7        beer
10   a8       juice

I want to group by user and aggregate as list for query and pick the first two items if it's more than two, the expected outcome is

  user                        query
0   a1         [orange, strawberry]
1   a2         [orange, strawberry]
2   a3             [orange, banana]
3   a6                       [meat]
4   a7                       [beer]
5   a8                      [juice]

with the code below

df_user = pd.DataFrame( {'user': {0: 'a1', 1: 'a1', 2: 'a1', 3: 'a2', 
                                  4: 'a2', 5: 'a2', 6: 'a3', 7: 'a3', 
                                  8: 'a6', 9: 'a7', 10: 'a8'}, 
                         'query': {0: 'orange', 1: 'strawberry', 
                                   2: 'pear', 3: 'orange', 4: 'strawberry', 
                                   5: 'lemon', 6: 'orange', 7: 'banana', 
                                   8: 'meat', 9: 'beer', 10: 'juice'}} )

print(df_user.groupby(['user'], as_index=False).agg(list))

I manage to get

  user                        query
0   a1   [orange, strawberry, pear]
1   a2  [orange, strawberry, lemon]
2   a3             [orange, banana]
3   a6                       [meat]
4   a7                       [beer]
5   a8                      [juice]

what is a good way to get the desired outcome?

CodePudding user response:

You could use iloc to slice up to 2 items:

df_user.groupby(['user'], as_index=False).agg(lambda s: s.iloc[:2].to_list())

Output:

  user                 query
0   a1  [orange, strawberry]
1   a2  [orange, strawberry]
2   a3      [orange, banana]
3   a6                [meat]
4   a7                [beer]
5   a8               [juice]

CodePudding user response:

Here's one way:

out = df[df.groupby('user').cumcount()<2].groupby('user', as_index=False).agg(list)

Output:

  user                 query
0   a1  [orange, strawberry]
1   a2  [orange, strawberry]
2   a3      [orange, banana]
3   a6                [meat]
4   a7                [beer]
5   a8               [juice]
​

CodePudding user response:

You can use groupby nth() to select elements from each group by index (if they exist):

new_df = df.groupby('user').nth([0, 1]).groupby(level=0).agg(list)

Output:

>>> new_df
                     query
user                      
a1    [orange, strawberry]
a2    [orange, strawberry]
a3        [orange, banana]
a6                  [meat]
a7                  [beer]
a8                 [juice]

Note that list(range(2)) would be more dynamic than [0, 1], if you didn't want to type all those numbers :)

  • Related