Home > other >  Filter values in pandas dataframe based on complex columns conditions
Filter values in pandas dataframe based on complex columns conditions

Time:08-26

I have a dataframe that looks like this

dict = {'trade_date': {1350: 20151201,
  6175: 20151201,
  3100: 20151201,
  5650: 20151201,
  3575: 20151201,
     1: 20170301,
     2: 20170301},
 'comId': {1350: '257762',
  6175: '1038328',
  3100: '315476',
  5650: '658776',
  3575: '329376',
     1: '123456',
     2: '987654'},
 'return': {1350: -0.0018,
  6175: 0.0023,
  3100: -0.0413,
  5650: 0.1266,
  3575: 0.0221,
  1: '0.9',
  2: '0.01'}}

df = pd.DataFrame(dict)

the expected output should be like this:
dict2 = {'trade_date': {5650: 20151201,
     1: 20170301},
 'comId': {5650: '658776',
     1: '123456'},
 'return': {5650: 0.1266,
  1: '0.9'}}

I need to filter it based on the following condition: for each trade_date value, I want to keep only the top 20% entries, based on the value in column return. So for this example, it would filter out everything but the company with comId value 658776 and return value 0.1266.

Bear in mind there might be trade_dates with more companies associated to them. In that case it should round that up or down to the nearest integer. For example, if there are 9 companies associated with a date, 20% * 9 = 1.8, so it should only keep the first two based on the values in column return.

Any ideas how to best approach this, I'm a bit lost?

CodePudding user response:

I think this should work:

df\
.groupby("trade_date")\
.apply(lambda x: x[x["return"] >
    x["return"].quantile(0.8, interpolation="nearest")])\
.reset_index(drop=True)

CodePudding user response:

You can use groupby().transform to get the threshold for each row. This would be a bit faster than groupby().apply:

thresholds = df.groupby('trade_date')['return'].transform('quantile',q=.8)
df[df['return'] > thresholds]

Output:

      trade_date   comId  return
5650    20151201  658776  0.1266

CodePudding user response:

Create a temporary variable storing only the rows with the same trade_date. Then use this: df.sort_values(by='return', ascending=False) and then remove the bottom 80%. Loop through all possible dates and everytime you get the 20%, append them to a new dataframe.

  • Related