Home > Back-end >  Find top 3 items in each group in a dataset
Find top 3 items in each group in a dataset

Time:10-20

There is a dataset of movies. I want to find out the top 3 genres in each year(genres having maximum count of movies for the year). The dataset excerpt is below:

      year      genre  imdb_title_id

 19   1894    Romance              1
 29   1906  Biography              1
 31   1906      Crime              1
 33   1906      Drama              1
 58   1911      Drama              4
 73   1911        War              2
 52   1911  Adventure              1
 60   1911    Fantasy              1
 62   1911    History              1
 83   1912      Drama              5
 87   1912    History              2
 79   1912  Biography              1
 81   1912      Crime              1
 91   1912    Mystery              1
 98   1912        War              1
 108  1913      Drama             11
 106  1913      Crime              4
 110  1913    Fantasy              3
 102  1913  Adventure              2
 113  1913     Horror              2

How to do this kind of operation in pandas? I have tried nlargest however not getting the correct result. Expected output for this case should be like:

19   1894    Romance              1
29   1906  Biography              1
31   1906      Crime              1
33   1906      Drama              1
58   1911      Drama              4
73   1911        War              2
52   1911  Adventure              1
83   1912      Drama              5
87   1912    History              2
79   1912  Biography              1
108  1913      Drama             11
106  1913      Crime              4
110  1913    Fantasy              3

CodePudding user response:

i think it works:

df = df.sort_values(["imdb_title_id"], ascending=False)
df = df.groupby("year", as_index=False).agg({"genre": lambda x: list(x)[:3], "imdb_title_id": lambda x: list(x)[:3]})
result = df.explode("genre", ignore_index=True)
result["imdb_title_id"] = df.explode("imdb_title_id")["imdb_title_id"].values

but better ways could be found out.

CodePudding user response:

nlargest() should 'just work' but here is some sample code to deal with evil indices issues.

top3_idx = df.groupby("year")["imdb_title_id"].nlargest(3).droplevel(0).index
top3_df = df.iloc[top3_idx]

Basically you get nlargest by year then use the index values to filter your dataframe.

  • Related