I have a pandas dataframe with following columns.
col1 col2 col3 col4
A101 3 LLT 10028980
A101 7 LLT 10028980
A101 7 PT 10028980
A102 5 LLT 10028981
A102 3 PT 10028981
A103 2 PT 10028982
A103 4 LLT 10028982
I would like to extract all those rows where col2 is max for each value of col1. The expected output is:
col1 col2 col3 col4
A101 7 LLT 10028980
A101 7 PT 10028980
A102 5 LLT 10028981
A103 4 LLT 10028982
I am using following lines of code but it is filtering the rows where there are multiple rows with max value (row 1 is excluded).
m = df.notnull().all(axis=1)
df = df.loc[m].groupby('col1').max().reset_index()
I am getting this output:
col1 col2 col3 col4
A101 7 PT 10028980
A102 5 LLT 10028981
A103 4 LLT 10028982
CodePudding user response:
You can widen the maximum per group with transform, check equality against it to detect maximums and index with that mask:
>>> df.loc[df["col2"].eq(df.groupby("col1")["col2"].transform("max"))]
col1 col2 col3 col4
1 A101 7 LLT 10028980
2 A101 7 PT 10028980
3 A102 5 LLT 10028981
6 A103 4 LLT 10028982
here's what .agg would do:
>>> df.groupby("col1")["col2"].agg("max")
col1
A101 7
A102 5
A103 4
Name: col2, dtype: int64
and what .transform does
>>> df.groupby("col1")["col2"].transform("max")
0 7
1 7
2 7
3 5
4 5
5 4
6 4
Name: col2, dtype: int64
which allows for an element-wise comparison with the col2 column.