I have a dataframe that looks like this
dict = {'trade_date': {1350: 20151201,
6175: 20151201,
3100: 20151201,
5650: 20151201,
3575: 20151201,
1: 20170301,
2: 20170301},
'comId': {1350: '257762',
6175: '1038328',
3100: '315476',
5650: '658776',
3575: '329376',
1: '123456',
2: '987654'},
'return': {1350: -0.0018,
6175: 0.0023,
3100: -0.0413,
5650: 0.1266,
3575: 0.0221,
1: '0.9',
2: '0.01'}}
df = pd.DataFrame(dict)
the expected output should be like this:
dict2 = {'trade_date': {5650: 20151201,
1: 20170301},
'comId': {5650: '658776',
1: '123456'},
'return': {5650: 0.1266,
1: '0.9'}}
I need to filter it based on the following condition: for each trade_date
value, I want to keep only the top 20% entries, based on the value in column return
. So for this example, it would filter out everything but the company with comId
value 658776
and return
value 0.1266
.
Bear in mind there might be trade_dates
with more companies associated to them. In that case it should round that up or down to the nearest integer. For example, if there are 9 companies associated with a date, 20% * 9 = 1.8, so it should only keep the first two based on the values in column return
.
Any ideas how to best approach this, I'm a bit lost?
CodePudding user response:
I think this should work:
df\
.groupby("trade_date")\
.apply(lambda x: x[x["return"] >
x["return"].quantile(0.8, interpolation="nearest")])\
.reset_index(drop=True)
CodePudding user response:
You can use groupby().transform
to get the threshold for each row. This would be a bit faster than groupby().apply
:
thresholds = df.groupby('trade_date')['return'].transform('quantile',q=.8)
df[df['return'] > thresholds]
Output:
trade_date comId return
5650 20151201 658776 0.1266
CodePudding user response:
Create a temporary variable storing only the rows with the same trade_date. Then use this: df.sort_values(by='return', ascending=False) and then remove the bottom 80%. Loop through all possible dates and everytime you get the 20%, append them to a new dataframe.