Home > Back-end >  Filter a grouped dataframe based on column value in pyspark
Filter a grouped dataframe based on column value in pyspark

Time:08-04

I have below df. I want to group the dataframe by company and date and for such grouped subsets filter row based on category prioritizing QF if available, if not SAF and if not then AF. I am trying to assign ranks using window function but maybe there is an easier way.

    company     date     value  category
    ------------------------------------
      xyz    31-12-2020    12      AF
      xyz    31-12-2020    10      SAF
      xyz    31-12-2020    11      QF
      xyz    30-06-2020    14      AF
      xyz    30-06-2020    16      SAF
      xyz    30-09-2020    13      SAF
      xyz    31-03-2019    20      AF

Expected output:

   company     date      value  category
    ------------------------------------
      xyz    31-12-2020    11      QF
      xyz    30-06-2020    16      SAF
      xyz    30-09-2020    13      SAF
      xyz    31-03-2019    20      AF

CodePudding user response:

We can assign a rank to the categories using when().otherwise() and retain the records that have the min rank in the group.

data_sdf. \
    withColumn('cat_rank',
               func.when(func.col('cat') == 'QF', func.lit(1)).
               when(func.col('cat') == 'SAF', func.lit(2)).
               when(func.col('cat') == 'AF', func.lit(3))
               ). \
    withColumn('min_cat_rank', 
               func.min('cat_rank').over(wd.partitionBy('company', 'dt'))
               ). \
    filter(func.col('min_cat_rank').isNotNull()). \
    filter(func.col('min_cat_rank') == func.col('cat_rank')). \
    drop('cat_rank', 'min_cat_rank'). \
    show()

#  ------- ---------- --- --- 
# |company|        dt|val|cat|
#  ------- ---------- --- --- 
# |    xyz|30-09-2020| 13|SAF|
# |    xyz|30-06-2020| 16|SAF|
# |    xyz|31-03-2019| 20| AF|
# |    xyz|31-12-2020| 11| QF|
#  ------- ---------- --- --- 

CodePudding user response:

Assuming, that there are only a limited amount of categories and that there are no duplicated entries for each categories I would suggest to map the categories to integers to which you can order them. Afterwards you can simply partition, sort and pick the first entry of each partition.

df = df.withColumn('mapping',
            f.when(f.col('category') == 'QF', f.lit('1')).otherwise(
            f.when(f.col('category') == 'SAF', f.lit('2')).otherwise(
            f.when(f.col('category') == 'AF', f.lit('3')).otherwise(f.lit(None)))))

w = Window.partitionBy('date').orderBy(f.col('mapping'))
df.withColumn('row', f.row_number().over(w))\
   .filter(f.col('row') == 1)\
   .drop('row', 'mapping')\
   .show()

CodePudding user response:

Supposing there can be multiple values for the same category in a combination of company and date, and that we want to keep the maximum value for the preferred category, here is a solution with two window functions:

import pyspark.sql.functions as F
from pyspark.sql.window import Window

w_company_date = Window.partitionBy('company', 'date')
w_company_date_category = Window.partitionBy('company', 'date', 'category')

df = (df
  .withColumn('priority', F.when(F.col('category') == 'QF', 1)
                           .when(F.col('category') == 'SAF', 2)
                           .when(F.col('category') == 'AF', 3)
                           .otherwise(None))
  .withColumn('top_choice', F.when((F.col('priority') == F.min('priority').over(w_company_date))
                                   & (F.col('value') == F.max('value').over(w_company_date_category)), 1)
                             .otherwise(0))
  .filter(F.col('top_choice') == 1)
  .drop('priority', 'top_choice')
)

df.show()

 ------- ---------- ----- -------- 
|company|      date|value|category|
 ------- ---------- ----- -------- 
|    xyz|2020-03-31|   20|      AF|
|    xyz|2020-06-30|   16|     SAF|
|    xyz|2020-09-30|   13|     SAF|
|    xyz|2020-12-31|   11|      QF|
 ------- ---------- ----- -------- 
  • Related