I am trying to keep the top 3 values for each date/row and set every other value to zero.
I created a sample dataframe:
import pandas as pd
df1 = pd.DataFrame({
'Date':['2021-01-01', '2021-01-02', '2021-01-03', '2021-01-04'],
'01K W':[0, 1.2, 0.3, 2],
'02K W':[0.5, 2, 1.4, 3],
'03K W':[2, 1.6, 3, 5],
'04K W':[7, 0.5, 2.4, 5],
'05K W':[4, 2, 4.5, 1],
'06K W':[2.7, 0, 0, 0],
'07K W':[4, 3, 3, 2],
'08K W':[3.8, 1, 9, 2],
'09K W':[1, 4, 0.4, 6.3],
'10K W':[0, 0, 9, 5.6]})
df1 = df1.set_index('Date')
I struggle with adjusting the apply
function to get the top n values and not the max. I tried it with nlargest
, but then I get the error: Can only compare identically-labeled Series objects.
df1.apply(lambda x: x == x.max(), axis=1)
But what I would like to get is:
df2 = pd.DataFrame({
'Date':['2021-01-01', '2021-01-02', '2021-01-03', '2021-01-04'],
'01K W':[0, 0, 0, 0],
'02K W':[0, 2, 0, 0],
'03K W':[0, 0, 0, 5],
'04K W':[7, 0, 0, 5],
'05K W':[4, 2, 4.5, 0],
'06K W':[0, 0, 0, 0],
'07K W':[4, 3, 0, 0],
'08K W':[0, 0, 9, 0],
'09K W':[0, 4, 0, 6.3],
'10K W':[0, 0, 9, 5.6]})
df2 = df2.set_index('Date')
Thanks a lot
CodePudding user response:
DataFrame.rank
Rank the rows along the columns axis in descending order, then mask
the values which have rank > 3
df1.mask(df1.rank(axis=1, method='min', ascending=False).gt(3), 0)
01K W 02K W 03K W 04K W 05K W 06K W 07K W 08K W 09K W 10K W
Date
2021-01-01 0.0 0.0 0.0 7.0 4.0 0.0 4 0.0 0.0 0.0
2021-01-02 0.0 2.0 0.0 0.0 2.0 0.0 3 0.0 4.0 0.0
2021-01-03 0.0 0.0 0.0 0.0 4.5 0.0 0 9.0 0.0 9.0
2021-01-04 0.0 0.0 5.0 5.0 0.0 0.0 0 0.0 6.3 5.6